get almost all client tests working

This commit is contained in:
Ellie Huxtable 2023-12-02 20:24:49 +00:00
parent 26abf41be4
commit 767aadeb63
9 changed files with 352 additions and 362 deletions

View File

@ -1,30 +1,15 @@
-- Add migration script here
create table if not exists host(
id integer primary key, -- just the rowid for normalization
host text unique not null, -- the globally unique host id (uuid)
name text unique -- an optional user-friendly alias
);
-- this will become more useful when we allow for multiple recipients of
-- some data (same cek, multiple long term keys)
-- This could be a key per host rather than one global key, or separate users.
create table if not exists cek (
id integer primary key, -- normalization rowid
cek text unique not null,
);
create table if not exists store (
id text primary key, -- globally unique ID
idx integer, -- incrementing integer ID unique per (host, tag)
host integer not null, -- references the host row
cek integer not null, -- references the cek row
host text not null, -- references the host row
tag text not null,
timestamp integer not null,
tag text not null,
version text not null,
data blob not null,
foreign key(host) references host(id),
foreign key(cek) references cek(id)
cek blob not null
);
create unique index record_uniq ON store(host, tag, idx);

View File

@ -14,7 +14,7 @@ use atuin_common::{
AddHistoryRequest, CountResponse, DeleteHistoryRequest, ErrorResponse, IndexResponse,
LoginRequest, LoginResponse, RegisterResponse, StatusResponse, SyncHistoryResponse,
},
record::RecordIndex,
record::RecordStatus,
};
use semver::Version;
use time::format_description::well_known::Rfc3339;
@ -264,7 +264,7 @@ impl<'a> Client<'a> {
Ok(records)
}
pub async fn record_index(&self) -> Result<RecordIndex> {
pub async fn record_status(&self) -> Result<RecordStatus> {
let url = format!("{}/record", self.sync_addr);
let url = Url::parse(url.as_str())?;

View File

@ -1,12 +1,12 @@
use rmp::decode::bytes::BytesReadError;
use rmp::decode::{read_u64, ValueReadError};
use rmp::decode::{ValueReadError};
use rmp::{decode::Bytes, Marker};
use std::env;
use atuin_common::record::{DecryptedData, HostId};
use atuin_common::record::{DecryptedData};
use atuin_common::utils::uuid_v7;
use eyre::{bail, ensure, eyre, Result};
use eyre::{bail, eyre, Result};
use regex::RegexSet;
use crate::{secrets::SECRET_PATTERNS, settings::Settings};
@ -313,6 +313,7 @@ impl History {
#[cfg(test)]
mod tests {
use atuin_common::record::DecryptedData;
use regex::RegexSet;
use time::{macros::datetime, OffsetDateTime};
@ -415,9 +416,9 @@ mod tests {
};
let serialized = history.serialize().expect("failed to serialize history");
assert_eq!(serialized, bytes);
assert_eq!(serialized.0, bytes);
let deserialized = History::deserialize(&serialized, HISTORY_VERSION)
let deserialized = History::deserialize(&serialized.0, HISTORY_VERSION)
.expect("failed to deserialize history");
assert_eq!(history, deserialized);
@ -443,7 +444,7 @@ mod tests {
let serialized = history.serialize().expect("failed to serialize history");
let deserialized = History::deserialize(&serialized, HISTORY_VERSION)
let deserialized = History::deserialize(&serialized.0, HISTORY_VERSION)
.expect("failed to deserialize history");
assert_eq!(history, deserialized);

View File

@ -1,10 +1,10 @@
use std::sync::Arc;
use eyre::Result;
use serde::{Deserialize, Serialize};
use serde::{Serialize};
use crate::record::{self, encryption::PASETO_V4, sqlite_store::SqliteStore, store::Store};
use atuin_common::record::{HostId, Record};
use crate::record::{encryption::PASETO_V4, sqlite_store::SqliteStore, store::Store};
use atuin_common::record::{Host, HostId, Record};
use super::{History, HISTORY_TAG, HISTORY_VERSION};
@ -26,17 +26,17 @@ impl HistoryStore {
pub async fn push(&self, history: &History) -> Result<()> {
let bytes = history.serialize()?;
let parent = self
let id = self
.store
.tail(self.host_id, HISTORY_TAG)
.last(self.host_id, HISTORY_TAG)
.await?
.map(|p| p.id);
.map_or(0, |p| p.idx + 1);
let record = Record::builder()
.host(self.host_id)
.host(Host::new(self.host_id))
.version(HISTORY_VERSION.to_string())
.tag(HISTORY_TAG.to_string())
.parent(parent)
.idx(id)
.data(bytes)
.build();

View File

@ -1,6 +1,6 @@
use std::collections::BTreeMap;
use atuin_common::record::{DecryptedData, HostId};
use atuin_common::record::{DecryptedData, Host, HostId};
use eyre::{bail, ensure, eyre, Result};
use serde::Deserialize;
@ -111,13 +111,16 @@ impl KvStore {
let bytes = record.serialize()?;
let parent = store.tail(host_id, KV_TAG).await?.map(|entry| entry.id);
let idx = store
.last(host_id, KV_TAG)
.await?
.map_or(0, |entry| entry.idx + 1);
let record = atuin_common::record::Record::builder()
.host(host_id)
.host(Host::new(host_id))
.version(KV_VERSION.to_string())
.tag(KV_TAG.to_string())
.parent(parent)
.idx(idx)
.data(bytes)
.build();
@ -132,47 +135,13 @@ impl KvStore {
// well.
pub async fn get(
&self,
store: &impl Store,
encryption_key: &[u8; 32],
namespace: &str,
key: &str,
_store: &impl Store,
_encryption_key: &[u8; 32],
_namespace: &str,
_key: &str,
) -> Result<Option<KvRecord>> {
// Currently, this is O(n). When we have an actual KV store, it can be better
// Just a poc for now!
// TODO: implement
// iterate records to find the value we want
// start at the end, so we get the most recent version
let tails = store.tag_tails(KV_TAG).await?;
if tails.is_empty() {
return Ok(None);
}
// first, decide on a record.
// try getting the newest first
// we always need a way of deciding the "winner" of a write
// TODO(ellie): something better than last-write-wins, what if two write at the same time?
let mut record = tails.iter().max_by_key(|r| r.timestamp).unwrap().clone();
loop {
let decrypted = match record.version.as_str() {
KV_VERSION => record.decrypt::<PASETO_V4>(encryption_key)?,
version => bail!("unknown version {version:?}"),
};
let kv = KvRecord::deserialize(&decrypted.data, &decrypted.version)?;
if kv.key == key && kv.namespace == namespace {
return Ok(Some(kv));
}
if let Some(parent) = decrypted.parent {
record = store.get(parent).await?;
} else {
break;
}
}
// if we get here, then... we didn't find the record with that key :(
Ok(None)
}
@ -182,35 +151,11 @@ impl KvStore {
// use as a write-through cache to avoid constant rebuilds.
pub async fn build_kv(
&self,
store: &impl Store,
encryption_key: &[u8; 32],
_store: &impl Store,
_encryption_key: &[u8; 32],
) -> Result<BTreeMap<String, BTreeMap<String, String>>> {
let mut map = BTreeMap::new();
let tails = store.tag_tails(KV_TAG).await?;
if tails.is_empty() {
return Ok(map);
}
let mut record = tails.iter().max_by_key(|r| r.timestamp).unwrap().clone();
loop {
let decrypted = match record.version.as_str() {
KV_VERSION => record.decrypt::<PASETO_V4>(encryption_key)?,
version => bail!("unknown version {version:?}"),
};
let kv = KvRecord::deserialize(&decrypted.data, &decrypted.version)?;
let ns = map.entry(kv.namespace).or_insert_with(BTreeMap::new);
ns.entry(kv.key).or_insert_with(|| kv.value);
if let Some(parent) = decrypted.parent {
record = store.get(parent).await?;
} else {
break;
}
}
let map = BTreeMap::new();
// TODO: implement
Ok(map)
}

View File

@ -1,5 +1,5 @@
use atuin_common::record::{
AdditionalData, DecryptedData, EncryptedData, Encryption, HostId, RecordId,
AdditionalData, DecryptedData, EncryptedData, Encryption, HostId, RecordId, RecordIdx,
};
use base64::{engine::general_purpose, Engine};
use eyre::{ensure, Context, Result};
@ -170,10 +170,10 @@ struct AtuinFooter {
#[derive(Debug, Copy, Clone, Serialize)]
struct Assertions<'a> {
id: &'a RecordId,
idx: &'a RecordIdx,
version: &'a str,
tag: &'a str,
host: &'a HostId,
parent: Option<&'a RecordId>,
}
impl<'a> From<AdditionalData<'a>> for Assertions<'a> {
@ -183,7 +183,7 @@ impl<'a> From<AdditionalData<'a>> for Assertions<'a> {
version: ad.version,
tag: ad.tag,
host: ad.host,
parent: ad.parent,
idx: ad.idx,
}
}
}
@ -196,7 +196,10 @@ impl Assertions<'_> {
#[cfg(test)]
mod tests {
use atuin_common::{record::Record, utils::uuid_v7};
use atuin_common::{
record::{Host, Record},
utils::uuid_v7,
};
use super::*;
@ -209,7 +212,7 @@ mod tests {
version: "v0",
tag: "kv",
host: &HostId(uuid_v7()),
parent: None,
idx: &0,
};
let data = DecryptedData(vec![1, 2, 3, 4]);
@ -228,7 +231,7 @@ mod tests {
version: "v0",
tag: "kv",
host: &HostId(uuid_v7()),
parent: None,
idx: &0,
};
let data = DecryptedData(vec![1, 2, 3, 4]);
@ -252,7 +255,7 @@ mod tests {
version: "v0",
tag: "kv",
host: &HostId(uuid_v7()),
parent: None,
idx: &0,
};
let data = DecryptedData(vec![1, 2, 3, 4]);
@ -270,7 +273,7 @@ mod tests {
version: "v0",
tag: "kv",
host: &HostId(uuid_v7()),
parent: None,
idx: &0,
};
let data = DecryptedData(vec![1, 2, 3, 4]);
@ -294,7 +297,7 @@ mod tests {
version: "v0",
tag: "kv",
host: &HostId(uuid_v7()),
parent: None,
idx: &0,
};
let data = DecryptedData(vec![1, 2, 3, 4]);
@ -323,9 +326,10 @@ mod tests {
.id(RecordId(uuid_v7()))
.version("v0".to_owned())
.tag("kv".to_owned())
.host(HostId(uuid_v7()))
.host(Host::new(HostId(uuid_v7())))
.timestamp(1687244806000000)
.data(DecryptedData(vec![1, 2, 3, 4]))
.idx(0)
.build();
let encrypted = record.encrypt::<PASETO_V4>(&key);
@ -345,15 +349,16 @@ mod tests {
.id(RecordId(uuid_v7()))
.version("v0".to_owned())
.tag("kv".to_owned())
.host(HostId(uuid_v7()))
.host(Host::new(HostId(uuid_v7())))
.timestamp(1687244806000000)
.data(DecryptedData(vec![1, 2, 3, 4]))
.idx(0)
.build();
let encrypted = record.encrypt::<PASETO_V4>(&key);
let mut enc1 = encrypted.clone();
enc1.host = HostId(uuid_v7());
enc1.host = Host::new(HostId(uuid_v7()));
let _ = enc1
.decrypt::<PASETO_V4>(&key)
.expect_err("tampering with the host should result in auth failure");

View File

@ -8,13 +8,15 @@ use std::str::FromStr;
use async_trait::async_trait;
use eyre::{eyre, Result};
use fs_err as fs;
use futures::TryStreamExt;
use sqlx::{
sqlite::{SqliteConnectOptions, SqliteJournalMode, SqlitePool, SqlitePoolOptions, SqliteRow},
Row,
};
use atuin_common::record::{EncryptedData, Host, HostId, Record, RecordId, RecordIndex};
use atuin_common::record::{
EncryptedData, Host, HostId, Record, RecordId, RecordIdx, RecordStatus,
};
use uuid::Uuid;
use super::store::Store;
@ -49,49 +51,6 @@ impl SqliteStore {
Ok(Self { pool })
}
async fn host(tx: &mut sqlx::Transaction<'_, sqlx::Sqlite>, host: &Host) -> Result<u64> {
// try selecting the id from the host. return if exists, or insert new and return id
let res: Result<(i64,), sqlx::Error> =
sqlx::query_as("select id from host where host = ?1")
.bind(host.id.0.as_hyphenated().to_string())
.fetch_one(&mut **tx)
.await;
if let Ok(res) = res {
return Ok(res.0 as u64);
}
let res: (i64,) =
sqlx::query_as("insert into host(host, name) values (?1, ?2) returning id")
.bind(host.id.0.as_hyphenated().to_string())
.bind(host.name.as_str())
.fetch_one(&mut **tx)
.await?;
Ok(res.0 as u64)
}
async fn cek(tx: &mut sqlx::Transaction<'_, sqlx::Sqlite>, cek: &str) -> Result<u64> {
// try selecting the id from the host. return if exists, or insert new and return id
let res: Result<(i64,), sqlx::Error> = sqlx::query_as("select id from cek where cek = ?1")
.bind(cek)
.fetch_one(&mut **tx)
.await;
if let Ok(res) = res {
return Ok(res.0 as u64);
}
let res: (i64,) = sqlx::query_as("insert into cek(cek) values (?1) returning id")
.bind(cek)
.fetch_one(&mut **tx)
.await?;
Ok(res.0 as u64)
}
async fn setup_db(pool: &SqlitePool) -> Result<()> {
debug!("running sqlite database setup");
@ -104,22 +63,19 @@ impl SqliteStore {
tx: &mut sqlx::Transaction<'_, sqlx::Sqlite>,
r: &Record<EncryptedData>,
) -> Result<()> {
let host = Self::host(tx, &r.host).await?;
let cek = Self::cek(tx, r.data.content_encryption_key.as_str()).await?;
// In sqlite, we are "limited" to i64. But that is still fine, until 2262.
sqlx::query(
"insert or ignore into store(id, idx, host, cek, timestamp, tag, version, data)
"insert or ignore into store(id, idx, host, tag, timestamp, version, data, cek)
values(?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8)",
)
.bind(r.id.0.as_simple().to_string())
.bind(r.id.0.as_hyphenated().to_string())
.bind(r.idx as i64)
.bind(host as i64)
.bind(cek as i64)
.bind(r.timestamp as i64)
.bind(r.host.id.0.as_hyphenated().to_string())
.bind(r.tag.as_str())
.bind(r.timestamp as i64)
.bind(r.version.as_str())
.bind(r.data.data.as_str())
.bind(r.data.content_encryption_key.as_str())
.execute(&mut **tx)
.await?;
@ -132,8 +88,7 @@ impl SqliteStore {
// tbh at this point things are pretty fucked so just panic
let id = Uuid::from_str(row.get("id")).expect("invalid id UUID format in sqlite DB");
let host =
Uuid::from_str(row.get("host.host")).expect("invalid host UUID format in sqlite DB");
let host = Uuid::from_str(row.get("host")).expect("invalid host UUID format in sqlite DB");
Record {
id: RecordId(id),
@ -144,7 +99,7 @@ impl SqliteStore {
version: row.get("version"),
data: EncryptedData {
data: row.get("data"),
content_encryption_key: row.get("cek.cek"),
content_encryption_key: row.get("cek"),
},
}
}
@ -168,8 +123,8 @@ impl Store for SqliteStore {
}
async fn get(&self, id: RecordId) -> Result<Record<EncryptedData>> {
let res = sqlx::query("select * from store inner join host on store.host=host.id inner join cek on store.cek=cek.id where store.id = ?1")
.bind(id.0.as_simple().to_string())
let res = sqlx::query("select * from store where store.id = ?1")
.bind(id.0.as_hyphenated().to_string())
.map(Self::query_row)
.fetch_one(&self.pool)
.await?;
@ -177,20 +132,66 @@ impl Store for SqliteStore {
Ok(res)
}
async fn last(&self, host: HostId, tag: &str) -> Result<u64> {
let res: (i64,) =
sqlx::query_as("select max(idx) from records where host = ?1 and tag = ?2")
.bind(host.0.as_simple().to_string())
async fn last(&self, host: HostId, tag: &str) -> Result<Option<Record<EncryptedData>>> {
let res =
sqlx::query("select * from store where host=?1 and tag=?2 order by idx desc limit 1")
.bind(host.0.as_hyphenated().to_string())
.bind(tag)
.map(Self::query_row)
.fetch_one(&self.pool)
.await?;
.await;
Ok(res.0 as u64)
match res {
Err(sqlx::Error::RowNotFound) => Ok(None),
Err(e) => Err(eyre!("an error occured: {}", e)),
Ok(record) => Ok(Some(record)),
}
}
async fn next(&self, record: &Record<EncryptedData>) -> Result<Option<Record<EncryptedData>>> {
let res = sqlx::query("select * from records where parent = ?1")
.bind(record.id.0.as_simple().to_string())
async fn first(&self, host: HostId, tag: &str) -> Result<Option<Record<EncryptedData>>> {
self.idx(host, tag, 0).await
}
async fn len(&self, host: HostId, tag: &str) -> Result<Option<u64>> {
let last = self.last(host, tag).await?;
if let Some(last) = last {
return Ok(Some(last.idx + 1));
}
return Ok(None);
}
async fn next(
&self,
host: HostId,
tag: &str,
idx: RecordIdx,
limit: u64,
) -> Result<Vec<Record<EncryptedData>>> {
let res =
sqlx::query("select * from store where idx > ?1 and host = ?2 and tag = ?3 limit ?4")
.bind(idx as i64)
.bind(host)
.bind(tag)
.bind(limit as i64)
.map(Self::query_row)
.fetch_all(&self.pool)
.await?;
Ok(res)
}
async fn idx(
&self,
host: HostId,
tag: &str,
idx: RecordIdx,
) -> Result<Option<Record<EncryptedData>>> {
let res = sqlx::query("select * from store where idx = ?1 and host = ?2 and tag = ?3")
.bind(idx as i64)
.bind(host)
.bind(tag)
.map(Self::query_row)
.fetch_one(&self.pool)
.await;
@ -202,58 +203,36 @@ impl Store for SqliteStore {
}
}
async fn head(&self, host: HostId, tag: &str) -> Result<Option<Record<EncryptedData>>> {
let res = sqlx::query(
"select * from records where host = ?1 and tag = ?2 and parent is null limit 1",
)
.bind(host.0.as_simple().to_string())
.bind(tag)
.map(Self::query_row)
.fetch_optional(&self.pool)
.await?;
async fn status(&self) -> Result<RecordStatus> {
let mut status = RecordStatus::new();
Ok(res)
let res: Result<Vec<(String, String, i64)>, sqlx::Error> =
sqlx::query_as("select host, tag, max(idx) from store group by host, tag")
.fetch_all(&self.pool)
.await;
let res = match res {
Err(e) => return Err(eyre!("failed to fetch local store status: {}", e)),
Ok(v) => v,
};
for i in res {
let host = HostId(
Uuid::from_str(i.0.as_str()).expect("failed to parse uuid for local store status"),
);
status.set_raw(host, i.1, i.2 as u64);
}
Ok(status)
}
async fn tail(&self, host: HostId, tag: &str) -> Result<Option<Record<EncryptedData>>> {
let res = sqlx::query(
"select * from records rp where tag=?1 and host=?2 and (select count(1) from records where parent=rp.id) = 0;",
)
.bind(tag)
.bind(host.0.as_simple().to_string())
.map(Self::query_row)
.fetch_optional(&self.pool)
.await?;
Ok(res)
}
async fn tag_tails(&self, tag: &str) -> Result<Vec<Record<EncryptedData>>> {
let res = sqlx::query(
"select * from records rp where tag=?1 and (select count(1) from records where parent=rp.id) = 0;",
)
.bind(tag)
.map(Self::query_row)
.fetch_all(&self.pool)
.await?;
Ok(res)
}
async fn tail_records(&self) -> Result<RecordIndex> {
let res = sqlx::query(
"select host, tag, id from records rp where (select count(1) from records where parent=rp.id) = 0;",
)
.map(|row: SqliteRow| {
let host: Uuid= Uuid::from_str(row.get("host")).expect("invalid uuid in db host");
let tag: String= row.get("tag");
let id: Uuid= Uuid::from_str(row.get("id")).expect("invalid uuid in db id");
(HostId(host), tag, RecordId(id))
})
.fetch(&self.pool)
.try_collect()
.await?;
async fn all_tagged(&self, tag: &str) -> Result<Vec<Record<EncryptedData>>> {
let res = sqlx::query("select * from store where idx = 0 and tag = ?1")
.bind(tag)
.map(Self::query_row)
.fetch_all(&self.pool)
.await?;
Ok(res)
}
@ -261,7 +240,7 @@ impl Store for SqliteStore {
#[cfg(test)]
mod tests {
use atuin_common::record::{EncryptedData, HostId, Record};
use atuin_common::record::{EncryptedData, Host, HostId, Record};
use crate::record::{encryption::PASETO_V4, store::Store};
@ -269,13 +248,14 @@ mod tests {
fn test_record() -> Record<EncryptedData> {
Record::builder()
.host(HostId(atuin_common::utils::uuid_v7()))
.host(Host::new(HostId(atuin_common::utils::uuid_v7())))
.version("v1".into())
.tag(atuin_common::utils::uuid_v7().simple().to_string())
.data(EncryptedData {
data: "1234".into(),
content_encryption_key: "1234".into(),
})
.idx(0)
.build()
}
@ -309,6 +289,24 @@ mod tests {
assert_eq!(record, new_record, "records are not equal");
}
#[tokio::test]
async fn last() {
let db = SqliteStore::new(":memory:").await.unwrap();
let record = test_record();
db.push(&record).await.unwrap();
let last = db
.last(record.host.id, record.tag.as_str())
.await
.expect("failed to get store len");
assert_eq!(
last.unwrap().id,
record.id,
"expected to get back the same record that was inserted"
);
}
#[tokio::test]
async fn len() {
let db = SqliteStore::new(":memory:").await.unwrap();
@ -316,11 +314,11 @@ mod tests {
db.push(&record).await.unwrap();
let len = db
.len(record.host, record.tag.as_str())
.len(record.host.id, record.tag.as_str())
.await
.expect("failed to get store len");
assert_eq!(len, 1, "expected length of 1 after insert");
assert_eq!(len, Some(1), "expected length of 1 after insert");
}
#[tokio::test]
@ -336,11 +334,11 @@ mod tests {
db.push(&first).await.unwrap();
db.push(&second).await.unwrap();
let first_len = db.len(first.host, first.tag.as_str()).await.unwrap();
let second_len = db.len(second.host, second.tag.as_str()).await.unwrap();
let first_len = db.len(first.host.id, first.tag.as_str()).await.unwrap();
let second_len = db.len(second.host.id, second.tag.as_str()).await.unwrap();
assert_eq!(first_len, 1, "expected length of 1 after insert");
assert_eq!(second_len, 1, "expected length of 1 after insert");
assert_eq!(first_len, Some(1), "expected length of 1 after insert");
assert_eq!(second_len, Some(1), "expected length of 1 after insert");
}
#[tokio::test]
@ -351,15 +349,13 @@ mod tests {
db.push(&tail).await.expect("failed to push record");
for _ in 1..100 {
tail = tail
.new_child(vec![1, 2, 3, 4])
.encrypt::<PASETO_V4>(&[0; 32]);
tail = tail.append(vec![1, 2, 3, 4]).encrypt::<PASETO_V4>(&[0; 32]);
db.push(&tail).await.unwrap();
}
assert_eq!(
db.len(tail.host, tail.tag.as_str()).await.unwrap(),
100,
db.len(tail.host.id, tail.tag.as_str()).await.unwrap(),
Some(100),
"failed to insert 100 records"
);
}
@ -374,50 +370,16 @@ mod tests {
records.push(tail.clone());
for _ in 1..10000 {
tail = tail.new_child(vec![1, 2, 3]).encrypt::<PASETO_V4>(&[0; 32]);
tail = tail.append(vec![1, 2, 3]).encrypt::<PASETO_V4>(&[0; 32]);
records.push(tail.clone());
}
db.push_batch(records.iter()).await.unwrap();
assert_eq!(
db.len(tail.host, tail.tag.as_str()).await.unwrap(),
10000,
db.len(tail.host.id, tail.tag.as_str()).await.unwrap(),
Some(10000),
"failed to insert 10k records"
);
}
#[tokio::test]
async fn test_chain() {
let db = SqliteStore::new(":memory:").await.unwrap();
let mut records: Vec<Record<EncryptedData>> = Vec::with_capacity(1000);
let mut tail = test_record();
records.push(tail.clone());
for _ in 1..1000 {
tail = tail.new_child(vec![1, 2, 3]).encrypt::<PASETO_V4>(&[0; 32]);
records.push(tail.clone());
}
db.push_batch(records.iter()).await.unwrap();
let mut record = db
.head(tail.host, tail.tag.as_str())
.await
.expect("in memory sqlite should not fail")
.expect("entry exists");
let mut count = 1;
while let Some(next) = db.next(&record).await.unwrap() {
assert_eq!(record.id, next.clone().parent.unwrap());
record = next;
count += 1;
}
assert_eq!(count, 1000);
}
}

View File

@ -1,7 +1,7 @@
use async_trait::async_trait;
use eyre::Result;
use atuin_common::record::{EncryptedData, HostId, Record, RecordId, RecordIndex};
use atuin_common::record::{EncryptedData, HostId, Record, RecordId, RecordIdx, RecordStatus};
/// A record store stores records
/// In more detail - we tend to need to process this into _another_ format to actually query it.
@ -21,21 +21,32 @@ pub trait Store {
) -> Result<()>;
async fn get(&self, id: RecordId) -> Result<Record<EncryptedData>>;
async fn last(&self, host: HostId, tag: &str) -> Result<Option<u64>>;
async fn len(&self, host: HostId, tag: &str) -> Result<Option<u64>>;
async fn last(&self, host: HostId, tag: &str) -> Result<Option<Record<EncryptedData>>>;
async fn first(&self, host: HostId, tag: &str) -> Result<Option<Record<EncryptedData>>>;
/// Get the record that follows this record
async fn next(&self, record: &Record<EncryptedData>) -> Result<Option<Record<EncryptedData>>>;
async fn next(
&self,
host: HostId,
tag: &str,
idx: RecordIdx,
limit: u64,
) -> Result<Vec<Record<EncryptedData>>>;
/// Get the first record for a given host and tag
async fn head(&self, host: HostId, tag: &str) -> Result<Option<Record<EncryptedData>>>;
async fn idx(
&self,
host: HostId,
tag: &str,
idx: RecordIdx,
) -> Result<Option<Record<EncryptedData>>>;
/// Get the last record for a given host and tag
async fn tail(&self, host: HostId, tag: &str) -> Result<Option<Record<EncryptedData>>>;
async fn status(&self) -> Result<RecordStatus>;
// Get the last record for all hosts for a given tag, useful for the read path of apps.
async fn tag_tails(&self, tag: &str) -> Result<Vec<Record<EncryptedData>>>;
// Get the latest host/tag/record tuple for every set in the store. useful for building an
// index
async fn tail_records(&self) -> Result<RecordIndex>;
/// Get every start record for a given tag, regardless of host.
/// Useful when actually operating on synchronized data, and will often have conflict
/// resolution applied.
async fn all_tagged(&self, tag: &str) -> Result<Vec<Record<EncryptedData>>>;
}

View File

@ -5,7 +5,7 @@ use thiserror::Error;
use super::store::Store;
use crate::{api_client::Client, settings::Settings};
use atuin_common::record::{Diff, HostId, RecordId, RecordIdx, RecordStatus};
use atuin_common::record::{Diff, HostId, RecordIdx, RecordStatus};
#[derive(Error, Debug)]
pub enum SyncError {
@ -51,8 +51,8 @@ pub async fn diff(
settings.network_timeout,
)?;
let local_index = store.tail_records().await?;
let remote_index = client.record_index().await?;
let local_index = store.status().await?;
let remote_index = client.record_status().await?;
let diff = local_index.diff(&remote_index);
@ -63,41 +63,31 @@ pub async fn diff(
// With the store as context, we can determine if a tail exists locally or not and therefore if it needs uploading or download.
// In theory this could be done as a part of the diffing stage, but it's easier to reason
// about and test this way
pub async fn operations(diffs: Vec<Diff>, store: &impl Store) -> Result<Vec<Operation>, SyncError> {
pub async fn operations(diffs: Vec<Diff>, _store: &impl Store) -> Result<Vec<Operation>, SyncError> {
let mut operations = Vec::with_capacity(diffs.len());
let host = Settings::host_id().expect("got to record sync without a host id; abort");
let _host = Settings::host_id().expect("got to record sync without a host id; abort");
for diff in diffs {
// First, try to fetch the tail
// If it exists locally, then that means we need to update the remote
// host until it has the same tail. Ie, upload.
// If it does not exist locally, that means remote is ahead of us.
// Therefore, we need to download until our local tail matches
let last = store
.last(diff.host, diff.tag.as_str())
.await
.map_err(|_| SyncError::LocalStoreError)?;
let op = match (last, diff.remote) {
let op = match (diff.local, diff.remote) {
// We both have it! Could be either. Compare.
(Some(last), Some(remote)) => {
if last == remote {
(Some(local), Some(remote)) => {
if local == remote {
// between the diff and now, a sync has somehow occured.
// regardless, no work is needed!
Operation::Noop {
host: diff.host,
tag: diff.tag,
}
} else if last > remote {
} else if local > remote {
Operation::Upload {
local: last,
local,
remote: Some(remote),
host: diff.host,
tag: diff.tag,
}
} else {
Operation::Download {
local: Some(last),
local: Some(local),
remote,
host: diff.host,
tag: diff.tag,
@ -114,8 +104,8 @@ pub async fn operations(diffs: Vec<Diff>, store: &impl Store) -> Result<Vec<Oper
},
// We have it, remote doesn't. Gotta be upload.
(Some(last), None) => Operation::Upload {
local: last,
(Some(local), None) => Operation::Upload {
local,
remote: None,
host: diff.host,
tag: diff.tag,
@ -151,16 +141,16 @@ pub async fn operations(diffs: Vec<Diff>, store: &impl Store) -> Result<Vec<Oper
}
async fn sync_upload(
store: &mut impl Store,
client: &Client<'_>,
_store: &mut impl Store,
_client: &Client<'_>,
host: HostId,
tag: String,
local: RecordIdx,
remote: Option<RecordIdx>,
) -> Result<i64, SyncError> {
let expected = local - remote.unwrap_or(0);
let upload_page_size = 100;
let mut total = 0;
let _upload_page_size = 100;
let _total = 0;
if expected < 0 {
return Err(SyncError::SyncLogicError {
@ -181,16 +171,16 @@ async fn sync_upload(
}
async fn sync_download(
store: &mut impl Store,
client: &Client<'_>,
_store: &mut impl Store,
_client: &Client<'_>,
host: HostId,
tag: String,
local: Option<RecordIdx>,
remote: RecordIdx,
) -> Result<i64, SyncError> {
let expected = remote - local.unwrap_or(0);
let download_page_size = 100;
let mut total = 0;
let _download_page_size = 100;
let _total = 0;
if expected < 0 {
return Err(SyncError::SyncLogicError {
@ -300,8 +290,8 @@ mod tests {
remote_store.push(&i).await.unwrap();
}
let local_index = local_store.tail_records().await.unwrap();
let remote_index = remote_store.tail_records().await.unwrap();
let local_index = local_store.status().await.unwrap();
let remote_index = remote_store.status().await.unwrap();
let diff = local_index.diff(&remote_index);
@ -338,12 +328,14 @@ mod tests {
// another. One upload, one download
let shared_record = test_record();
let remote_ahead = test_record();
let local_ahead = shared_record
.append(vec![1, 2, 3])
.encrypt::<PASETO_V4>(&[0; 32]);
assert_eq!(local_ahead.idx, 1);
let local = vec![shared_record.clone(), local_ahead.clone()]; // local knows about the already synced, and something newer in the same store
let remote = vec![shared_record.clone(), remote_ahead.clone()]; // remote knows about the already-synced, and one new record in a new store
@ -355,18 +347,20 @@ mod tests {
assert_eq!(
operations,
vec![
// Or in otherwords, local is ahead by one
Operation::Upload {
host: local_ahead.host.id,
tag: local_ahead.tag,
local: 1,
remote: Some(0),
},
// Or in other words, remote knows of a record in an entirely new store (tag)
Operation::Download {
host: remote_ahead.host.id,
tag: remote_ahead.tag,
local: None,
remote: 0,
},
Operation::Upload {
host: local_ahead.host.id,
tag: local_ahead.tag,
local: 0,
remote: None,
},
]
);
}
@ -378,62 +372,149 @@ mod tests {
// One known only by remote
let shared_record = test_record();
let local_only = test_record();
let remote_known = test_record();
let local_known = test_record();
let local_only_20 = test_record();
let local_only_21 = local_only_20
.append(vec![1, 2, 3])
.encrypt::<PASETO_V4>(&[0; 32]);
let local_only_22 = local_only_21
.append(vec![1, 2, 3])
.encrypt::<PASETO_V4>(&[0; 32]);
let local_only_23 = local_only_22
.append(vec![1, 2, 3])
.encrypt::<PASETO_V4>(&[0; 32]);
let remote_only = test_record();
let remote_only_20 = test_record();
let remote_only_21 = remote_only_20
.append(vec![2, 3, 2])
.encrypt::<PASETO_V4>(&[0; 32]);
let remote_only_22 = remote_only_21
.append(vec![2, 3, 2])
.encrypt::<PASETO_V4>(&[0; 32]);
let remote_only_23 = remote_only_22
.append(vec![2, 3, 2])
.encrypt::<PASETO_V4>(&[0; 32]);
let remote_only_24 = remote_only_23
.append(vec![2, 3, 2])
.encrypt::<PASETO_V4>(&[0; 32]);
let second_shared = test_record();
let second_shared_remote_ahead = second_shared
.append(vec![1, 2, 3])
.encrypt::<PASETO_V4>(&[0; 32]);
let second_shared_remote_ahead2 = second_shared_remote_ahead
.append(vec![1, 2, 3])
.encrypt::<PASETO_V4>(&[0; 32]);
let local_ahead = shared_record
let third_shared = test_record();
let third_shared_local_ahead = third_shared
.append(vec![1, 2, 3])
.encrypt::<PASETO_V4>(&[0; 32]);
let third_shared_local_ahead2 = third_shared_local_ahead
.append(vec![1, 2, 3])
.encrypt::<PASETO_V4>(&[0; 32]);
let fourth_shared = test_record();
let fourth_shared_remote_ahead = fourth_shared
.append(vec![1, 2, 3])
.encrypt::<PASETO_V4>(&[0; 32]);
let fourth_shared_remote_ahead2 = fourth_shared_remote_ahead
.append(vec![1, 2, 3])
.encrypt::<PASETO_V4>(&[0; 32]);
let local = vec![
shared_record.clone(),
second_shared.clone(),
local_known.clone(),
local_ahead.clone(),
third_shared.clone(),
fourth_shared.clone(),
fourth_shared_remote_ahead.clone(),
// single store, only local has it
local_only.clone(),
// bigger store, also only known by local
local_only_20.clone(),
local_only_21.clone(),
local_only_22.clone(),
local_only_23.clone(),
// another shared store, but local is ahead on this one
third_shared_local_ahead.clone(),
third_shared_local_ahead2.clone(),
];
let remote = vec![
remote_only.clone(),
remote_only_20.clone(),
remote_only_21.clone(),
remote_only_22.clone(),
remote_only_23.clone(),
remote_only_24.clone(),
shared_record.clone(),
second_shared.clone(),
third_shared.clone(),
second_shared_remote_ahead.clone(),
remote_known.clone(),
second_shared_remote_ahead2.clone(),
fourth_shared.clone(),
fourth_shared_remote_ahead.clone(),
fourth_shared_remote_ahead2.clone(),
]; // remote knows about the already-synced, and one new record in a new store
let (store, diff) = build_test_diff(local, remote).await;
let operations = sync::operations(diff, &store).await.unwrap();
assert_eq!(operations.len(), 4);
assert_eq!(operations.len(), 7);
let mut result_ops = vec![
// We started with a shared record, but the remote knows of two newer records in the
// same store
Operation::Download {
host: remote_known.host.id,
tag: remote_known.tag,
local: Some(second_shared.idx),
remote: second_shared_remote_ahead.idx,
local: Some(0),
remote: 2,
host: second_shared_remote_ahead.host.id,
tag: second_shared_remote_ahead.tag,
},
// We have a shared record, local knows of the first two but not the last
Operation::Download {
local: Some(1),
remote: 2,
host: fourth_shared_remote_ahead2.host.id,
tag: fourth_shared_remote_ahead2.tag,
},
// Remote knows of a store with a single record that local does not have
Operation::Download {
host: second_shared.host.id,
tag: second_shared.tag,
local: None,
remote: remote_known.idx,
remote: 0,
host: remote_only.host.id,
tag: remote_only.tag,
},
Operation::Upload {
host: local_ahead.host.id,
tag: local_ahead.tag,
local: local_ahead.idx,
remote: Some(shared_record.idx),
// Remote knows of a store with a bunch of records that local does not have
Operation::Download {
local: None,
remote: 4,
host: remote_only_20.host.id,
tag: remote_only_20.tag,
},
// Local knows of a record in a store that remote does not have
Operation::Upload {
host: local_known.host.id,
tag: local_known.tag,
local: local_known.idx,
local: 0,
remote: None,
host: local_only.host.id,
tag: local_only.tag,
},
// Local knows of 4 records in a store that remote does not have
Operation::Upload {
local: 3,
remote: None,
host: local_only_20.host.id,
tag: local_only_20.tag,
},
// Local knows of 2 more records in a shared store that remote only has one of
Operation::Upload {
local: 2,
remote: Some(0),
host: third_shared.host.id,
tag: third_shared.tag,
},
];
@ -445,6 +526,6 @@ mod tests {
Operation::Download { host, tag, .. } => (2, *host, tag.clone()),
});
assert_eq!(operations, result_ops);
assert_eq!(result_ops, operations);
}
}