This commit is contained in:
Ellie Huxtable
2023-12-02 22:20:00 +00:00
parent 3193ede1fe
commit eff7de4720
9 changed files with 147 additions and 84 deletions

View File

@ -8,7 +8,7 @@ use reqwest::{
StatusCode, Url,
};
use atuin_common::record::{EncryptedData, HostId, Record, RecordId};
use atuin_common::record::{EncryptedData, HostId, Record, RecordId, RecordIdx};
use atuin_common::{
api::{
AddHistoryRequest, CountResponse, DeleteHistoryRequest, ErrorResponse, IndexResponse,
@ -231,14 +231,14 @@ impl<'a> Client<'a> {
&self,
host: HostId,
tag: String,
start: Option<RecordId>,
start: RecordIdx,
count: u64,
) -> Result<Vec<Record<EncryptedData>>> {
debug!(
"fetching record/s from host {}/{}/{}",
host.0.to_string(),
tag,
start.map_or(String::from("empty"), |f| f.0.to_string())
start
);
let url = format!(

View File

@ -170,9 +170,9 @@ impl Store for SqliteStore {
limit: u64,
) -> Result<Vec<Record<EncryptedData>>> {
let res =
sqlx::query("select * from store where idx > ?1 and host = ?2 and tag = ?3 limit ?4")
sqlx::query("select * from store where idx >= ?1 and host = ?2 and tag = ?3 limit ?4")
.bind(idx as i64)
.bind(host)
.bind(host.0.as_hyphenated().to_string())
.bind(tag)
.bind(limit as i64)
.map(Self::query_row)

View File

@ -12,11 +12,14 @@ pub enum SyncError {
#[error("the local store is ahead of the remote, but for another host. has remote lost data?")]
LocalAheadOtherHost,
#[error("some issue with the local database occured")]
#[error("an issue with the local database occured")]
LocalStoreError,
#[error("something has gone wrong with the sync logic: {msg:?}")]
SyncLogicError { msg: String },
#[error("a request to the sync server failed")]
RemoteRequestError,
}
#[derive(Debug, Eq, PartialEq)]
@ -63,7 +66,10 @@ pub async fn diff(
// With the store as context, we can determine if a tail exists locally or not and therefore if it needs uploading or download.
// In theory this could be done as a part of the diffing stage, but it's easier to reason
// about and test this way
pub async fn operations(diffs: Vec<Diff>, _store: &impl Store) -> Result<Vec<Operation>, SyncError> {
pub async fn operations(
diffs: Vec<Diff>,
_store: &impl Store,
) -> Result<Vec<Operation>, SyncError> {
let mut operations = Vec::with_capacity(diffs.len());
let _host = Settings::host_id().expect("got to record sync without a host id; abort");
@ -141,22 +147,16 @@ pub async fn operations(diffs: Vec<Diff>, _store: &impl Store) -> Result<Vec<Ope
}
async fn sync_upload(
_store: &mut impl Store,
_client: &Client<'_>,
store: &mut impl Store,
client: &Client<'_>,
host: HostId,
tag: String,
local: RecordIdx,
remote: Option<RecordIdx>,
) -> Result<i64, SyncError> {
let expected = local - remote.unwrap_or(0);
let _upload_page_size = 100;
let _total = 0;
if expected < 0 {
return Err(SyncError::SyncLogicError {
msg: String::from("ran upload, but remote ahead of local"),
});
}
let upload_page_size = 100;
let mut progress = 0;
println!(
"Uploading {} records to {}/{}",
@ -165,28 +165,47 @@ async fn sync_upload(
tag
);
// TODO: actually upload lmfao
// preload with the first entry if remote does not know of this store
while progress < expected {
let page = store
.next(
host,
tag.as_str(),
remote.unwrap_or(0) + progress,
upload_page_size,
)
.await
.map_err(|_| SyncError::LocalStoreError)?;
Ok(0)
let _ = client
.post_records(&page)
.await
.map_err(|_| SyncError::RemoteRequestError)?;
println!(
"uploaded {} to remote, progress {}/{}",
page.len(),
progress,
expected
);
progress += page.len() as u64;
}
Ok(progress as i64)
}
async fn sync_download(
_store: &mut impl Store,
_client: &Client<'_>,
store: &mut impl Store,
client: &Client<'_>,
host: HostId,
tag: String,
local: Option<RecordIdx>,
remote: RecordIdx,
) -> Result<i64, SyncError> {
let local = local.unwrap_or(0);
let expected = remote - local.unwrap_or(0);
let _download_page_size = 100;
let _total = 0;
if expected < 0 {
return Err(SyncError::SyncLogicError {
msg: String::from("ran download, but local ahead of remote"),
});
}
let download_page_size = 100;
let mut progress = 0;
println!(
"Downloading {} records from {}/{}",
@ -195,7 +214,25 @@ async fn sync_download(
tag
);
// TODO: actually upload lmfao
// preload with the first entry if remote does not know of this store
while progress < expected {
let page = client.next_records(host, tag, Some(local + progress), download_page_size);
let _ = client
.post_records(&page)
.await
.map_err(|_| SyncError::RemoteRequestError)?;
println!(
"uploaded {} to remote, progress {}/{}",
page.len(),
progress,
expected
);
progress += page.len() as u64;
}
Ok(progress as i64)
Ok(0)
}
@ -204,13 +241,14 @@ pub async fn sync_remote(
operations: Vec<Operation>,
local_store: &mut impl Store,
settings: &Settings,
) -> Result<(i64, i64)> {
) -> Result<(i64, i64), SyncError> {
let client = Client::new(
&settings.sync_address,
&settings.session_token,
settings.network_connect_timeout,
settings.network_timeout,
)?;
)
.expect("failed to create client");
let mut uploaded = 0;
let mut downloaded = 0;

View File

@ -14,7 +14,7 @@ use self::{
models::{History, NewHistory, NewSession, NewUser, Session, User},
};
use async_trait::async_trait;
use atuin_common::record::{EncryptedData, HostId, Record, RecordId, RecordStatus};
use atuin_common::record::{EncryptedData, HostId, Record, RecordId, RecordIdx, RecordStatus};
use serde::{de::DeserializeOwned, Serialize};
use time::{Date, Duration, Month, OffsetDateTime, Time, UtcOffset};
use tracing::instrument;
@ -68,12 +68,12 @@ pub trait Database: Sized + Clone + Send + Sync + 'static {
user: &User,
host: HostId,
tag: String,
start: Option<RecordId>,
start: Option<RecordIdx>,
count: u64,
) -> DbResult<Vec<Record<EncryptedData>>>;
// Return the tail record ID for each store, so (HostID, Tag, TailRecordID)
async fn tail_records(&self, user: &User) -> DbResult<RecordStatus>;
async fn status(&self, user: &User) -> DbResult<RecordStatus>;
async fn count_history_range(&self, user: &User, range: Range<OffsetDateTime>)
-> DbResult<i64>;

View File

@ -0,0 +1,15 @@
-- Add migration script here
create table store (
id uuid primary key, -- remember to use uuidv7 for happy indices <3
client_id uuid not null, -- I am too uncomfortable with the idea of a client-generated primary key, even though it's fine mathematically
host uuid not null, -- a unique identifier for the host
idx bigint not null, -- the index of the record in this store, identified by (host, tag)
timestamp bigint not null, -- not a timestamp type, as those do not have nanosecond precision
version text not null,
tag text not null, -- what is this? history, kv, whatever. Remember clients get a log per tag per host
data text not null, -- store the actual history data, encrypted. I don't wanna know!
cek text not null,
user_id bigint not null, -- allow multiple users
created_at timestamp not null default current_timestamp
);

View File

@ -1,7 +1,7 @@
use std::ops::Range;
use async_trait::async_trait;
use atuin_common::record::{EncryptedData, HostId, Record, RecordId, RecordIndex};
use atuin_common::record::{EncryptedData, HostId, Record, RecordId, RecordIdx, RecordStatus};
use atuin_server_database::models::{History, NewHistory, NewSession, NewUser, Session, User};
use atuin_server_database::{Database, DbError, DbResult};
use futures_util::TryStreamExt;
@ -11,6 +11,7 @@ use sqlx::Row;
use time::{OffsetDateTime, PrimitiveDateTime, UtcOffset};
use tracing::instrument;
use uuid::Uuid;
use wrappers::{DbHistory, DbRecord, DbSession, DbUser};
mod wrappers;
@ -361,16 +362,16 @@ impl Database for Postgres {
let id = atuin_common::utils::uuid_v7();
sqlx::query(
"insert into records
(id, client_id, host, parent, timestamp, version, tag, data, cek, user_id)
"insert into store
(id, client_id, host, idx, timestamp, version, tag, data, cek, user_id)
values ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10)
on conflict do nothing
",
)
.bind(id)
.bind(i.id)
.bind(i.host)
.bind(i.parent)
.bind(i.host.id)
.bind(i.idx as i64)
.bind(i.timestamp as i64) // throwing away some data, but i64 is still big in terms of time
.bind(&i.version)
.bind(&i.tag)
@ -393,62 +394,71 @@ impl Database for Postgres {
user: &User,
host: HostId,
tag: String,
start: Option<RecordId>,
start: Option<RecordIdx>,
count: u64,
) -> DbResult<Vec<Record<EncryptedData>>> {
tracing::debug!("{:?} - {:?} - {:?}", host, tag, start);
let mut ret = Vec::with_capacity(count as usize);
let mut parent = start;
let start = start.unwrap_or(0);
// yeah let's do something better
for _ in 0..count {
// a very much not ideal query. but it's simple at least?
// we are basically using postgres as a kv store here, so... maybe consider using an actual
// kv store?
let record: Result<DbRecord, DbError> = sqlx::query_as(
"select client_id, host, parent, timestamp, version, tag, data, cek from records
let records: Result<Vec<DbRecord>, DbError> = sqlx::query_as(
"select client_id, host, idx, timestamp, version, tag, data, cek from store
where user_id = $1
and tag = $2
and host = $3
and parent is not distinct from $4",
)
.bind(user.id)
.bind(tag.clone())
.bind(host)
.bind(parent)
.fetch_one(&self.pool)
.await
.map_err(fix_error);
and idx >= $4
order by idx asc
limit $5",
)
.bind(user.id)
.bind(tag.clone())
.bind(host)
.bind(start as i64)
.bind(count as i64)
.fetch_all(&self.pool)
.await
.map_err(fix_error);
match record {
Ok(record) => {
let record: Record<EncryptedData> = record.into();
ret.push(record.clone());
let ret = match records {
Ok(records) => {
let records: Vec<Record<EncryptedData>> = records
.into_iter()
.map(|f| {
let record: Record<EncryptedData> = f.into();
record
})
.collect();
parent = Some(record.id);
}
Err(DbError::NotFound) => {
tracing::debug!("hit tail of store: {:?}/{}", host, tag);
return Ok(ret);
}
Err(e) => return Err(e),
records
}
}
Err(DbError::NotFound) => {
tracing::debug!("hit end of store: {:?}/{}", host, tag);
return Ok(ret);
}
Err(e) => return Err(e),
};
Ok(ret)
}
async fn tail_records(&self, user: &User) -> DbResult<RecordIndex> {
const TAIL_RECORDS_SQL: &str = "select host, tag, client_id from records rp where (select count(1) from records where parent=rp.client_id and user_id = $1) = 0 and user_id = $1;";
async fn status(&self, user: &User) -> DbResult<RecordStatus> {
const STATUS_SQL: &str =
"select host, tag, max(idx) from store where user_id = $1 group by host, tag";
let res = sqlx::query_as(TAIL_RECORDS_SQL)
let res: Vec<(Uuid, String, i64)> = sqlx::query_as(STATUS_SQL)
.bind(user.id)
.fetch(&self.pool)
.try_collect()
.fetch_all(&self.pool)
.await
.map_err(fix_error)?;
Ok(res)
let mut status = RecordStatus::new();
for i in res {
status.set_raw(HostId(i.0), i.1, i.2 as u64);
}
Ok(status)
}
}

View File

@ -1,5 +1,5 @@
use ::sqlx::{FromRow, Result};
use atuin_common::record::{EncryptedData, Record};
use atuin_common::record::{EncryptedData, Host, Record};
use atuin_server_database::models::{History, Session, User};
use sqlx::{postgres::PgRow, Row};
use time::PrimitiveDateTime;
@ -51,6 +51,7 @@ impl<'a> ::sqlx::FromRow<'a, PgRow> for DbHistory {
impl<'a> ::sqlx::FromRow<'a, PgRow> for DbRecord {
fn from_row(row: &'a PgRow) -> ::sqlx::Result<Self> {
let timestamp: i64 = row.try_get("timestamp")?;
let idx: i64 = row.try_get("idx")?;
let data = EncryptedData {
data: row.try_get("data")?,
@ -59,8 +60,8 @@ impl<'a> ::sqlx::FromRow<'a, PgRow> for DbRecord {
Ok(Self(Record {
id: row.try_get("client_id")?,
host: row.try_get("host")?,
parent: row.try_get("parent")?,
host: Host::new(row.try_get("host")?),
idx: idx as u64,
timestamp: timestamp as u64,
version: row.try_get("version")?,
tag: row.try_get("tag")?,

View File

@ -8,7 +8,7 @@ use super::{ErrorResponse, ErrorResponseStatus, RespExt};
use crate::router::{AppState, UserAuth};
use atuin_server_database::Database;
use atuin_common::record::{EncryptedData, HostId, Record, RecordId, RecordIndex};
use atuin_common::record::{EncryptedData, HostId, Record, RecordId, RecordIdx, RecordStatus};
#[instrument(skip_all, fields(user.id = user.id))]
pub async fn post<DB: Database>(
@ -53,13 +53,13 @@ pub async fn post<DB: Database>(
pub async fn index<DB: Database>(
UserAuth(user): UserAuth,
state: State<AppState<DB>>,
) -> Result<Json<RecordIndex>, ErrorResponseStatus<'static>> {
) -> Result<Json<RecordStatus>, ErrorResponseStatus<'static>> {
let State(AppState {
database,
settings: _,
}) = state;
let record_index = match database.tail_records(&user).await {
let record_index = match database.status(&user).await {
Ok(index) => index,
Err(e) => {
error!("failed to get record index: {}", e);
@ -76,7 +76,7 @@ pub async fn index<DB: Database>(
pub struct NextParams {
host: HostId,
tag: String,
start: Option<RecordId>,
start: Option<RecordIdx>,
count: u64,
}

View File

@ -79,8 +79,7 @@ async fn run(
) -> Result<()> {
let (diff, remote_index) = sync::diff(settings, store).await?;
let operations = sync::operations(diff, store).await?;
let (uploaded, downloaded) =
sync::sync_remote(operations, &remote_index, store, settings).await?;
let (uploaded, downloaded) = sync::sync_remote(operations, store, settings).await?;
println!("{uploaded}/{downloaded} up/down to record store");