diff --git a/atuin-client/src/api_client.rs b/atuin-client/src/api_client.rs index 5cab36fd..b4066197 100644 --- a/atuin-client/src/api_client.rs +++ b/atuin-client/src/api_client.rs @@ -8,7 +8,7 @@ use reqwest::{ StatusCode, Url, }; -use atuin_common::record::{EncryptedData, HostId, Record, RecordId}; +use atuin_common::record::{EncryptedData, HostId, Record, RecordId, RecordIdx}; use atuin_common::{ api::{ AddHistoryRequest, CountResponse, DeleteHistoryRequest, ErrorResponse, IndexResponse, @@ -231,14 +231,14 @@ impl<'a> Client<'a> { &self, host: HostId, tag: String, - start: Option, + start: RecordIdx, count: u64, ) -> Result>> { debug!( "fetching record/s from host {}/{}/{}", host.0.to_string(), tag, - start.map_or(String::from("empty"), |f| f.0.to_string()) + start ); let url = format!( diff --git a/atuin-client/src/record/sqlite_store.rs b/atuin-client/src/record/sqlite_store.rs index 50ed4fe0..245baefd 100644 --- a/atuin-client/src/record/sqlite_store.rs +++ b/atuin-client/src/record/sqlite_store.rs @@ -170,9 +170,9 @@ impl Store for SqliteStore { limit: u64, ) -> Result>> { let res = - sqlx::query("select * from store where idx > ?1 and host = ?2 and tag = ?3 limit ?4") + sqlx::query("select * from store where idx >= ?1 and host = ?2 and tag = ?3 limit ?4") .bind(idx as i64) - .bind(host) + .bind(host.0.as_hyphenated().to_string()) .bind(tag) .bind(limit as i64) .map(Self::query_row) diff --git a/atuin-client/src/record/sync.rs b/atuin-client/src/record/sync.rs index b43199b7..11b4c268 100644 --- a/atuin-client/src/record/sync.rs +++ b/atuin-client/src/record/sync.rs @@ -12,11 +12,14 @@ pub enum SyncError { #[error("the local store is ahead of the remote, but for another host. has remote lost data?")] LocalAheadOtherHost, - #[error("some issue with the local database occured")] + #[error("an issue with the local database occured")] LocalStoreError, #[error("something has gone wrong with the sync logic: {msg:?}")] SyncLogicError { msg: String }, + + #[error("a request to the sync server failed")] + RemoteRequestError, } #[derive(Debug, Eq, PartialEq)] @@ -63,7 +66,10 @@ pub async fn diff( // With the store as context, we can determine if a tail exists locally or not and therefore if it needs uploading or download. // In theory this could be done as a part of the diffing stage, but it's easier to reason // about and test this way -pub async fn operations(diffs: Vec, _store: &impl Store) -> Result, SyncError> { +pub async fn operations( + diffs: Vec, + _store: &impl Store, +) -> Result, SyncError> { let mut operations = Vec::with_capacity(diffs.len()); let _host = Settings::host_id().expect("got to record sync without a host id; abort"); @@ -141,22 +147,16 @@ pub async fn operations(diffs: Vec, _store: &impl Store) -> Result, + store: &mut impl Store, + client: &Client<'_>, host: HostId, tag: String, local: RecordIdx, remote: Option, ) -> Result { let expected = local - remote.unwrap_or(0); - let _upload_page_size = 100; - let _total = 0; - - if expected < 0 { - return Err(SyncError::SyncLogicError { - msg: String::from("ran upload, but remote ahead of local"), - }); - } + let upload_page_size = 100; + let mut progress = 0; println!( "Uploading {} records to {}/{}", @@ -165,28 +165,47 @@ async fn sync_upload( tag ); - // TODO: actually upload lmfao + // preload with the first entry if remote does not know of this store + while progress < expected { + let page = store + .next( + host, + tag.as_str(), + remote.unwrap_or(0) + progress, + upload_page_size, + ) + .await + .map_err(|_| SyncError::LocalStoreError)?; - Ok(0) + let _ = client + .post_records(&page) + .await + .map_err(|_| SyncError::RemoteRequestError)?; + + println!( + "uploaded {} to remote, progress {}/{}", + page.len(), + progress, + expected + ); + progress += page.len() as u64; + } + + Ok(progress as i64) } async fn sync_download( - _store: &mut impl Store, - _client: &Client<'_>, + store: &mut impl Store, + client: &Client<'_>, host: HostId, tag: String, local: Option, remote: RecordIdx, ) -> Result { + let local = local.unwrap_or(0); let expected = remote - local.unwrap_or(0); - let _download_page_size = 100; - let _total = 0; - - if expected < 0 { - return Err(SyncError::SyncLogicError { - msg: String::from("ran download, but local ahead of remote"), - }); - } + let download_page_size = 100; + let mut progress = 0; println!( "Downloading {} records from {}/{}", @@ -195,7 +214,25 @@ async fn sync_download( tag ); - // TODO: actually upload lmfao + // preload with the first entry if remote does not know of this store + while progress < expected { + let page = client.next_records(host, tag, Some(local + progress), download_page_size); + + let _ = client + .post_records(&page) + .await + .map_err(|_| SyncError::RemoteRequestError)?; + + println!( + "uploaded {} to remote, progress {}/{}", + page.len(), + progress, + expected + ); + progress += page.len() as u64; + } + + Ok(progress as i64) Ok(0) } @@ -204,13 +241,14 @@ pub async fn sync_remote( operations: Vec, local_store: &mut impl Store, settings: &Settings, -) -> Result<(i64, i64)> { +) -> Result<(i64, i64), SyncError> { let client = Client::new( &settings.sync_address, &settings.session_token, settings.network_connect_timeout, settings.network_timeout, - )?; + ) + .expect("failed to create client"); let mut uploaded = 0; let mut downloaded = 0; diff --git a/atuin-server-database/src/lib.rs b/atuin-server-database/src/lib.rs index c60fdcf9..23ad540d 100644 --- a/atuin-server-database/src/lib.rs +++ b/atuin-server-database/src/lib.rs @@ -14,7 +14,7 @@ use self::{ models::{History, NewHistory, NewSession, NewUser, Session, User}, }; use async_trait::async_trait; -use atuin_common::record::{EncryptedData, HostId, Record, RecordId, RecordStatus}; +use atuin_common::record::{EncryptedData, HostId, Record, RecordId, RecordIdx, RecordStatus}; use serde::{de::DeserializeOwned, Serialize}; use time::{Date, Duration, Month, OffsetDateTime, Time, UtcOffset}; use tracing::instrument; @@ -68,12 +68,12 @@ pub trait Database: Sized + Clone + Send + Sync + 'static { user: &User, host: HostId, tag: String, - start: Option, + start: Option, count: u64, ) -> DbResult>>; // Return the tail record ID for each store, so (HostID, Tag, TailRecordID) - async fn tail_records(&self, user: &User) -> DbResult; + async fn status(&self, user: &User) -> DbResult; async fn count_history_range(&self, user: &User, range: Range) -> DbResult; diff --git a/atuin-server-postgres/migrations/20231202170508_create-store.sql b/atuin-server-postgres/migrations/20231202170508_create-store.sql new file mode 100644 index 00000000..ffb57966 --- /dev/null +++ b/atuin-server-postgres/migrations/20231202170508_create-store.sql @@ -0,0 +1,15 @@ +-- Add migration script here +create table store ( + id uuid primary key, -- remember to use uuidv7 for happy indices <3 + client_id uuid not null, -- I am too uncomfortable with the idea of a client-generated primary key, even though it's fine mathematically + host uuid not null, -- a unique identifier for the host + idx bigint not null, -- the index of the record in this store, identified by (host, tag) + timestamp bigint not null, -- not a timestamp type, as those do not have nanosecond precision + version text not null, + tag text not null, -- what is this? history, kv, whatever. Remember clients get a log per tag per host + data text not null, -- store the actual history data, encrypted. I don't wanna know! + cek text not null, + + user_id bigint not null, -- allow multiple users + created_at timestamp not null default current_timestamp +); diff --git a/atuin-server-postgres/src/lib.rs b/atuin-server-postgres/src/lib.rs index f22e6bee..b838534a 100644 --- a/atuin-server-postgres/src/lib.rs +++ b/atuin-server-postgres/src/lib.rs @@ -1,7 +1,7 @@ use std::ops::Range; use async_trait::async_trait; -use atuin_common::record::{EncryptedData, HostId, Record, RecordId, RecordIndex}; +use atuin_common::record::{EncryptedData, HostId, Record, RecordId, RecordIdx, RecordStatus}; use atuin_server_database::models::{History, NewHistory, NewSession, NewUser, Session, User}; use atuin_server_database::{Database, DbError, DbResult}; use futures_util::TryStreamExt; @@ -11,6 +11,7 @@ use sqlx::Row; use time::{OffsetDateTime, PrimitiveDateTime, UtcOffset}; use tracing::instrument; +use uuid::Uuid; use wrappers::{DbHistory, DbRecord, DbSession, DbUser}; mod wrappers; @@ -361,16 +362,16 @@ impl Database for Postgres { let id = atuin_common::utils::uuid_v7(); sqlx::query( - "insert into records - (id, client_id, host, parent, timestamp, version, tag, data, cek, user_id) + "insert into store + (id, client_id, host, idx, timestamp, version, tag, data, cek, user_id) values ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10) on conflict do nothing ", ) .bind(id) .bind(i.id) - .bind(i.host) - .bind(i.parent) + .bind(i.host.id) + .bind(i.idx as i64) .bind(i.timestamp as i64) // throwing away some data, but i64 is still big in terms of time .bind(&i.version) .bind(&i.tag) @@ -393,62 +394,71 @@ impl Database for Postgres { user: &User, host: HostId, tag: String, - start: Option, + start: Option, count: u64, ) -> DbResult>> { tracing::debug!("{:?} - {:?} - {:?}", host, tag, start); let mut ret = Vec::with_capacity(count as usize); let mut parent = start; + let start = start.unwrap_or(0); - // yeah let's do something better - for _ in 0..count { - // a very much not ideal query. but it's simple at least? - // we are basically using postgres as a kv store here, so... maybe consider using an actual - // kv store? - let record: Result = sqlx::query_as( - "select client_id, host, parent, timestamp, version, tag, data, cek from records + let records: Result, DbError> = sqlx::query_as( + "select client_id, host, idx, timestamp, version, tag, data, cek from store where user_id = $1 and tag = $2 and host = $3 - and parent is not distinct from $4", - ) - .bind(user.id) - .bind(tag.clone()) - .bind(host) - .bind(parent) - .fetch_one(&self.pool) - .await - .map_err(fix_error); + and idx >= $4 + order by idx asc + limit $5", + ) + .bind(user.id) + .bind(tag.clone()) + .bind(host) + .bind(start as i64) + .bind(count as i64) + .fetch_all(&self.pool) + .await + .map_err(fix_error); - match record { - Ok(record) => { - let record: Record = record.into(); - ret.push(record.clone()); + let ret = match records { + Ok(records) => { + let records: Vec> = records + .into_iter() + .map(|f| { + let record: Record = f.into(); + record + }) + .collect(); - parent = Some(record.id); - } - Err(DbError::NotFound) => { - tracing::debug!("hit tail of store: {:?}/{}", host, tag); - return Ok(ret); - } - Err(e) => return Err(e), + records } - } + Err(DbError::NotFound) => { + tracing::debug!("hit end of store: {:?}/{}", host, tag); + return Ok(ret); + } + Err(e) => return Err(e), + }; Ok(ret) } - async fn tail_records(&self, user: &User) -> DbResult { - const TAIL_RECORDS_SQL: &str = "select host, tag, client_id from records rp where (select count(1) from records where parent=rp.client_id and user_id = $1) = 0 and user_id = $1;"; + async fn status(&self, user: &User) -> DbResult { + const STATUS_SQL: &str = + "select host, tag, max(idx) from store where user_id = $1 group by host, tag"; - let res = sqlx::query_as(TAIL_RECORDS_SQL) + let res: Vec<(Uuid, String, i64)> = sqlx::query_as(STATUS_SQL) .bind(user.id) - .fetch(&self.pool) - .try_collect() + .fetch_all(&self.pool) .await .map_err(fix_error)?; - Ok(res) + let mut status = RecordStatus::new(); + + for i in res { + status.set_raw(HostId(i.0), i.1, i.2 as u64); + } + + Ok(status) } } diff --git a/atuin-server-postgres/src/wrappers.rs b/atuin-server-postgres/src/wrappers.rs index b4ae48ae..3ccf9c19 100644 --- a/atuin-server-postgres/src/wrappers.rs +++ b/atuin-server-postgres/src/wrappers.rs @@ -1,5 +1,5 @@ use ::sqlx::{FromRow, Result}; -use atuin_common::record::{EncryptedData, Record}; +use atuin_common::record::{EncryptedData, Host, Record}; use atuin_server_database::models::{History, Session, User}; use sqlx::{postgres::PgRow, Row}; use time::PrimitiveDateTime; @@ -51,6 +51,7 @@ impl<'a> ::sqlx::FromRow<'a, PgRow> for DbHistory { impl<'a> ::sqlx::FromRow<'a, PgRow> for DbRecord { fn from_row(row: &'a PgRow) -> ::sqlx::Result { let timestamp: i64 = row.try_get("timestamp")?; + let idx: i64 = row.try_get("idx")?; let data = EncryptedData { data: row.try_get("data")?, @@ -59,8 +60,8 @@ impl<'a> ::sqlx::FromRow<'a, PgRow> for DbRecord { Ok(Self(Record { id: row.try_get("client_id")?, - host: row.try_get("host")?, - parent: row.try_get("parent")?, + host: Host::new(row.try_get("host")?), + idx: idx as u64, timestamp: timestamp as u64, version: row.try_get("version")?, tag: row.try_get("tag")?, diff --git a/atuin-server/src/handlers/record.rs b/atuin-server/src/handlers/record.rs index 91b937b3..473e3206 100644 --- a/atuin-server/src/handlers/record.rs +++ b/atuin-server/src/handlers/record.rs @@ -8,7 +8,7 @@ use super::{ErrorResponse, ErrorResponseStatus, RespExt}; use crate::router::{AppState, UserAuth}; use atuin_server_database::Database; -use atuin_common::record::{EncryptedData, HostId, Record, RecordId, RecordIndex}; +use atuin_common::record::{EncryptedData, HostId, Record, RecordId, RecordIdx, RecordStatus}; #[instrument(skip_all, fields(user.id = user.id))] pub async fn post( @@ -53,13 +53,13 @@ pub async fn post( pub async fn index( UserAuth(user): UserAuth, state: State>, -) -> Result, ErrorResponseStatus<'static>> { +) -> Result, ErrorResponseStatus<'static>> { let State(AppState { database, settings: _, }) = state; - let record_index = match database.tail_records(&user).await { + let record_index = match database.status(&user).await { Ok(index) => index, Err(e) => { error!("failed to get record index: {}", e); @@ -76,7 +76,7 @@ pub async fn index( pub struct NextParams { host: HostId, tag: String, - start: Option, + start: Option, count: u64, } diff --git a/atuin/src/command/client/sync.rs b/atuin/src/command/client/sync.rs index 0b796804..cdb0f214 100644 --- a/atuin/src/command/client/sync.rs +++ b/atuin/src/command/client/sync.rs @@ -79,8 +79,7 @@ async fn run( ) -> Result<()> { let (diff, remote_index) = sync::diff(settings, store).await?; let operations = sync::operations(diff, store).await?; - let (uploaded, downloaded) = - sync::sync_remote(operations, &remote_index, store, settings).await?; + let (uploaded, downloaded) = sync::sync_remote(operations, store, settings).await?; println!("{uploaded}/{downloaded} up/down to record store");