Add next() and iterate, test the pointer chain

This commit is contained in:
Ellie Huxtable 2023-06-08 08:49:24 +01:00
parent 485ad84b74
commit 963c356e6b
3 changed files with 125 additions and 2 deletions

View File

@ -90,7 +90,6 @@ impl SqliteStore {
#[async_trait]
impl Store for SqliteStore {
async fn push(&self, record: Record) -> Result<Record> {
// TODO: batch inserts
let mut tx = self.pool.begin().await?;
Self::save_raw(&mut tx, &record).await?;
tx.commit().await?;
@ -98,6 +97,25 @@ impl Store for SqliteStore {
Ok(record)
}
async fn push_batch(
&self,
records: impl Iterator<Item = &Record> + Send + Sync,
) -> Result<Option<Record>> {
let mut tx = self.pool.begin().await?;
// If you push in a batch of nothing it does... nothing.
let mut last: Option<Record> = None;
for record in records {
Self::save_raw(&mut tx, &record).await?;
last = Some(record.clone());
}
tx.commit().await?;
Ok(last)
}
async fn get(&self, id: &str) -> Result<Record> {
let res = sqlx::query("select * from records where id = ?1")
.bind(id)
@ -119,9 +137,20 @@ impl Store for SqliteStore {
Ok(res.0 as u64)
}
async fn next(&self, record: &Record) -> Option<Record> {
let res = sqlx::query("select * from records where parent = ?1")
.bind(record.id.clone())
.map(Self::query_row)
.fetch_one(&self.pool)
.await
.ok();
res
}
async fn first(&self, host: &str, tag: &str) -> Result<Record> {
let res = sqlx::query(
"select * from records where tag = ?1 and host = ?2 and parent is null limit 1",
"select * from records where host = ?1 and tag = ?2 and parent is null limit 1",
)
.bind(host)
.bind(tag)
@ -245,4 +274,77 @@ mod tests {
assert_eq!(first_len, 1, "expected length of 1 after insert");
assert_eq!(second_len, 1, "expected length of 1 after insert");
}
#[tokio::test]
async fn append_a_bunch() {
let db = SqliteStore::new(":memory:").await.unwrap();
let mut tail = db.push(test_record()).await.expect("failed to push record");
for _ in 1..100 {
tail = db.push(tail.new_child(vec![1, 2, 3, 4])).await.unwrap();
}
assert_eq!(
db.len(tail.host.as_str(), tail.tag.as_str()).await.unwrap(),
100,
"failed to insert 100 records"
);
}
#[tokio::test]
async fn append_a_big_bunch() {
let db = SqliteStore::new(":memory:").await.unwrap();
let mut records: Vec<Record> = Vec::with_capacity(10000);
let mut tail = test_record();
records.push(tail.clone());
for _ in 1..10000 {
tail = tail.new_child(vec![1, 2, 3]);
records.push(tail.clone());
}
db.push_batch(records.iter()).await.unwrap();
assert_eq!(
db.len(tail.host.as_str(), tail.tag.as_str()).await.unwrap(),
10000,
"failed to insert 10k records"
);
}
#[tokio::test]
async fn test_chain() {
let db = SqliteStore::new(":memory:").await.unwrap();
let mut records: Vec<Record> = Vec::with_capacity(1000);
let mut tail = test_record();
records.push(tail.clone());
for _ in 1..1000 {
tail = tail.new_child(vec![1, 2, 3]);
records.push(tail.clone());
}
db.push_batch(records.iter()).await.unwrap();
let mut record = db
.first(tail.host.as_str(), tail.tag.as_str())
.await
.unwrap();
let mut count = 1;
while let Some(next) = db.next(&record).await {
assert_eq!(record.id, next.clone().parent.unwrap());
record = next;
count += 1;
}
assert_eq!(count, 1000);
}
}

View File

@ -9,10 +9,21 @@ use atuin_common::record::Record;
/// be shell history, kvs, etc.
#[async_trait]
pub trait Store {
// Push a record and return it
async fn push(&self, record: Record) -> Result<Record>;
// Push a batch of records, all in one transaction
// Returns a record if you push at least one. If the iterator is empty, then
// there is no return record.
async fn push_batch(
&self,
records: impl Iterator<Item = &Record> + Send + Sync,
) -> Result<Option<Record>>;
async fn get(&self, id: &str) -> Result<Record>;
async fn len(&self, host: &str, tag: &str) -> Result<u64>;
async fn next(&self, record: &Record) -> Option<Record>;
// Get the first record for a given host and tag
async fn first(&self, host: &str, tag: &str) -> Result<Record>;
async fn last(&self, host: &str, tag: &str) -> Result<Record>;

View File

@ -46,4 +46,14 @@ impl Record {
data,
}
}
pub fn new_child(&self, data: Vec<u8>) -> Record {
Self::new(
self.host.clone(),
self.version.clone(),
self.tag.clone(),
Some(self.id.clone()),
data,
)
}
}