Record downloading works

This commit is contained in:
Ellie Huxtable 2023-07-07 09:05:25 +01:00
parent a4066d7255
commit 0c20816662
7 changed files with 136 additions and 10 deletions

View File

@ -110,6 +110,8 @@ async fn sync_upload(
.id .id
}; };
debug!("starting push to remote from: {}", start);
// we have the start point for sync. it is either the head of the store if // we have the start point for sync. it is either the head of the store if
// the remote has no data for it, or the tail that the remote has // the remote has no data for it, or the tail that the remote has
// we need to iterate from the remote tail, and keep going until // we need to iterate from the remote tail, and keep going until
@ -117,11 +119,6 @@ async fn sync_upload(
let mut record = Some(store.get(start).await.unwrap()); let mut record = Some(store.get(start).await.unwrap());
// don't try and upload the head again
if let Some(r) = record {
record = store.next(&r).await?;
}
// We are currently uploading them one at a time. Yes, this sucks. We are // We are currently uploading them one at a time. Yes, this sucks. We are
// also processing all records in serial. That also sucks. // also processing all records in serial. That also sucks.
// Once sync works, we can then make it super fast. // Once sync works, we can then make it super fast.
@ -134,7 +131,15 @@ async fn sync_upload(
Ok(total) Ok(total)
} }
fn sync_download(tail: Uuid, host: Uuid, tag: String) -> Result<i64> {
async fn sync_download(
store: &mut impl Store,
remote_index: &RecordIndex,
client: &Client<'_>,
op: (Uuid, String, Uuid),
) -> Result<i64> {
let mut total = 0;
Ok(0) Ok(0)
} }
@ -157,7 +162,8 @@ pub async fn sync_remote(
sync_upload(local_store, remote_index, &client, (host, tag, tail)).await? sync_upload(local_store, remote_index, &client, (host, tag, tail)).await?
} }
Operation::Download { tail, host, tag } => { Operation::Download { tail, host, tag } => {
downloaded += sync_download(tail, host, tag)? downloaded +=
sync_download(local_store, remote_index, &client, (host, tag, tail)).await?
} }
} }
} }

View File

@ -57,6 +57,14 @@ pub trait Database: Sized + Clone + Send + Sync + 'static {
async fn deleted_history(&self, user: &User) -> DbResult<Vec<String>>; async fn deleted_history(&self, user: &User) -> DbResult<Vec<String>>;
async fn add_records(&self, user: &User, record: &[Record]) -> DbResult<()>; async fn add_records(&self, user: &User, record: &[Record]) -> DbResult<()>;
async fn next_records(
&self,
user: &User,
host: Uuid,
tag: String,
start: Option<Uuid>,
count: u64,
) -> DbResult<Vec<Record>>;
// Return the tail record ID for each store, so (HostID, Tag, TailRecordID) // Return the tail record ID for each store, so (HostID, Tag, TailRecordID)
async fn tail_records(&self, user: &User) -> DbResult<Vec<(Uuid, String, Uuid)>>; async fn tail_records(&self, user: &User) -> DbResult<Vec<(Uuid, String, Uuid)>>;

View File

@ -3,7 +3,7 @@ create table records (
id uuid primary key, -- remember to use uuidv7 for happy indices <3 id uuid primary key, -- remember to use uuidv7 for happy indices <3
client_id uuid not null, -- I am too uncomfortable with the idea of a client-generated primary key client_id uuid not null, -- I am too uncomfortable with the idea of a client-generated primary key
host uuid not null, -- a unique identifier for the host host uuid not null, -- a unique identifier for the host
parent uuid not null, -- the ID of the parent record, bearing in mind this is a linked list parent uuid default null, -- the ID of the parent record, bearing in mind this is a linked list
timestamp bigint not null, -- not a timestamp type, as those do not have nanosecond precision timestamp bigint not null, -- not a timestamp type, as those do not have nanosecond precision
version text not null, version text not null,
tag text not null, -- what is this? history, kv, whatever. Remember clients get a log per tag per host tag text not null, -- what is this? history, kv, whatever. Remember clients get a log per tag per host

View File

@ -12,7 +12,7 @@ use sqlx::Row;
use sqlx::types::Uuid; use sqlx::types::Uuid;
use tracing::instrument; use tracing::instrument;
use wrappers::{DbHistory, DbSession, DbUser}; use wrappers::{DbHistory, DbRecord, DbSession, DbUser};
mod wrappers; mod wrappers;
@ -334,6 +334,7 @@ impl Database for Postgres {
.map(|DbHistory(h)| h) .map(|DbHistory(h)| h)
} }
#[instrument(skip_all)]
async fn add_records(&self, user: &User, records: &[Record]) -> DbResult<()> { async fn add_records(&self, user: &User, records: &[Record]) -> DbResult<()> {
let mut tx = self.pool.begin().await.map_err(fix_error)?; let mut tx = self.pool.begin().await.map_err(fix_error)?;
@ -366,6 +367,57 @@ impl Database for Postgres {
Ok(()) Ok(())
} }
#[instrument(skip_all)]
async fn next_records(
&self,
user: &User,
host: Uuid,
tag: String,
start: Option<Uuid>,
count: u64,
) -> DbResult<Vec<Record>> {
tracing::debug!("{:?} - {:?} - {:?}", host, tag, start);
let mut ret = Vec::with_capacity(count as usize);
let mut parent = start;
// yeah let's do something better
for _ in 0..count {
// a very much not ideal query. but it's simple at least?
// we are basically using postgres as a kv store here, so... maybe consider using an actual
// kv store?
let record: Result<DbRecord, DbError> = sqlx::query_as(
"select client_id, host, parent, timestamp, version, tag, data from records
where user_id = $1
and tag = $2
and host = $3
and parent is not distinct from $4",
)
.bind(user.id)
.bind(tag.clone())
.bind(host)
.bind(parent)
.fetch_one(&self.pool)
.await
.map_err(fix_error);
match record {
Ok(record) => {
let record: Record = record.into();
ret.push(record.clone());
parent = Some(record.id);
}
Err(DbError::NotFound) => {
tracing::debug!("hit tail of store: {:?}/{}", host, tag);
return Ok(ret);
}
Err(e) => return Err(e),
}
}
Ok(ret)
}
async fn tail_records(&self, user: &User) -> DbResult<Vec<(Uuid, String, Uuid)>> { async fn tail_records(&self, user: &User) -> DbResult<Vec<(Uuid, String, Uuid)>> {
const TAIL_RECORDS_SQL: &str = "select host, tag, client_id from records rp where (select count(1) from records where parent=rp.client_id and user_id = $1) = 0;"; const TAIL_RECORDS_SQL: &str = "select host, tag, client_id from records rp where (select count(1) from records where parent=rp.client_id and user_id = $1) = 0;";

View File

@ -1,10 +1,12 @@
use ::sqlx::{FromRow, Result}; use ::sqlx::{FromRow, Result};
use atuin_common::record::Record;
use atuin_server_database::models::{History, Session, User}; use atuin_server_database::models::{History, Session, User};
use sqlx::{postgres::PgRow, Row}; use sqlx::{postgres::PgRow, Row};
pub struct DbUser(pub User); pub struct DbUser(pub User);
pub struct DbSession(pub Session); pub struct DbSession(pub Session);
pub struct DbHistory(pub History); pub struct DbHistory(pub History);
pub struct DbRecord(pub Record);
impl<'a> FromRow<'a, PgRow> for DbUser { impl<'a> FromRow<'a, PgRow> for DbUser {
fn from_row(row: &'a PgRow) -> Result<Self> { fn from_row(row: &'a PgRow) -> Result<Self> {
@ -40,3 +42,25 @@ impl<'a> ::sqlx::FromRow<'a, PgRow> for DbHistory {
})) }))
} }
} }
impl<'a> ::sqlx::FromRow<'a, PgRow> for DbRecord {
fn from_row(row: &'a PgRow) -> ::sqlx::Result<Self> {
let timestamp: i64 = row.try_get("timestamp")?;
Ok(Self(Record {
id: row.try_get("client_id")?,
host: row.try_get("host")?,
parent: row.try_get("parent")?,
timestamp: timestamp as u64,
version: row.try_get("version")?,
tag: row.try_get("tag")?,
data: row.try_get("data")?,
}))
}
}
impl Into<Record> for DbRecord {
fn into(self) -> Record {
Record { ..self.0 }
}
}

View File

@ -1,6 +1,8 @@
use axum::{extract::State, Json}; use axum::{extract::Query, extract::State, Json};
use http::StatusCode; use http::StatusCode;
use serde::Deserialize;
use tracing::{error, instrument}; use tracing::{error, instrument};
use uuid::Uuid;
use super::{ErrorResponse, ErrorResponseStatus, RespExt}; use super::{ErrorResponse, ErrorResponseStatus, RespExt};
use crate::router::{AppState, UserAuth}; use crate::router::{AppState, UserAuth};
@ -68,3 +70,36 @@ pub async fn index<DB: Database>(
Ok(Json(record_index)) Ok(Json(record_index))
} }
#[derive(Deserialize)]
pub struct NextParams {
host: Uuid,
tag: String,
start: Option<Uuid>,
count: u64,
}
#[instrument(skip_all, fields(user.id = user.id))]
pub async fn next<DB: Database>(
params: Query<NextParams>,
UserAuth(user): UserAuth,
state: State<AppState<DB>>,
) -> Result<Json<Vec<Record>>, ErrorResponseStatus<'static>> {
let State(AppState { database, settings }) = state;
let params = params.0;
let records = match database
.next_records(&user, params.host, params.tag, params.start, params.count)
.await
{
Ok(records) => records,
Err(e) => {
error!("failed to get record index: {}", e);
return Err(ErrorResponse::reply("failed to calculate record index")
.with_status(StatusCode::INTERNAL_SERVER_ERROR));
}
};
Ok(Json(records))
}

View File

@ -73,6 +73,7 @@ pub fn router<DB: Database>(database: DB, settings: Settings<DB::Settings>) -> R
.route("/history", delete(handlers::history::delete)) .route("/history", delete(handlers::history::delete))
.route("/record", post(handlers::record::post)) .route("/record", post(handlers::record::post))
.route("/record", get(handlers::record::index)) .route("/record", get(handlers::record::index))
.route("/record/next", get(handlers::record::next))
.route("/user/:username", get(handlers::user::get)) .route("/user/:username", get(handlers::user::get))
.route("/account", delete(handlers::user::delete)) .route("/account", delete(handlers::user::delete))
.route("/register", post(handlers::user::register)) .route("/register", post(handlers::user::register))