This commit is contained in:
Ellie Huxtable
2023-12-02 22:20:00 +00:00
parent 3193ede1fe
commit eff7de4720
9 changed files with 147 additions and 84 deletions

View File

@ -8,7 +8,7 @@ use reqwest::{
StatusCode, Url, StatusCode, Url,
}; };
use atuin_common::record::{EncryptedData, HostId, Record, RecordId}; use atuin_common::record::{EncryptedData, HostId, Record, RecordId, RecordIdx};
use atuin_common::{ use atuin_common::{
api::{ api::{
AddHistoryRequest, CountResponse, DeleteHistoryRequest, ErrorResponse, IndexResponse, AddHistoryRequest, CountResponse, DeleteHistoryRequest, ErrorResponse, IndexResponse,
@ -231,14 +231,14 @@ impl<'a> Client<'a> {
&self, &self,
host: HostId, host: HostId,
tag: String, tag: String,
start: Option<RecordId>, start: RecordIdx,
count: u64, count: u64,
) -> Result<Vec<Record<EncryptedData>>> { ) -> Result<Vec<Record<EncryptedData>>> {
debug!( debug!(
"fetching record/s from host {}/{}/{}", "fetching record/s from host {}/{}/{}",
host.0.to_string(), host.0.to_string(),
tag, tag,
start.map_or(String::from("empty"), |f| f.0.to_string()) start
); );
let url = format!( let url = format!(

View File

@ -170,9 +170,9 @@ impl Store for SqliteStore {
limit: u64, limit: u64,
) -> Result<Vec<Record<EncryptedData>>> { ) -> Result<Vec<Record<EncryptedData>>> {
let res = let res =
sqlx::query("select * from store where idx > ?1 and host = ?2 and tag = ?3 limit ?4") sqlx::query("select * from store where idx >= ?1 and host = ?2 and tag = ?3 limit ?4")
.bind(idx as i64) .bind(idx as i64)
.bind(host) .bind(host.0.as_hyphenated().to_string())
.bind(tag) .bind(tag)
.bind(limit as i64) .bind(limit as i64)
.map(Self::query_row) .map(Self::query_row)

View File

@ -12,11 +12,14 @@ pub enum SyncError {
#[error("the local store is ahead of the remote, but for another host. has remote lost data?")] #[error("the local store is ahead of the remote, but for another host. has remote lost data?")]
LocalAheadOtherHost, LocalAheadOtherHost,
#[error("some issue with the local database occured")] #[error("an issue with the local database occured")]
LocalStoreError, LocalStoreError,
#[error("something has gone wrong with the sync logic: {msg:?}")] #[error("something has gone wrong with the sync logic: {msg:?}")]
SyncLogicError { msg: String }, SyncLogicError { msg: String },
#[error("a request to the sync server failed")]
RemoteRequestError,
} }
#[derive(Debug, Eq, PartialEq)] #[derive(Debug, Eq, PartialEq)]
@ -63,7 +66,10 @@ pub async fn diff(
// With the store as context, we can determine if a tail exists locally or not and therefore if it needs uploading or download. // With the store as context, we can determine if a tail exists locally or not and therefore if it needs uploading or download.
// In theory this could be done as a part of the diffing stage, but it's easier to reason // In theory this could be done as a part of the diffing stage, but it's easier to reason
// about and test this way // about and test this way
pub async fn operations(diffs: Vec<Diff>, _store: &impl Store) -> Result<Vec<Operation>, SyncError> { pub async fn operations(
diffs: Vec<Diff>,
_store: &impl Store,
) -> Result<Vec<Operation>, SyncError> {
let mut operations = Vec::with_capacity(diffs.len()); let mut operations = Vec::with_capacity(diffs.len());
let _host = Settings::host_id().expect("got to record sync without a host id; abort"); let _host = Settings::host_id().expect("got to record sync without a host id; abort");
@ -141,22 +147,16 @@ pub async fn operations(diffs: Vec<Diff>, _store: &impl Store) -> Result<Vec<Ope
} }
async fn sync_upload( async fn sync_upload(
_store: &mut impl Store, store: &mut impl Store,
_client: &Client<'_>, client: &Client<'_>,
host: HostId, host: HostId,
tag: String, tag: String,
local: RecordIdx, local: RecordIdx,
remote: Option<RecordIdx>, remote: Option<RecordIdx>,
) -> Result<i64, SyncError> { ) -> Result<i64, SyncError> {
let expected = local - remote.unwrap_or(0); let expected = local - remote.unwrap_or(0);
let _upload_page_size = 100; let upload_page_size = 100;
let _total = 0; let mut progress = 0;
if expected < 0 {
return Err(SyncError::SyncLogicError {
msg: String::from("ran upload, but remote ahead of local"),
});
}
println!( println!(
"Uploading {} records to {}/{}", "Uploading {} records to {}/{}",
@ -165,28 +165,47 @@ async fn sync_upload(
tag tag
); );
// TODO: actually upload lmfao // preload with the first entry if remote does not know of this store
while progress < expected {
let page = store
.next(
host,
tag.as_str(),
remote.unwrap_or(0) + progress,
upload_page_size,
)
.await
.map_err(|_| SyncError::LocalStoreError)?;
Ok(0) let _ = client
.post_records(&page)
.await
.map_err(|_| SyncError::RemoteRequestError)?;
println!(
"uploaded {} to remote, progress {}/{}",
page.len(),
progress,
expected
);
progress += page.len() as u64;
}
Ok(progress as i64)
} }
async fn sync_download( async fn sync_download(
_store: &mut impl Store, store: &mut impl Store,
_client: &Client<'_>, client: &Client<'_>,
host: HostId, host: HostId,
tag: String, tag: String,
local: Option<RecordIdx>, local: Option<RecordIdx>,
remote: RecordIdx, remote: RecordIdx,
) -> Result<i64, SyncError> { ) -> Result<i64, SyncError> {
let local = local.unwrap_or(0);
let expected = remote - local.unwrap_or(0); let expected = remote - local.unwrap_or(0);
let _download_page_size = 100; let download_page_size = 100;
let _total = 0; let mut progress = 0;
if expected < 0 {
return Err(SyncError::SyncLogicError {
msg: String::from("ran download, but local ahead of remote"),
});
}
println!( println!(
"Downloading {} records from {}/{}", "Downloading {} records from {}/{}",
@ -195,7 +214,25 @@ async fn sync_download(
tag tag
); );
// TODO: actually upload lmfao // preload with the first entry if remote does not know of this store
while progress < expected {
let page = client.next_records(host, tag, Some(local + progress), download_page_size);
let _ = client
.post_records(&page)
.await
.map_err(|_| SyncError::RemoteRequestError)?;
println!(
"uploaded {} to remote, progress {}/{}",
page.len(),
progress,
expected
);
progress += page.len() as u64;
}
Ok(progress as i64)
Ok(0) Ok(0)
} }
@ -204,13 +241,14 @@ pub async fn sync_remote(
operations: Vec<Operation>, operations: Vec<Operation>,
local_store: &mut impl Store, local_store: &mut impl Store,
settings: &Settings, settings: &Settings,
) -> Result<(i64, i64)> { ) -> Result<(i64, i64), SyncError> {
let client = Client::new( let client = Client::new(
&settings.sync_address, &settings.sync_address,
&settings.session_token, &settings.session_token,
settings.network_connect_timeout, settings.network_connect_timeout,
settings.network_timeout, settings.network_timeout,
)?; )
.expect("failed to create client");
let mut uploaded = 0; let mut uploaded = 0;
let mut downloaded = 0; let mut downloaded = 0;

View File

@ -14,7 +14,7 @@ use self::{
models::{History, NewHistory, NewSession, NewUser, Session, User}, models::{History, NewHistory, NewSession, NewUser, Session, User},
}; };
use async_trait::async_trait; use async_trait::async_trait;
use atuin_common::record::{EncryptedData, HostId, Record, RecordId, RecordStatus}; use atuin_common::record::{EncryptedData, HostId, Record, RecordId, RecordIdx, RecordStatus};
use serde::{de::DeserializeOwned, Serialize}; use serde::{de::DeserializeOwned, Serialize};
use time::{Date, Duration, Month, OffsetDateTime, Time, UtcOffset}; use time::{Date, Duration, Month, OffsetDateTime, Time, UtcOffset};
use tracing::instrument; use tracing::instrument;
@ -68,12 +68,12 @@ pub trait Database: Sized + Clone + Send + Sync + 'static {
user: &User, user: &User,
host: HostId, host: HostId,
tag: String, tag: String,
start: Option<RecordId>, start: Option<RecordIdx>,
count: u64, count: u64,
) -> DbResult<Vec<Record<EncryptedData>>>; ) -> DbResult<Vec<Record<EncryptedData>>>;
// Return the tail record ID for each store, so (HostID, Tag, TailRecordID) // Return the tail record ID for each store, so (HostID, Tag, TailRecordID)
async fn tail_records(&self, user: &User) -> DbResult<RecordStatus>; async fn status(&self, user: &User) -> DbResult<RecordStatus>;
async fn count_history_range(&self, user: &User, range: Range<OffsetDateTime>) async fn count_history_range(&self, user: &User, range: Range<OffsetDateTime>)
-> DbResult<i64>; -> DbResult<i64>;

View File

@ -0,0 +1,15 @@
-- Add migration script here
create table store (
id uuid primary key, -- remember to use uuidv7 for happy indices <3
client_id uuid not null, -- I am too uncomfortable with the idea of a client-generated primary key, even though it's fine mathematically
host uuid not null, -- a unique identifier for the host
idx bigint not null, -- the index of the record in this store, identified by (host, tag)
timestamp bigint not null, -- not a timestamp type, as those do not have nanosecond precision
version text not null,
tag text not null, -- what is this? history, kv, whatever. Remember clients get a log per tag per host
data text not null, -- store the actual history data, encrypted. I don't wanna know!
cek text not null,
user_id bigint not null, -- allow multiple users
created_at timestamp not null default current_timestamp
);

View File

@ -1,7 +1,7 @@
use std::ops::Range; use std::ops::Range;
use async_trait::async_trait; use async_trait::async_trait;
use atuin_common::record::{EncryptedData, HostId, Record, RecordId, RecordIndex}; use atuin_common::record::{EncryptedData, HostId, Record, RecordId, RecordIdx, RecordStatus};
use atuin_server_database::models::{History, NewHistory, NewSession, NewUser, Session, User}; use atuin_server_database::models::{History, NewHistory, NewSession, NewUser, Session, User};
use atuin_server_database::{Database, DbError, DbResult}; use atuin_server_database::{Database, DbError, DbResult};
use futures_util::TryStreamExt; use futures_util::TryStreamExt;
@ -11,6 +11,7 @@ use sqlx::Row;
use time::{OffsetDateTime, PrimitiveDateTime, UtcOffset}; use time::{OffsetDateTime, PrimitiveDateTime, UtcOffset};
use tracing::instrument; use tracing::instrument;
use uuid::Uuid;
use wrappers::{DbHistory, DbRecord, DbSession, DbUser}; use wrappers::{DbHistory, DbRecord, DbSession, DbUser};
mod wrappers; mod wrappers;
@ -361,16 +362,16 @@ impl Database for Postgres {
let id = atuin_common::utils::uuid_v7(); let id = atuin_common::utils::uuid_v7();
sqlx::query( sqlx::query(
"insert into records "insert into store
(id, client_id, host, parent, timestamp, version, tag, data, cek, user_id) (id, client_id, host, idx, timestamp, version, tag, data, cek, user_id)
values ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10) values ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10)
on conflict do nothing on conflict do nothing
", ",
) )
.bind(id) .bind(id)
.bind(i.id) .bind(i.id)
.bind(i.host) .bind(i.host.id)
.bind(i.parent) .bind(i.idx as i64)
.bind(i.timestamp as i64) // throwing away some data, but i64 is still big in terms of time .bind(i.timestamp as i64) // throwing away some data, but i64 is still big in terms of time
.bind(&i.version) .bind(&i.version)
.bind(&i.tag) .bind(&i.tag)
@ -393,62 +394,71 @@ impl Database for Postgres {
user: &User, user: &User,
host: HostId, host: HostId,
tag: String, tag: String,
start: Option<RecordId>, start: Option<RecordIdx>,
count: u64, count: u64,
) -> DbResult<Vec<Record<EncryptedData>>> { ) -> DbResult<Vec<Record<EncryptedData>>> {
tracing::debug!("{:?} - {:?} - {:?}", host, tag, start); tracing::debug!("{:?} - {:?} - {:?}", host, tag, start);
let mut ret = Vec::with_capacity(count as usize); let mut ret = Vec::with_capacity(count as usize);
let mut parent = start; let mut parent = start;
let start = start.unwrap_or(0);
// yeah let's do something better let records: Result<Vec<DbRecord>, DbError> = sqlx::query_as(
for _ in 0..count { "select client_id, host, idx, timestamp, version, tag, data, cek from store
// a very much not ideal query. but it's simple at least?
// we are basically using postgres as a kv store here, so... maybe consider using an actual
// kv store?
let record: Result<DbRecord, DbError> = sqlx::query_as(
"select client_id, host, parent, timestamp, version, tag, data, cek from records
where user_id = $1 where user_id = $1
and tag = $2 and tag = $2
and host = $3 and host = $3
and parent is not distinct from $4", and idx >= $4
) order by idx asc
.bind(user.id) limit $5",
.bind(tag.clone()) )
.bind(host) .bind(user.id)
.bind(parent) .bind(tag.clone())
.fetch_one(&self.pool) .bind(host)
.await .bind(start as i64)
.map_err(fix_error); .bind(count as i64)
.fetch_all(&self.pool)
.await
.map_err(fix_error);
match record { let ret = match records {
Ok(record) => { Ok(records) => {
let record: Record<EncryptedData> = record.into(); let records: Vec<Record<EncryptedData>> = records
ret.push(record.clone()); .into_iter()
.map(|f| {
let record: Record<EncryptedData> = f.into();
record
})
.collect();
parent = Some(record.id); records
}
Err(DbError::NotFound) => {
tracing::debug!("hit tail of store: {:?}/{}", host, tag);
return Ok(ret);
}
Err(e) => return Err(e),
} }
} Err(DbError::NotFound) => {
tracing::debug!("hit end of store: {:?}/{}", host, tag);
return Ok(ret);
}
Err(e) => return Err(e),
};
Ok(ret) Ok(ret)
} }
async fn tail_records(&self, user: &User) -> DbResult<RecordIndex> { async fn status(&self, user: &User) -> DbResult<RecordStatus> {
const TAIL_RECORDS_SQL: &str = "select host, tag, client_id from records rp where (select count(1) from records where parent=rp.client_id and user_id = $1) = 0 and user_id = $1;"; const STATUS_SQL: &str =
"select host, tag, max(idx) from store where user_id = $1 group by host, tag";
let res = sqlx::query_as(TAIL_RECORDS_SQL) let res: Vec<(Uuid, String, i64)> = sqlx::query_as(STATUS_SQL)
.bind(user.id) .bind(user.id)
.fetch(&self.pool) .fetch_all(&self.pool)
.try_collect()
.await .await
.map_err(fix_error)?; .map_err(fix_error)?;
Ok(res) let mut status = RecordStatus::new();
for i in res {
status.set_raw(HostId(i.0), i.1, i.2 as u64);
}
Ok(status)
} }
} }

View File

@ -1,5 +1,5 @@
use ::sqlx::{FromRow, Result}; use ::sqlx::{FromRow, Result};
use atuin_common::record::{EncryptedData, Record}; use atuin_common::record::{EncryptedData, Host, Record};
use atuin_server_database::models::{History, Session, User}; use atuin_server_database::models::{History, Session, User};
use sqlx::{postgres::PgRow, Row}; use sqlx::{postgres::PgRow, Row};
use time::PrimitiveDateTime; use time::PrimitiveDateTime;
@ -51,6 +51,7 @@ impl<'a> ::sqlx::FromRow<'a, PgRow> for DbHistory {
impl<'a> ::sqlx::FromRow<'a, PgRow> for DbRecord { impl<'a> ::sqlx::FromRow<'a, PgRow> for DbRecord {
fn from_row(row: &'a PgRow) -> ::sqlx::Result<Self> { fn from_row(row: &'a PgRow) -> ::sqlx::Result<Self> {
let timestamp: i64 = row.try_get("timestamp")?; let timestamp: i64 = row.try_get("timestamp")?;
let idx: i64 = row.try_get("idx")?;
let data = EncryptedData { let data = EncryptedData {
data: row.try_get("data")?, data: row.try_get("data")?,
@ -59,8 +60,8 @@ impl<'a> ::sqlx::FromRow<'a, PgRow> for DbRecord {
Ok(Self(Record { Ok(Self(Record {
id: row.try_get("client_id")?, id: row.try_get("client_id")?,
host: row.try_get("host")?, host: Host::new(row.try_get("host")?),
parent: row.try_get("parent")?, idx: idx as u64,
timestamp: timestamp as u64, timestamp: timestamp as u64,
version: row.try_get("version")?, version: row.try_get("version")?,
tag: row.try_get("tag")?, tag: row.try_get("tag")?,

View File

@ -8,7 +8,7 @@ use super::{ErrorResponse, ErrorResponseStatus, RespExt};
use crate::router::{AppState, UserAuth}; use crate::router::{AppState, UserAuth};
use atuin_server_database::Database; use atuin_server_database::Database;
use atuin_common::record::{EncryptedData, HostId, Record, RecordId, RecordIndex}; use atuin_common::record::{EncryptedData, HostId, Record, RecordId, RecordIdx, RecordStatus};
#[instrument(skip_all, fields(user.id = user.id))] #[instrument(skip_all, fields(user.id = user.id))]
pub async fn post<DB: Database>( pub async fn post<DB: Database>(
@ -53,13 +53,13 @@ pub async fn post<DB: Database>(
pub async fn index<DB: Database>( pub async fn index<DB: Database>(
UserAuth(user): UserAuth, UserAuth(user): UserAuth,
state: State<AppState<DB>>, state: State<AppState<DB>>,
) -> Result<Json<RecordIndex>, ErrorResponseStatus<'static>> { ) -> Result<Json<RecordStatus>, ErrorResponseStatus<'static>> {
let State(AppState { let State(AppState {
database, database,
settings: _, settings: _,
}) = state; }) = state;
let record_index = match database.tail_records(&user).await { let record_index = match database.status(&user).await {
Ok(index) => index, Ok(index) => index,
Err(e) => { Err(e) => {
error!("failed to get record index: {}", e); error!("failed to get record index: {}", e);
@ -76,7 +76,7 @@ pub async fn index<DB: Database>(
pub struct NextParams { pub struct NextParams {
host: HostId, host: HostId,
tag: String, tag: String,
start: Option<RecordId>, start: Option<RecordIdx>,
count: u64, count: u64,
} }

View File

@ -79,8 +79,7 @@ async fn run(
) -> Result<()> { ) -> Result<()> {
let (diff, remote_index) = sync::diff(settings, store).await?; let (diff, remote_index) = sync::diff(settings, store).await?;
let operations = sync::operations(diff, store).await?; let operations = sync::operations(diff, store).await?;
let (uploaded, downloaded) = let (uploaded, downloaded) = sync::sync_remote(operations, store, settings).await?;
sync::sync_remote(operations, &remote_index, store, settings).await?;
println!("{uploaded}/{downloaded} up/down to record store"); println!("{uploaded}/{downloaded} up/down to record store");