feat: automatically init history store when record sync is enabled (#1634)

* add support for getting the total length of a store

* tidy up sync

* auto call init if history is ahead

* fix import order, key regen

* fix import order, key regen

* do not delete key when user deletes account

* message output

* remote init store command; this is now automatic

* should probs make that function return u64 at some point
This commit is contained in:
Ellie Huxtable 2024-01-29 16:38:24 +00:00 committed by GitHub
parent 15bad15f48
commit 366b8ea97b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 120 additions and 68 deletions

View File

@ -4,7 +4,7 @@ use eyre::{bail, eyre, Result};
use rmp::decode::Bytes;
use crate::{
database::Database,
database::{self, Database},
record::{encryption::PASETO_V4, sqlite_store::SqliteStore, store::Store},
};
use atuin_common::record::{DecryptedData, Host, HostId, Record, RecordId, RecordIdx};
@ -255,6 +255,34 @@ impl HistoryStore {
Ok(ret)
}
pub async fn init_store(&self, context: database::Context, db: &impl Database) -> Result<()> {
println!("Importing all history.db data into records.db");
println!("Fetching history from old database");
let history = db.list(&[], &context, None, false, true).await?;
println!("Fetching history already in store");
let store_ids = self.history_ids().await?;
for i in history {
println!("loaded {}", i.id);
if store_ids.contains(&i.id) {
println!("skipping {} - already exists", i.id);
continue;
}
if i.deleted_at.is_some() {
self.push(i.clone()).await?;
self.delete(i.id).await?;
} else {
self.push(i).await?;
}
}
Ok(())
}
}
#[cfg(test)]

View File

@ -155,6 +155,18 @@ impl Store for SqliteStore {
self.idx(host, tag, 0).await
}
async fn len_tag(&self, tag: &str) -> Result<u64> {
let res: Result<(i64,), sqlx::Error> =
sqlx::query_as("select count(*) from store where tag=?1")
.bind(tag)
.fetch_one(&self.pool)
.await;
match res {
Err(e) => Err(eyre!("failed to fetch local store len: {}", e)),
Ok(v) => Ok(v.0 as u64),
}
}
async fn len(&self, host: HostId, tag: &str) -> Result<u64> {
let last = self.last(host, tag).await?;
@ -342,6 +354,20 @@ mod tests {
assert_eq!(len, 1, "expected length of 1 after insert");
}
#[tokio::test]
async fn len_tag() {
let db = SqliteStore::new(":memory:", 0.1).await.unwrap();
let record = test_record();
db.push(&record).await.unwrap();
let len = db
.len_tag(record.tag.as_str())
.await
.expect("failed to get store len");
assert_eq!(len, 1, "expected length of 1 after insert");
}
#[tokio::test]
async fn len_different_tags() {
let db = SqliteStore::new(":memory:", 0.1).await.unwrap();
@ -379,6 +405,12 @@ mod tests {
100,
"failed to insert 100 records"
);
assert_eq!(
db.len_tag(tail.tag.as_str()).await.unwrap(),
100,
"failed to insert 100 records"
);
}
#[tokio::test]

View File

@ -21,7 +21,9 @@ pub trait Store {
) -> Result<()>;
async fn get(&self, id: RecordId) -> Result<Record<EncryptedData>>;
async fn len(&self, host: HostId, tag: &str) -> Result<u64>;
async fn len_tag(&self, tag: &str) -> Result<u64>;
async fn last(&self, host: HostId, tag: &str) -> Result<Option<Record<EncryptedData>>>;
async fn first(&self, host: HostId, tag: &str) -> Result<Option<Record<EncryptedData>>>;

View File

@ -14,14 +14,17 @@ pub enum SyncError {
#[error("the local store is ahead of the remote, but for another host. has remote lost data?")]
LocalAheadOtherHost,
#[error("an issue with the local database occured")]
LocalStoreError,
#[error("an issue with the local database occured: {msg:?}")]
LocalStoreError { msg: String },
#[error("something has gone wrong with the sync logic: {msg:?}")]
SyncLogicError { msg: String },
#[error("a request to the sync server failed")]
RemoteRequestError,
#[error("operational error: {msg:?}")]
OperationalError { msg: String },
#[error("a request to the sync server failed: {msg:?}")]
RemoteRequestError { msg: String },
}
#[derive(Debug, Eq, PartialEq)]
@ -45,16 +48,27 @@ pub enum Operation {
},
}
pub async fn diff(settings: &Settings, store: &impl Store) -> Result<(Vec<Diff>, RecordStatus)> {
pub async fn diff(
settings: &Settings,
store: &impl Store,
) -> Result<(Vec<Diff>, RecordStatus), SyncError> {
let client = Client::new(
&settings.sync_address,
&settings.session_token,
settings.network_connect_timeout,
settings.network_timeout,
)?;
)
.map_err(|e| SyncError::OperationalError { msg: e.to_string() })?;
let local_index = store.status().await?;
let remote_index = client.record_status().await?;
let local_index = store
.status()
.await
.map_err(|e| SyncError::LocalStoreError { msg: e.to_string() })?;
let remote_index = client
.record_status()
.await
.map_err(|e| SyncError::RemoteRequestError { msg: e.to_string() })?;
let diff = local_index.diff(&remote_index);
@ -166,13 +180,13 @@ async fn sync_upload(
.map_err(|e| {
error!("failed to read upload page: {e:?}");
SyncError::LocalStoreError
SyncError::LocalStoreError { msg: e.to_string() }
})?;
client.post_records(&page).await.map_err(|e| {
error!("failed to post records: {e:?}");
SyncError::RemoteRequestError
SyncError::RemoteRequestError { msg: e.to_string() }
})?;
println!(
@ -217,12 +231,12 @@ async fn sync_download(
let page = client
.next_records(host, tag.clone(), local + progress, download_page_size)
.await
.map_err(|_| SyncError::RemoteRequestError)?;
.map_err(|e| SyncError::RemoteRequestError { msg: e.to_string() })?;
store
.push_batch(page.iter())
.await
.map_err(|_| SyncError::LocalStoreError)?;
.map_err(|e| SyncError::LocalStoreError { msg: e.to_string() })?;
println!(
"downloaded {} records from remote, progress {}/{}",
@ -283,6 +297,17 @@ pub async fn sync_remote(
Ok((uploaded, downloaded))
}
pub async fn sync(
settings: &Settings,
store: &impl Store,
) -> Result<(i64, Vec<RecordId>), SyncError> {
let (diff, _) = diff(settings, store).await?;
let operations = operations(diff, store).await?;
let (uploaded, downloaded) = sync_remote(operations, store, settings).await?;
Ok((uploaded, downloaded))
}
#[cfg(test)]
mod tests {
use atuin_common::record::{Diff, EncryptedData, HostId, Record};

View File

@ -5,7 +5,6 @@ use std::path::PathBuf;
pub async fn run(settings: &Settings) -> Result<()> {
let session_path = settings.session_path.as_str();
let key_path = settings.key_path.as_str();
if !PathBuf::from(session_path).exists() {
bail!("You are not logged in");
@ -25,10 +24,6 @@ pub async fn run(settings: &Settings) -> Result<()> {
remove_file(PathBuf::from(session_path))?;
}
if PathBuf::from(key_path).exists() {
remove_file(PathBuf::from(key_path))?;
}
println!("Your account is deleted");
Ok(())

View File

@ -49,8 +49,7 @@ pub async fn run(
let mut file = File::create(path).await?;
file.write_all(session.session.as_bytes()).await?;
// Create a new key, and save it to disk
let _key = atuin_client::encryption::new_key(settings)?;
let _key = atuin_client::encryption::load_key(settings)?;
Ok(())
}

View File

@ -88,10 +88,6 @@ pub enum Cmd {
#[arg(long, short)]
format: Option<String>,
},
/// Import all old history.db data into the record store. Do not run more than once, and do not
/// run unless you know what you're doing (or the docs ask you to)
InitStore,
}
#[derive(Clone, Copy, Debug)]
@ -321,10 +317,7 @@ impl Cmd {
#[cfg(feature = "sync")]
{
if settings.sync.records {
let (diff, _) = record::sync::diff(settings, &store).await?;
let operations = record::sync::operations(diff, &store).await?;
let (_, downloaded) =
record::sync::sync_remote(operations, &store, settings).await?;
let (_, downloaded) = record::sync::sync(settings, &store).await?;
history_store.incremental_build(db, &downloaded).await?;
} else {
@ -380,38 +373,6 @@ impl Cmd {
Ok(())
}
async fn init_store(
context: atuin_client::database::Context,
db: &impl Database,
store: HistoryStore,
) -> Result<()> {
println!("Importing all history.db data into records.db");
println!("Fetching history from old database");
let history = db.list(&[], &context, None, false, true).await?;
println!("Fetching history already in store");
let store_ids = store.history_ids().await?;
for i in history {
println!("loaded {}", i.id);
if store_ids.contains(&i.id) {
println!("skipping {} - already exists", i.id);
continue;
}
if i.deleted_at.is_some() {
store.push(i.clone()).await?;
store.delete(i.id).await?;
} else {
store.push(i).await?;
}
}
Ok(())
}
pub async fn run(
self,
settings: &Settings,
@ -468,8 +429,6 @@ impl Cmd {
Ok(())
}
Self::InitStore => Self::init_store(context, db, history_store).await,
}
}
}

View File

@ -2,10 +2,10 @@ use clap::Subcommand;
use eyre::{Result, WrapErr};
use atuin_client::{
database::Database,
database::{current_context, Database},
encryption,
history::store::HistoryStore,
record::{sqlite_store::SqliteStore, sync},
record::{sqlite_store::SqliteStore, store::Store, sync},
settings::Settings,
};
@ -80,10 +80,6 @@ async fn run(
store: SqliteStore,
) -> Result<()> {
if settings.sync.records {
let (diff, _) = sync::diff(settings, &store).await?;
let operations = sync::operations(diff, &store).await?;
let (uploaded, downloaded) = sync::sync_remote(operations, &store, settings).await?;
let encryption_key: [u8; 32] = encryption::load_key(settings)
.context("could not load encryption key")?
.into();
@ -91,6 +87,22 @@ async fn run(
let host_id = Settings::host_id().expect("failed to get host_id");
let history_store = HistoryStore::new(store.clone(), host_id, encryption_key);
let history_length = db.history_count(true).await?;
let store_history_length = store.len_tag("history").await?;
#[allow(clippy::cast_sign_loss)]
if history_length as u64 > store_history_length {
println!("History DB is longer than history record store");
println!("This happens when you used Atuin pre-record-store");
let context = current_context();
history_store.init_store(context, db).await?;
println!("\n");
}
let (uploaded, downloaded) = sync::sync(settings, &store).await?;
history_store.incremental_build(db, &downloaded).await?;
println!("{uploaded}/{} up/down to record store", downloaded.len());