Fix a bunch of tests, remove Option<Uuid>

This commit is contained in:
Ellie Huxtable 2023-06-29 09:25:05 +01:00
parent 1774ec93eb
commit f05fe5bd83
7 changed files with 97 additions and 20 deletions

View File

@ -8,11 +8,14 @@ use reqwest::{
StatusCode, Url, StatusCode, Url,
}; };
use atuin_common::api::{ use atuin_common::record::Record;
use atuin_common::{
api::{
AddHistoryRequest, CountResponse, DeleteHistoryRequest, ErrorResponse, IndexResponse, AddHistoryRequest, CountResponse, DeleteHistoryRequest, ErrorResponse, IndexResponse,
LoginRequest, LoginResponse, RegisterResponse, StatusResponse, SyncHistoryResponse, LoginRequest, LoginResponse, RegisterResponse, StatusResponse, SyncHistoryResponse,
},
record::RecordIndex,
}; };
use atuin_common::record::Record;
use semver::Version; use semver::Version;
use crate::{history::History, sync::hash_str}; use crate::{history::History, sync::hash_str};
@ -205,6 +208,16 @@ impl<'a> Client<'a> {
Ok(()) Ok(())
} }
pub async fn record_index(&self) -> Result<RecordIndex> {
let url = format!("{}/record", self.sync_addr);
let url = Url::parse(url.as_str())?;
let resp = self.client.get(url).send().await?;
let index = resp.json().await?;
Ok(index)
}
pub async fn delete(&self) -> Result<()> { pub async fn delete(&self) -> Result<()> {
let url = format!("{}/account", self.sync_addr); let url = format!("{}/account", self.sync_addr);
let url = Url::parse(url.as_str())?; let url = Url::parse(url.as_str())?;

View File

@ -1,3 +1,4 @@
pub mod encryption; pub mod encryption;
pub mod sqlite_store; pub mod sqlite_store;
pub mod store; pub mod store;
pub mod sync;

View File

@ -83,7 +83,7 @@ impl SqliteStore {
// tbh at this point things are pretty fucked so just panic // tbh at this point things are pretty fucked so just panic
let id = Uuid::from_str(row.get("id")).expect("invalid id UUID format in sqlite DB"); let id = Uuid::from_str(row.get("id")).expect("invalid id UUID format in sqlite DB");
let host = Uuid::from_str(row.get("host")).expect("invalid host UUID format in sqlite DB"); let host = Uuid::from_str(row.get("host")).expect("invalid host UUID format in sqlite DB");
let parent: Option<&str> = row.get("host"); let parent: Option<&str> = row.get("parent");
let parent = if let Some(parent) = parent { let parent = if let Some(parent) = parent {
Some(Uuid::from_str(parent).expect("invalid parent UUID format in sqlite DB")) Some(Uuid::from_str(parent).expect("invalid parent UUID format in sqlite DB"))
@ -125,7 +125,7 @@ impl Store for SqliteStore {
async fn get(&self, id: Uuid) -> Result<Record<EncryptedData>> { async fn get(&self, id: Uuid) -> Result<Record<EncryptedData>> {
let res = sqlx::query("select * from records where id = ?1") let res = sqlx::query("select * from records where id = ?1")
.bind(id) .bind(id.as_simple().to_string())
.map(Self::query_row) .map(Self::query_row)
.fetch_one(&self.pool) .fetch_one(&self.pool)
.await?; .await?;
@ -136,7 +136,7 @@ impl Store for SqliteStore {
async fn len(&self, host: Uuid, tag: &str) -> Result<u64> { async fn len(&self, host: Uuid, tag: &str) -> Result<u64> {
let res: (i64,) = let res: (i64,) =
sqlx::query_as("select count(1) from records where host = ?1 and tag = ?2") sqlx::query_as("select count(1) from records where host = ?1 and tag = ?2")
.bind(host) .bind(host.as_simple().to_string())
.bind(tag) .bind(tag)
.fetch_one(&self.pool) .fetch_one(&self.pool)
.await?; .await?;
@ -146,7 +146,7 @@ impl Store for SqliteStore {
async fn next(&self, record: &Record<EncryptedData>) -> Result<Option<Record<EncryptedData>>> { async fn next(&self, record: &Record<EncryptedData>) -> Result<Option<Record<EncryptedData>>> {
let res = sqlx::query("select * from records where parent = ?1") let res = sqlx::query("select * from records where parent = ?1")
.bind(record.id.clone()) .bind(record.id.as_simple().to_string())
.map(Self::query_row) .map(Self::query_row)
.fetch_one(&self.pool) .fetch_one(&self.pool)
.await; .await;
@ -162,7 +162,7 @@ impl Store for SqliteStore {
let res = sqlx::query( let res = sqlx::query(
"select * from records where host = ?1 and tag = ?2 and parent is null limit 1", "select * from records where host = ?1 and tag = ?2 and parent is null limit 1",
) )
.bind(host) .bind(host.as_simple().to_string())
.bind(tag) .bind(tag)
.map(Self::query_row) .map(Self::query_row)
.fetch_optional(&self.pool) .fetch_optional(&self.pool)
@ -183,6 +183,23 @@ impl Store for SqliteStore {
Ok(res) Ok(res)
} }
async fn tail_records(&self) -> Result<Vec<(Uuid, String, Uuid)>> {
let res = sqlx::query(
"select host, tag, id from records rp where (select count(1) from records where parent=rp.id) = 0;",
)
.map(|row: SqliteRow| {
let host: Uuid= Uuid::from_str(row.get("host")).expect("invalid uuid in db host");
let tag: String= row.get("tag");
let id: Uuid= Uuid::from_str(row.get("id")).expect("invalid uuid in db id");
(host, tag, id)
})
.fetch_all(&self.pool)
.await?;
Ok(res)
}
} }
#[cfg(test)] #[cfg(test)]

View File

@ -31,4 +31,6 @@ pub trait Store {
async fn first(&self, host: Uuid, tag: &str) -> Result<Option<Record<EncryptedData>>>; async fn first(&self, host: Uuid, tag: &str) -> Result<Option<Record<EncryptedData>>>;
/// Get the last record for a given host and tag /// Get the last record for a given host and tag
async fn last(&self, host: Uuid, tag: &str) -> Result<Option<Record<EncryptedData>>>; async fn last(&self, host: Uuid, tag: &str) -> Result<Option<Record<EncryptedData>>>;
async fn tail_records(&self) -> Result<Vec<(Uuid, String, Uuid)>>;
} }

View File

@ -0,0 +1,25 @@
use atuin_common::record::RecordIndex;
// do a sync :O
use eyre::Result;
use uuid::Uuid;
use crate::{api_client::Client, settings::Settings};
use super::store::Store;
pub async fn diff(
settings: &Settings,
store: &mut impl Store,
) -> Result<Vec<(Uuid, String, Uuid)>> {
let client = Client::new(&settings.sync_address, &settings.session_token)?;
// First, build our own index
let local_tail = store.tail_records().await?;
let local_index = RecordIndex::from(local_tail);
let remote_index = client.record_index().await?;
let diff = local_index.diff(&remote_index);
Ok(diff)
}

View File

@ -84,6 +84,17 @@ impl Default for RecordIndex {
} }
} }
impl From<Vec<(Uuid, String, Uuid)>> for RecordIndex {
fn from(f: Vec<(Uuid, String, Uuid)>) -> RecordIndex {
let mut record_index = RecordIndex::new();
for row in f {
record_index.set_raw(row.0, row.1, row.2);
}
record_index
}
}
impl RecordIndex { impl RecordIndex {
pub fn new() -> RecordIndex { pub fn new() -> RecordIndex {
RecordIndex { RecordIndex {
@ -114,7 +125,7 @@ impl RecordIndex {
/// other machine has a different tail, it will be the differing tail. This is useful to /// other machine has a different tail, it will be the differing tail. This is useful to
/// check if the other index is ahead of us, or behind. /// check if the other index is ahead of us, or behind.
/// If the other index does not have the (host, tag) pair, then the other value will be None. /// If the other index does not have the (host, tag) pair, then the other value will be None.
pub fn diff(&self, other: &Self) -> Vec<(Uuid, String, Option<Uuid>)> { pub fn diff(&self, other: &Self) -> Vec<(Uuid, String, Uuid)> {
let mut ret = Vec::new(); let mut ret = Vec::new();
// First, we check if other has everything that self has // First, we check if other has everything that self has
@ -125,10 +136,10 @@ impl RecordIndex {
Some(t) if t.eq(tail) => continue, Some(t) if t.eq(tail) => continue,
// The other store does exist, but it is either ahead or behind us. A diff regardless // The other store does exist, but it is either ahead or behind us. A diff regardless
Some(t) => ret.push((host.clone(), tag.clone(), Some(t))), Some(t) => ret.push((host.clone(), tag.clone(), t)),
// The other store does not exist :O // The other store does not exist :O
None => ret.push((host.clone(), tag.clone(), None)), None => ret.push((host.clone(), tag.clone(), tail.clone())),
}; };
} }
} }
@ -143,7 +154,7 @@ impl RecordIndex {
// If we have this host/tag combo, the comparison and diff will have already happened above // If we have this host/tag combo, the comparison and diff will have already happened above
Some(_) => continue, Some(_) => continue,
None => ret.push((host.clone(), tag.clone(), Some(tail.clone()))), None => ret.push((host.clone(), tag.clone(), tail.clone())),
}; };
} }
} }
@ -311,7 +322,7 @@ mod tests {
let diff = index1.diff(&index2); let diff = index1.diff(&index2);
assert_eq!(1, diff.len(), "expected single diff"); assert_eq!(1, diff.len(), "expected single diff");
assert_eq!(diff[0], (record2.host, record2.tag, Some(record2.id))); assert_eq!(diff[0], (record2.host, record2.tag, record2.id));
} }
#[test] #[test]

View File

@ -1,7 +1,12 @@
use clap::Subcommand; use clap::Subcommand;
use eyre::{Result, WrapErr}; use eyre::{Result, WrapErr};
use atuin_client::{api_client, database::Database, record::store::Store, settings::Settings}; use atuin_client::{
api_client,
database::Database,
record::{store::Store, sync},
settings::Settings,
};
mod status; mod status;
@ -73,11 +78,14 @@ async fn run(
db: &mut impl Database, db: &mut impl Database,
store: &mut impl Store, store: &mut impl Store,
) -> Result<()> { ) -> Result<()> {
let host = Settings::host_id().expect("No host ID found"); let diff = sync::diff(settings, store).await?;
// FOR TESTING ONLY! println!("{:?}", diff);
let kv_tail = store.last(host, "kv").await?.expect("no kv found");
let client = api_client::Client::new(&settings.sync_address, &settings.session_token)?;
client.post_records(&[kv_tail]).await?;
atuin_client::sync::sync(settings, force, db).await?;
println!(
"Sync complete! {} items in database, force: {}",
db.history_count().await?,
force
);
Ok(()) Ok(())
} }