mirror of
https://github.com/atuinsh/atuin.git
synced 2025-06-20 18:07:57 +02:00
feat: automatically init history store when record sync is enabled (#1634)
* add support for getting the total length of a store * tidy up sync * auto call init if history is ahead * fix import order, key regen * fix import order, key regen * do not delete key when user deletes account * message output * remote init store command; this is now automatic * should probs make that function return u64 at some point
This commit is contained in:
parent
15bad15f48
commit
366b8ea97b
@ -4,7 +4,7 @@ use eyre::{bail, eyre, Result};
|
|||||||
use rmp::decode::Bytes;
|
use rmp::decode::Bytes;
|
||||||
|
|
||||||
use crate::{
|
use crate::{
|
||||||
database::Database,
|
database::{self, Database},
|
||||||
record::{encryption::PASETO_V4, sqlite_store::SqliteStore, store::Store},
|
record::{encryption::PASETO_V4, sqlite_store::SqliteStore, store::Store},
|
||||||
};
|
};
|
||||||
use atuin_common::record::{DecryptedData, Host, HostId, Record, RecordId, RecordIdx};
|
use atuin_common::record::{DecryptedData, Host, HostId, Record, RecordId, RecordIdx};
|
||||||
@ -255,6 +255,34 @@ impl HistoryStore {
|
|||||||
|
|
||||||
Ok(ret)
|
Ok(ret)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub async fn init_store(&self, context: database::Context, db: &impl Database) -> Result<()> {
|
||||||
|
println!("Importing all history.db data into records.db");
|
||||||
|
|
||||||
|
println!("Fetching history from old database");
|
||||||
|
let history = db.list(&[], &context, None, false, true).await?;
|
||||||
|
|
||||||
|
println!("Fetching history already in store");
|
||||||
|
let store_ids = self.history_ids().await?;
|
||||||
|
|
||||||
|
for i in history {
|
||||||
|
println!("loaded {}", i.id);
|
||||||
|
|
||||||
|
if store_ids.contains(&i.id) {
|
||||||
|
println!("skipping {} - already exists", i.id);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
if i.deleted_at.is_some() {
|
||||||
|
self.push(i.clone()).await?;
|
||||||
|
self.delete(i.id).await?;
|
||||||
|
} else {
|
||||||
|
self.push(i).await?;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
|
@ -155,6 +155,18 @@ impl Store for SqliteStore {
|
|||||||
self.idx(host, tag, 0).await
|
self.idx(host, tag, 0).await
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async fn len_tag(&self, tag: &str) -> Result<u64> {
|
||||||
|
let res: Result<(i64,), sqlx::Error> =
|
||||||
|
sqlx::query_as("select count(*) from store where tag=?1")
|
||||||
|
.bind(tag)
|
||||||
|
.fetch_one(&self.pool)
|
||||||
|
.await;
|
||||||
|
match res {
|
||||||
|
Err(e) => Err(eyre!("failed to fetch local store len: {}", e)),
|
||||||
|
Ok(v) => Ok(v.0 as u64),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
async fn len(&self, host: HostId, tag: &str) -> Result<u64> {
|
async fn len(&self, host: HostId, tag: &str) -> Result<u64> {
|
||||||
let last = self.last(host, tag).await?;
|
let last = self.last(host, tag).await?;
|
||||||
|
|
||||||
@ -342,6 +354,20 @@ mod tests {
|
|||||||
assert_eq!(len, 1, "expected length of 1 after insert");
|
assert_eq!(len, 1, "expected length of 1 after insert");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn len_tag() {
|
||||||
|
let db = SqliteStore::new(":memory:", 0.1).await.unwrap();
|
||||||
|
let record = test_record();
|
||||||
|
db.push(&record).await.unwrap();
|
||||||
|
|
||||||
|
let len = db
|
||||||
|
.len_tag(record.tag.as_str())
|
||||||
|
.await
|
||||||
|
.expect("failed to get store len");
|
||||||
|
|
||||||
|
assert_eq!(len, 1, "expected length of 1 after insert");
|
||||||
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn len_different_tags() {
|
async fn len_different_tags() {
|
||||||
let db = SqliteStore::new(":memory:", 0.1).await.unwrap();
|
let db = SqliteStore::new(":memory:", 0.1).await.unwrap();
|
||||||
@ -379,6 +405,12 @@ mod tests {
|
|||||||
100,
|
100,
|
||||||
"failed to insert 100 records"
|
"failed to insert 100 records"
|
||||||
);
|
);
|
||||||
|
|
||||||
|
assert_eq!(
|
||||||
|
db.len_tag(tail.tag.as_str()).await.unwrap(),
|
||||||
|
100,
|
||||||
|
"failed to insert 100 records"
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
|
@ -21,7 +21,9 @@ pub trait Store {
|
|||||||
) -> Result<()>;
|
) -> Result<()>;
|
||||||
|
|
||||||
async fn get(&self, id: RecordId) -> Result<Record<EncryptedData>>;
|
async fn get(&self, id: RecordId) -> Result<Record<EncryptedData>>;
|
||||||
|
|
||||||
async fn len(&self, host: HostId, tag: &str) -> Result<u64>;
|
async fn len(&self, host: HostId, tag: &str) -> Result<u64>;
|
||||||
|
async fn len_tag(&self, tag: &str) -> Result<u64>;
|
||||||
|
|
||||||
async fn last(&self, host: HostId, tag: &str) -> Result<Option<Record<EncryptedData>>>;
|
async fn last(&self, host: HostId, tag: &str) -> Result<Option<Record<EncryptedData>>>;
|
||||||
async fn first(&self, host: HostId, tag: &str) -> Result<Option<Record<EncryptedData>>>;
|
async fn first(&self, host: HostId, tag: &str) -> Result<Option<Record<EncryptedData>>>;
|
||||||
|
@ -14,14 +14,17 @@ pub enum SyncError {
|
|||||||
#[error("the local store is ahead of the remote, but for another host. has remote lost data?")]
|
#[error("the local store is ahead of the remote, but for another host. has remote lost data?")]
|
||||||
LocalAheadOtherHost,
|
LocalAheadOtherHost,
|
||||||
|
|
||||||
#[error("an issue with the local database occured")]
|
#[error("an issue with the local database occured: {msg:?}")]
|
||||||
LocalStoreError,
|
LocalStoreError { msg: String },
|
||||||
|
|
||||||
#[error("something has gone wrong with the sync logic: {msg:?}")]
|
#[error("something has gone wrong with the sync logic: {msg:?}")]
|
||||||
SyncLogicError { msg: String },
|
SyncLogicError { msg: String },
|
||||||
|
|
||||||
#[error("a request to the sync server failed")]
|
#[error("operational error: {msg:?}")]
|
||||||
RemoteRequestError,
|
OperationalError { msg: String },
|
||||||
|
|
||||||
|
#[error("a request to the sync server failed: {msg:?}")]
|
||||||
|
RemoteRequestError { msg: String },
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Eq, PartialEq)]
|
#[derive(Debug, Eq, PartialEq)]
|
||||||
@ -45,16 +48,27 @@ pub enum Operation {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn diff(settings: &Settings, store: &impl Store) -> Result<(Vec<Diff>, RecordStatus)> {
|
pub async fn diff(
|
||||||
|
settings: &Settings,
|
||||||
|
store: &impl Store,
|
||||||
|
) -> Result<(Vec<Diff>, RecordStatus), SyncError> {
|
||||||
let client = Client::new(
|
let client = Client::new(
|
||||||
&settings.sync_address,
|
&settings.sync_address,
|
||||||
&settings.session_token,
|
&settings.session_token,
|
||||||
settings.network_connect_timeout,
|
settings.network_connect_timeout,
|
||||||
settings.network_timeout,
|
settings.network_timeout,
|
||||||
)?;
|
)
|
||||||
|
.map_err(|e| SyncError::OperationalError { msg: e.to_string() })?;
|
||||||
|
|
||||||
let local_index = store.status().await?;
|
let local_index = store
|
||||||
let remote_index = client.record_status().await?;
|
.status()
|
||||||
|
.await
|
||||||
|
.map_err(|e| SyncError::LocalStoreError { msg: e.to_string() })?;
|
||||||
|
|
||||||
|
let remote_index = client
|
||||||
|
.record_status()
|
||||||
|
.await
|
||||||
|
.map_err(|e| SyncError::RemoteRequestError { msg: e.to_string() })?;
|
||||||
|
|
||||||
let diff = local_index.diff(&remote_index);
|
let diff = local_index.diff(&remote_index);
|
||||||
|
|
||||||
@ -166,13 +180,13 @@ async fn sync_upload(
|
|||||||
.map_err(|e| {
|
.map_err(|e| {
|
||||||
error!("failed to read upload page: {e:?}");
|
error!("failed to read upload page: {e:?}");
|
||||||
|
|
||||||
SyncError::LocalStoreError
|
SyncError::LocalStoreError { msg: e.to_string() }
|
||||||
})?;
|
})?;
|
||||||
|
|
||||||
client.post_records(&page).await.map_err(|e| {
|
client.post_records(&page).await.map_err(|e| {
|
||||||
error!("failed to post records: {e:?}");
|
error!("failed to post records: {e:?}");
|
||||||
|
|
||||||
SyncError::RemoteRequestError
|
SyncError::RemoteRequestError { msg: e.to_string() }
|
||||||
})?;
|
})?;
|
||||||
|
|
||||||
println!(
|
println!(
|
||||||
@ -217,12 +231,12 @@ async fn sync_download(
|
|||||||
let page = client
|
let page = client
|
||||||
.next_records(host, tag.clone(), local + progress, download_page_size)
|
.next_records(host, tag.clone(), local + progress, download_page_size)
|
||||||
.await
|
.await
|
||||||
.map_err(|_| SyncError::RemoteRequestError)?;
|
.map_err(|e| SyncError::RemoteRequestError { msg: e.to_string() })?;
|
||||||
|
|
||||||
store
|
store
|
||||||
.push_batch(page.iter())
|
.push_batch(page.iter())
|
||||||
.await
|
.await
|
||||||
.map_err(|_| SyncError::LocalStoreError)?;
|
.map_err(|e| SyncError::LocalStoreError { msg: e.to_string() })?;
|
||||||
|
|
||||||
println!(
|
println!(
|
||||||
"downloaded {} records from remote, progress {}/{}",
|
"downloaded {} records from remote, progress {}/{}",
|
||||||
@ -283,6 +297,17 @@ pub async fn sync_remote(
|
|||||||
Ok((uploaded, downloaded))
|
Ok((uploaded, downloaded))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub async fn sync(
|
||||||
|
settings: &Settings,
|
||||||
|
store: &impl Store,
|
||||||
|
) -> Result<(i64, Vec<RecordId>), SyncError> {
|
||||||
|
let (diff, _) = diff(settings, store).await?;
|
||||||
|
let operations = operations(diff, store).await?;
|
||||||
|
let (uploaded, downloaded) = sync_remote(operations, store, settings).await?;
|
||||||
|
|
||||||
|
Ok((uploaded, downloaded))
|
||||||
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use atuin_common::record::{Diff, EncryptedData, HostId, Record};
|
use atuin_common::record::{Diff, EncryptedData, HostId, Record};
|
||||||
|
@ -5,7 +5,6 @@ use std::path::PathBuf;
|
|||||||
|
|
||||||
pub async fn run(settings: &Settings) -> Result<()> {
|
pub async fn run(settings: &Settings) -> Result<()> {
|
||||||
let session_path = settings.session_path.as_str();
|
let session_path = settings.session_path.as_str();
|
||||||
let key_path = settings.key_path.as_str();
|
|
||||||
|
|
||||||
if !PathBuf::from(session_path).exists() {
|
if !PathBuf::from(session_path).exists() {
|
||||||
bail!("You are not logged in");
|
bail!("You are not logged in");
|
||||||
@ -25,10 +24,6 @@ pub async fn run(settings: &Settings) -> Result<()> {
|
|||||||
remove_file(PathBuf::from(session_path))?;
|
remove_file(PathBuf::from(session_path))?;
|
||||||
}
|
}
|
||||||
|
|
||||||
if PathBuf::from(key_path).exists() {
|
|
||||||
remove_file(PathBuf::from(key_path))?;
|
|
||||||
}
|
|
||||||
|
|
||||||
println!("Your account is deleted");
|
println!("Your account is deleted");
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
|
@ -49,8 +49,7 @@ pub async fn run(
|
|||||||
let mut file = File::create(path).await?;
|
let mut file = File::create(path).await?;
|
||||||
file.write_all(session.session.as_bytes()).await?;
|
file.write_all(session.session.as_bytes()).await?;
|
||||||
|
|
||||||
// Create a new key, and save it to disk
|
let _key = atuin_client::encryption::load_key(settings)?;
|
||||||
let _key = atuin_client::encryption::new_key(settings)?;
|
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
@ -88,10 +88,6 @@ pub enum Cmd {
|
|||||||
#[arg(long, short)]
|
#[arg(long, short)]
|
||||||
format: Option<String>,
|
format: Option<String>,
|
||||||
},
|
},
|
||||||
|
|
||||||
/// Import all old history.db data into the record store. Do not run more than once, and do not
|
|
||||||
/// run unless you know what you're doing (or the docs ask you to)
|
|
||||||
InitStore,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone, Copy, Debug)]
|
#[derive(Clone, Copy, Debug)]
|
||||||
@ -321,10 +317,7 @@ impl Cmd {
|
|||||||
#[cfg(feature = "sync")]
|
#[cfg(feature = "sync")]
|
||||||
{
|
{
|
||||||
if settings.sync.records {
|
if settings.sync.records {
|
||||||
let (diff, _) = record::sync::diff(settings, &store).await?;
|
let (_, downloaded) = record::sync::sync(settings, &store).await?;
|
||||||
let operations = record::sync::operations(diff, &store).await?;
|
|
||||||
let (_, downloaded) =
|
|
||||||
record::sync::sync_remote(operations, &store, settings).await?;
|
|
||||||
|
|
||||||
history_store.incremental_build(db, &downloaded).await?;
|
history_store.incremental_build(db, &downloaded).await?;
|
||||||
} else {
|
} else {
|
||||||
@ -380,38 +373,6 @@ impl Cmd {
|
|||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn init_store(
|
|
||||||
context: atuin_client::database::Context,
|
|
||||||
db: &impl Database,
|
|
||||||
store: HistoryStore,
|
|
||||||
) -> Result<()> {
|
|
||||||
println!("Importing all history.db data into records.db");
|
|
||||||
|
|
||||||
println!("Fetching history from old database");
|
|
||||||
let history = db.list(&[], &context, None, false, true).await?;
|
|
||||||
|
|
||||||
println!("Fetching history already in store");
|
|
||||||
let store_ids = store.history_ids().await?;
|
|
||||||
|
|
||||||
for i in history {
|
|
||||||
println!("loaded {}", i.id);
|
|
||||||
|
|
||||||
if store_ids.contains(&i.id) {
|
|
||||||
println!("skipping {} - already exists", i.id);
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
if i.deleted_at.is_some() {
|
|
||||||
store.push(i.clone()).await?;
|
|
||||||
store.delete(i.id).await?;
|
|
||||||
} else {
|
|
||||||
store.push(i).await?;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
pub async fn run(
|
pub async fn run(
|
||||||
self,
|
self,
|
||||||
settings: &Settings,
|
settings: &Settings,
|
||||||
@ -468,8 +429,6 @@ impl Cmd {
|
|||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
Self::InitStore => Self::init_store(context, db, history_store).await,
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -2,10 +2,10 @@ use clap::Subcommand;
|
|||||||
use eyre::{Result, WrapErr};
|
use eyre::{Result, WrapErr};
|
||||||
|
|
||||||
use atuin_client::{
|
use atuin_client::{
|
||||||
database::Database,
|
database::{current_context, Database},
|
||||||
encryption,
|
encryption,
|
||||||
history::store::HistoryStore,
|
history::store::HistoryStore,
|
||||||
record::{sqlite_store::SqliteStore, sync},
|
record::{sqlite_store::SqliteStore, store::Store, sync},
|
||||||
settings::Settings,
|
settings::Settings,
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -80,10 +80,6 @@ async fn run(
|
|||||||
store: SqliteStore,
|
store: SqliteStore,
|
||||||
) -> Result<()> {
|
) -> Result<()> {
|
||||||
if settings.sync.records {
|
if settings.sync.records {
|
||||||
let (diff, _) = sync::diff(settings, &store).await?;
|
|
||||||
let operations = sync::operations(diff, &store).await?;
|
|
||||||
let (uploaded, downloaded) = sync::sync_remote(operations, &store, settings).await?;
|
|
||||||
|
|
||||||
let encryption_key: [u8; 32] = encryption::load_key(settings)
|
let encryption_key: [u8; 32] = encryption::load_key(settings)
|
||||||
.context("could not load encryption key")?
|
.context("could not load encryption key")?
|
||||||
.into();
|
.into();
|
||||||
@ -91,6 +87,22 @@ async fn run(
|
|||||||
let host_id = Settings::host_id().expect("failed to get host_id");
|
let host_id = Settings::host_id().expect("failed to get host_id");
|
||||||
let history_store = HistoryStore::new(store.clone(), host_id, encryption_key);
|
let history_store = HistoryStore::new(store.clone(), host_id, encryption_key);
|
||||||
|
|
||||||
|
let history_length = db.history_count(true).await?;
|
||||||
|
let store_history_length = store.len_tag("history").await?;
|
||||||
|
|
||||||
|
#[allow(clippy::cast_sign_loss)]
|
||||||
|
if history_length as u64 > store_history_length {
|
||||||
|
println!("History DB is longer than history record store");
|
||||||
|
println!("This happens when you used Atuin pre-record-store");
|
||||||
|
|
||||||
|
let context = current_context();
|
||||||
|
history_store.init_store(context, db).await?;
|
||||||
|
|
||||||
|
println!("\n");
|
||||||
|
}
|
||||||
|
|
||||||
|
let (uploaded, downloaded) = sync::sync(settings, &store).await?;
|
||||||
|
|
||||||
history_store.incremental_build(db, &downloaded).await?;
|
history_store.incremental_build(db, &downloaded).await?;
|
||||||
|
|
||||||
println!("{uploaded}/{} up/down to record store", downloaded.len());
|
println!("{uploaded}/{} up/down to record store", downloaded.len());
|
||||||
|
Loading…
x
Reference in New Issue
Block a user