mirror of
https://github.com/atuinsh/atuin.git
synced 2025-01-12 17:28:16 +01:00
feat: automatically init history store when record sync is enabled (#1634)
* add support for getting the total length of a store * tidy up sync * auto call init if history is ahead * fix import order, key regen * fix import order, key regen * do not delete key when user deletes account * message output * remote init store command; this is now automatic * should probs make that function return u64 at some point
This commit is contained in:
parent
15bad15f48
commit
366b8ea97b
@ -4,7 +4,7 @@ use eyre::{bail, eyre, Result};
|
||||
use rmp::decode::Bytes;
|
||||
|
||||
use crate::{
|
||||
database::Database,
|
||||
database::{self, Database},
|
||||
record::{encryption::PASETO_V4, sqlite_store::SqliteStore, store::Store},
|
||||
};
|
||||
use atuin_common::record::{DecryptedData, Host, HostId, Record, RecordId, RecordIdx};
|
||||
@ -255,6 +255,34 @@ impl HistoryStore {
|
||||
|
||||
Ok(ret)
|
||||
}
|
||||
|
||||
pub async fn init_store(&self, context: database::Context, db: &impl Database) -> Result<()> {
|
||||
println!("Importing all history.db data into records.db");
|
||||
|
||||
println!("Fetching history from old database");
|
||||
let history = db.list(&[], &context, None, false, true).await?;
|
||||
|
||||
println!("Fetching history already in store");
|
||||
let store_ids = self.history_ids().await?;
|
||||
|
||||
for i in history {
|
||||
println!("loaded {}", i.id);
|
||||
|
||||
if store_ids.contains(&i.id) {
|
||||
println!("skipping {} - already exists", i.id);
|
||||
continue;
|
||||
}
|
||||
|
||||
if i.deleted_at.is_some() {
|
||||
self.push(i.clone()).await?;
|
||||
self.delete(i.id).await?;
|
||||
} else {
|
||||
self.push(i).await?;
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
|
@ -155,6 +155,18 @@ impl Store for SqliteStore {
|
||||
self.idx(host, tag, 0).await
|
||||
}
|
||||
|
||||
async fn len_tag(&self, tag: &str) -> Result<u64> {
|
||||
let res: Result<(i64,), sqlx::Error> =
|
||||
sqlx::query_as("select count(*) from store where tag=?1")
|
||||
.bind(tag)
|
||||
.fetch_one(&self.pool)
|
||||
.await;
|
||||
match res {
|
||||
Err(e) => Err(eyre!("failed to fetch local store len: {}", e)),
|
||||
Ok(v) => Ok(v.0 as u64),
|
||||
}
|
||||
}
|
||||
|
||||
async fn len(&self, host: HostId, tag: &str) -> Result<u64> {
|
||||
let last = self.last(host, tag).await?;
|
||||
|
||||
@ -342,6 +354,20 @@ mod tests {
|
||||
assert_eq!(len, 1, "expected length of 1 after insert");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn len_tag() {
|
||||
let db = SqliteStore::new(":memory:", 0.1).await.unwrap();
|
||||
let record = test_record();
|
||||
db.push(&record).await.unwrap();
|
||||
|
||||
let len = db
|
||||
.len_tag(record.tag.as_str())
|
||||
.await
|
||||
.expect("failed to get store len");
|
||||
|
||||
assert_eq!(len, 1, "expected length of 1 after insert");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn len_different_tags() {
|
||||
let db = SqliteStore::new(":memory:", 0.1).await.unwrap();
|
||||
@ -379,6 +405,12 @@ mod tests {
|
||||
100,
|
||||
"failed to insert 100 records"
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
db.len_tag(tail.tag.as_str()).await.unwrap(),
|
||||
100,
|
||||
"failed to insert 100 records"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
|
@ -21,7 +21,9 @@ pub trait Store {
|
||||
) -> Result<()>;
|
||||
|
||||
async fn get(&self, id: RecordId) -> Result<Record<EncryptedData>>;
|
||||
|
||||
async fn len(&self, host: HostId, tag: &str) -> Result<u64>;
|
||||
async fn len_tag(&self, tag: &str) -> Result<u64>;
|
||||
|
||||
async fn last(&self, host: HostId, tag: &str) -> Result<Option<Record<EncryptedData>>>;
|
||||
async fn first(&self, host: HostId, tag: &str) -> Result<Option<Record<EncryptedData>>>;
|
||||
|
@ -14,14 +14,17 @@ pub enum SyncError {
|
||||
#[error("the local store is ahead of the remote, but for another host. has remote lost data?")]
|
||||
LocalAheadOtherHost,
|
||||
|
||||
#[error("an issue with the local database occured")]
|
||||
LocalStoreError,
|
||||
#[error("an issue with the local database occured: {msg:?}")]
|
||||
LocalStoreError { msg: String },
|
||||
|
||||
#[error("something has gone wrong with the sync logic: {msg:?}")]
|
||||
SyncLogicError { msg: String },
|
||||
|
||||
#[error("a request to the sync server failed")]
|
||||
RemoteRequestError,
|
||||
#[error("operational error: {msg:?}")]
|
||||
OperationalError { msg: String },
|
||||
|
||||
#[error("a request to the sync server failed: {msg:?}")]
|
||||
RemoteRequestError { msg: String },
|
||||
}
|
||||
|
||||
#[derive(Debug, Eq, PartialEq)]
|
||||
@ -45,16 +48,27 @@ pub enum Operation {
|
||||
},
|
||||
}
|
||||
|
||||
pub async fn diff(settings: &Settings, store: &impl Store) -> Result<(Vec<Diff>, RecordStatus)> {
|
||||
pub async fn diff(
|
||||
settings: &Settings,
|
||||
store: &impl Store,
|
||||
) -> Result<(Vec<Diff>, RecordStatus), SyncError> {
|
||||
let client = Client::new(
|
||||
&settings.sync_address,
|
||||
&settings.session_token,
|
||||
settings.network_connect_timeout,
|
||||
settings.network_timeout,
|
||||
)?;
|
||||
)
|
||||
.map_err(|e| SyncError::OperationalError { msg: e.to_string() })?;
|
||||
|
||||
let local_index = store.status().await?;
|
||||
let remote_index = client.record_status().await?;
|
||||
let local_index = store
|
||||
.status()
|
||||
.await
|
||||
.map_err(|e| SyncError::LocalStoreError { msg: e.to_string() })?;
|
||||
|
||||
let remote_index = client
|
||||
.record_status()
|
||||
.await
|
||||
.map_err(|e| SyncError::RemoteRequestError { msg: e.to_string() })?;
|
||||
|
||||
let diff = local_index.diff(&remote_index);
|
||||
|
||||
@ -166,13 +180,13 @@ async fn sync_upload(
|
||||
.map_err(|e| {
|
||||
error!("failed to read upload page: {e:?}");
|
||||
|
||||
SyncError::LocalStoreError
|
||||
SyncError::LocalStoreError { msg: e.to_string() }
|
||||
})?;
|
||||
|
||||
client.post_records(&page).await.map_err(|e| {
|
||||
error!("failed to post records: {e:?}");
|
||||
|
||||
SyncError::RemoteRequestError
|
||||
SyncError::RemoteRequestError { msg: e.to_string() }
|
||||
})?;
|
||||
|
||||
println!(
|
||||
@ -217,12 +231,12 @@ async fn sync_download(
|
||||
let page = client
|
||||
.next_records(host, tag.clone(), local + progress, download_page_size)
|
||||
.await
|
||||
.map_err(|_| SyncError::RemoteRequestError)?;
|
||||
.map_err(|e| SyncError::RemoteRequestError { msg: e.to_string() })?;
|
||||
|
||||
store
|
||||
.push_batch(page.iter())
|
||||
.await
|
||||
.map_err(|_| SyncError::LocalStoreError)?;
|
||||
.map_err(|e| SyncError::LocalStoreError { msg: e.to_string() })?;
|
||||
|
||||
println!(
|
||||
"downloaded {} records from remote, progress {}/{}",
|
||||
@ -283,6 +297,17 @@ pub async fn sync_remote(
|
||||
Ok((uploaded, downloaded))
|
||||
}
|
||||
|
||||
pub async fn sync(
|
||||
settings: &Settings,
|
||||
store: &impl Store,
|
||||
) -> Result<(i64, Vec<RecordId>), SyncError> {
|
||||
let (diff, _) = diff(settings, store).await?;
|
||||
let operations = operations(diff, store).await?;
|
||||
let (uploaded, downloaded) = sync_remote(operations, store, settings).await?;
|
||||
|
||||
Ok((uploaded, downloaded))
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use atuin_common::record::{Diff, EncryptedData, HostId, Record};
|
||||
|
@ -5,7 +5,6 @@ use std::path::PathBuf;
|
||||
|
||||
pub async fn run(settings: &Settings) -> Result<()> {
|
||||
let session_path = settings.session_path.as_str();
|
||||
let key_path = settings.key_path.as_str();
|
||||
|
||||
if !PathBuf::from(session_path).exists() {
|
||||
bail!("You are not logged in");
|
||||
@ -25,10 +24,6 @@ pub async fn run(settings: &Settings) -> Result<()> {
|
||||
remove_file(PathBuf::from(session_path))?;
|
||||
}
|
||||
|
||||
if PathBuf::from(key_path).exists() {
|
||||
remove_file(PathBuf::from(key_path))?;
|
||||
}
|
||||
|
||||
println!("Your account is deleted");
|
||||
|
||||
Ok(())
|
||||
|
@ -49,8 +49,7 @@ pub async fn run(
|
||||
let mut file = File::create(path).await?;
|
||||
file.write_all(session.session.as_bytes()).await?;
|
||||
|
||||
// Create a new key, and save it to disk
|
||||
let _key = atuin_client::encryption::new_key(settings)?;
|
||||
let _key = atuin_client::encryption::load_key(settings)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
@ -88,10 +88,6 @@ pub enum Cmd {
|
||||
#[arg(long, short)]
|
||||
format: Option<String>,
|
||||
},
|
||||
|
||||
/// Import all old history.db data into the record store. Do not run more than once, and do not
|
||||
/// run unless you know what you're doing (or the docs ask you to)
|
||||
InitStore,
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy, Debug)]
|
||||
@ -321,10 +317,7 @@ impl Cmd {
|
||||
#[cfg(feature = "sync")]
|
||||
{
|
||||
if settings.sync.records {
|
||||
let (diff, _) = record::sync::diff(settings, &store).await?;
|
||||
let operations = record::sync::operations(diff, &store).await?;
|
||||
let (_, downloaded) =
|
||||
record::sync::sync_remote(operations, &store, settings).await?;
|
||||
let (_, downloaded) = record::sync::sync(settings, &store).await?;
|
||||
|
||||
history_store.incremental_build(db, &downloaded).await?;
|
||||
} else {
|
||||
@ -380,38 +373,6 @@ impl Cmd {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn init_store(
|
||||
context: atuin_client::database::Context,
|
||||
db: &impl Database,
|
||||
store: HistoryStore,
|
||||
) -> Result<()> {
|
||||
println!("Importing all history.db data into records.db");
|
||||
|
||||
println!("Fetching history from old database");
|
||||
let history = db.list(&[], &context, None, false, true).await?;
|
||||
|
||||
println!("Fetching history already in store");
|
||||
let store_ids = store.history_ids().await?;
|
||||
|
||||
for i in history {
|
||||
println!("loaded {}", i.id);
|
||||
|
||||
if store_ids.contains(&i.id) {
|
||||
println!("skipping {} - already exists", i.id);
|
||||
continue;
|
||||
}
|
||||
|
||||
if i.deleted_at.is_some() {
|
||||
store.push(i.clone()).await?;
|
||||
store.delete(i.id).await?;
|
||||
} else {
|
||||
store.push(i).await?;
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn run(
|
||||
self,
|
||||
settings: &Settings,
|
||||
@ -468,8 +429,6 @@ impl Cmd {
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
Self::InitStore => Self::init_store(context, db, history_store).await,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -2,10 +2,10 @@ use clap::Subcommand;
|
||||
use eyre::{Result, WrapErr};
|
||||
|
||||
use atuin_client::{
|
||||
database::Database,
|
||||
database::{current_context, Database},
|
||||
encryption,
|
||||
history::store::HistoryStore,
|
||||
record::{sqlite_store::SqliteStore, sync},
|
||||
record::{sqlite_store::SqliteStore, store::Store, sync},
|
||||
settings::Settings,
|
||||
};
|
||||
|
||||
@ -80,10 +80,6 @@ async fn run(
|
||||
store: SqliteStore,
|
||||
) -> Result<()> {
|
||||
if settings.sync.records {
|
||||
let (diff, _) = sync::diff(settings, &store).await?;
|
||||
let operations = sync::operations(diff, &store).await?;
|
||||
let (uploaded, downloaded) = sync::sync_remote(operations, &store, settings).await?;
|
||||
|
||||
let encryption_key: [u8; 32] = encryption::load_key(settings)
|
||||
.context("could not load encryption key")?
|
||||
.into();
|
||||
@ -91,6 +87,22 @@ async fn run(
|
||||
let host_id = Settings::host_id().expect("failed to get host_id");
|
||||
let history_store = HistoryStore::new(store.clone(), host_id, encryption_key);
|
||||
|
||||
let history_length = db.history_count(true).await?;
|
||||
let store_history_length = store.len_tag("history").await?;
|
||||
|
||||
#[allow(clippy::cast_sign_loss)]
|
||||
if history_length as u64 > store_history_length {
|
||||
println!("History DB is longer than history record store");
|
||||
println!("This happens when you used Atuin pre-record-store");
|
||||
|
||||
let context = current_context();
|
||||
history_store.init_store(context, db).await?;
|
||||
|
||||
println!("\n");
|
||||
}
|
||||
|
||||
let (uploaded, downloaded) = sync::sync(settings, &store).await?;
|
||||
|
||||
history_store.incremental_build(db, &downloaded).await?;
|
||||
|
||||
println!("{uploaded}/{} up/down to record store", downloaded.len());
|
||||
|
Loading…
Reference in New Issue
Block a user