Switch to Warp + SQLx, use async, switch to Rust stable (#36)

* Switch to warp + sql, use async and stable rust

* Update CI to use stable
This commit is contained in:
Ellie Huxtable 2021-04-20 17:07:11 +01:00 committed by GitHub
parent f6de558070
commit 34888827f8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
32 changed files with 1520 additions and 1324 deletions

View File

@ -16,10 +16,10 @@ jobs:
steps: steps:
- uses: actions/checkout@v2 - uses: actions/checkout@v2
- name: Install latest nightly - name: Install rust
uses: actions-rs/toolchain@v1 uses: actions-rs/toolchain@v1
with: with:
toolchain: nightly toolchain: stable
override: true override: true
- name: Run cargo build - name: Run cargo build
@ -31,10 +31,10 @@ jobs:
steps: steps:
- uses: actions/checkout@v2 - uses: actions/checkout@v2
- name: Install latest nightly - name: Install rust
uses: actions-rs/toolchain@v1 uses: actions-rs/toolchain@v1
with: with:
toolchain: nightly toolchain: stable
override: true override: true
- name: Run cargo test - name: Run cargo test
@ -46,10 +46,10 @@ jobs:
steps: steps:
- uses: actions/checkout@v2 - uses: actions/checkout@v2
- name: Install latest nightly - name: Install latest rust
uses: actions-rs/toolchain@v1 uses: actions-rs/toolchain@v1
with: with:
toolchain: nightly toolchain: stable
override: true override: true
components: clippy components: clippy
@ -62,10 +62,10 @@ jobs:
steps: steps:
- uses: actions/checkout@v2 - uses: actions/checkout@v2
- name: Install latest nightly - name: Install latest rust
uses: actions-rs/toolchain@v1 uses: actions-rs/toolchain@v1
with: with:
toolchain: nightly toolchain: stable
override: true override: true
components: rustfmt components: rustfmt

1507
Cargo.lock generated

File diff suppressed because it is too large Load Diff

View File

@ -8,7 +8,7 @@ description = "atuin - magical shell history"
[dependencies] [dependencies]
log = "0.4" log = "0.4"
fern = "0.6.0" fern = {version = "0.6.0", features = ["colored"] }
chrono = { version = "0.4", features = ["serde"] } chrono = { version = "0.4", features = ["serde"] }
eyre = "0.6" eyre = "0.6"
shellexpand = "2" shellexpand = "2"
@ -17,7 +17,6 @@ directories = "3"
uuid = { version = "0.8", features = ["v4"] } uuid = { version = "0.8", features = ["v4"] }
indicatif = "0.15.0" indicatif = "0.15.0"
whoami = "1.1.2" whoami = "1.1.2"
rocket = "0.4.7"
chrono-english = "0.1.4" chrono-english = "0.1.4"
cli-table = "0.4" cli-table = "0.4"
config = "0.11" config = "0.11"
@ -29,8 +28,6 @@ tui = "0.14"
termion = "1.5" termion = "1.5"
unicode-width = "0.1" unicode-width = "0.1"
itertools = "0.10.0" itertools = "0.10.0"
diesel = { version = "1.4.4", features = ["postgres", "chrono"] }
diesel_migrations = "1.4.0"
dotenv = "0.15.0" dotenv = "0.15.0"
sodiumoxide = "0.2.6" sodiumoxide = "0.2.6"
reqwest = { version = "0.11", features = ["blocking", "json"] } reqwest = { version = "0.11", features = ["blocking", "json"] }
@ -40,12 +37,13 @@ parse_duration = "2.1.1"
rand = "0.8.3" rand = "0.8.3"
rust-crypto = "^0.2" rust-crypto = "^0.2"
human-panic = "1.0.3" human-panic = "1.0.3"
tokio = { version = "1", features = ["full"] }
warp = "0.3"
sqlx = { version = "0.5", features = [ "runtime-tokio-native-tls", "uuid", "chrono", "postgres" ] }
async-trait = "0.1.49"
urlencoding = "1.1.1"
humantime = "2.1.0"
[dependencies.rusqlite] [dependencies.rusqlite]
version = "0.25" version = "0.25"
features = ["bundled"] features = ["bundled"]
[dependencies.rocket_contrib]
version = "0.4.7"
default-features = false
features = ["diesel_postgres_pool", "json"]

View File

@ -1,7 +1,4 @@
FROM rust as builder FROM rust:1.51-buster as builder
RUN rustup default nightly
RUN cargo new --bin atuin RUN cargo new --bin atuin
WORKDIR /atuin WORKDIR /atuin

View File

@ -29,7 +29,7 @@
# sync_address = "https://api.atuin.sh" # sync_address = "https://api.atuin.sh"
# This section configures the sync server, if you decide to host your own # This section configures the sync server, if you decide to host your own
[remote] [server]
## host to bind, can also be passed via CLI args ## host to bind, can also be passed via CLI args
# host = "127.0.0.1" # host = "127.0.0.1"

View File

@ -1,8 +1,9 @@
use chrono::Utc; use chrono::Utc;
// This is shared between the client and the server, and has the data structures #[derive(Debug, Serialize, Deserialize)]
// representing the requests/responses for each method. pub struct UserResponse {
// TODO: Properly define responses rather than using json! pub username: String,
}
#[derive(Debug, Serialize, Deserialize)] #[derive(Debug, Serialize, Deserialize)]
pub struct RegisterRequest { pub struct RegisterRequest {
@ -11,12 +12,22 @@ pub struct RegisterRequest {
pub password: String, pub password: String,
} }
#[derive(Debug, Serialize, Deserialize)]
pub struct RegisterResponse {
pub session: String,
}
#[derive(Debug, Serialize, Deserialize)] #[derive(Debug, Serialize, Deserialize)]
pub struct LoginRequest { pub struct LoginRequest {
pub username: String, pub username: String,
pub password: String, pub password: String,
} }
#[derive(Debug, Serialize, Deserialize)]
pub struct LoginResponse {
pub session: String,
}
#[derive(Debug, Serialize, Deserialize)] #[derive(Debug, Serialize, Deserialize)]
pub struct AddHistoryRequest { pub struct AddHistoryRequest {
pub id: String, pub id: String,
@ -31,6 +42,29 @@ pub struct CountResponse {
} }
#[derive(Debug, Serialize, Deserialize)] #[derive(Debug, Serialize, Deserialize)]
pub struct ListHistoryResponse { pub struct SyncHistoryRequest {
pub sync_ts: chrono::DateTime<chrono::FixedOffset>,
pub history_ts: chrono::DateTime<chrono::FixedOffset>,
pub host: String,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct SyncHistoryResponse {
pub history: Vec<String>, pub history: Vec<String>,
} }
#[derive(Debug, Serialize, Deserialize)]
pub struct ErrorResponse {
pub reason: String,
}
impl ErrorResponse {
pub fn reply(reason: &str, status: warp::http::StatusCode) -> impl warp::Reply {
warp::reply::with_status(
warp::reply::json(&ErrorResponse {
reason: String::from(reason),
}),
status,
)
}
}

View File

@ -53,7 +53,7 @@ fn print_list(h: &[History]) {
} }
impl Cmd { impl Cmd {
pub fn run(&self, settings: &Settings, db: &mut impl Database) -> Result<()> { pub async fn run(&self, settings: &Settings, db: &mut (impl Database + Send)) -> Result<()> {
match self { match self {
Self::Start { command: words } => { Self::Start { command: words } => {
let command = words.join(" "); let command = words.join(" ");
@ -69,6 +69,10 @@ impl Cmd {
} }
Self::End { id, exit } => { Self::End { id, exit } => {
if id.trim() == "" {
return Ok(());
}
let mut h = db.load(id)?; let mut h = db.load(id)?;
h.exit = *exit; h.exit = *exit;
h.duration = chrono::Utc::now().timestamp_nanos() - h.timestamp.timestamp_nanos(); h.duration = chrono::Utc::now().timestamp_nanos() - h.timestamp.timestamp_nanos();
@ -82,7 +86,7 @@ impl Cmd {
} }
Ok(Fork::Child) => { Ok(Fork::Child) => {
debug!("running periodic background sync"); debug!("running periodic background sync");
sync::sync(settings, false, db)?; sync::sync(settings, false, db).await?;
} }
Err(_) => println!("Fork failed"), Err(_) => println!("Fork failed"),
} }

View File

@ -2,7 +2,7 @@ use std::collections::HashMap;
use std::fs::File; use std::fs::File;
use std::io::prelude::*; use std::io::prelude::*;
use eyre::Result; use eyre::{eyre, Result};
use structopt::StructOpt; use structopt::StructOpt;
use crate::settings::Settings; use crate::settings::Settings;
@ -28,8 +28,13 @@ impl Cmd {
let url = format!("{}/login", settings.local.sync_address); let url = format!("{}/login", settings.local.sync_address);
let client = reqwest::blocking::Client::new(); let client = reqwest::blocking::Client::new();
let resp = client.post(url).json(&map).send()?; let resp = client.post(url).json(&map).send()?;
if resp.status() != reqwest::StatusCode::OK {
return Err(eyre!("invalid login details"));
}
let session = resp.json::<HashMap<String, String>>()?; let session = resp.json::<HashMap<String, String>>()?;
let session = session["session"].clone(); let session = session["session"].clone();

View File

@ -63,16 +63,16 @@ pub fn uuid_v4() -> String {
} }
impl AtuinCmd { impl AtuinCmd {
pub fn run(self, db: &mut impl Database, settings: &Settings) -> Result<()> { pub async fn run<T: Database + Send>(self, db: &mut T, settings: &Settings) -> Result<()> {
match self { match self {
Self::History(history) => history.run(settings, db), Self::History(history) => history.run(settings, db).await,
Self::Import(import) => import.run(db), Self::Import(import) => import.run(db),
Self::Server(server) => server.run(settings), Self::Server(server) => server.run(settings).await,
Self::Stats(stats) => stats.run(db, settings), Self::Stats(stats) => stats.run(db, settings),
Self::Init => init::init(), Self::Init => init::init(),
Self::Search { query } => search::run(&query, db), Self::Search { query } => search::run(&query, db),
Self::Sync { force } => sync::run(settings, force, db), Self::Sync { force } => sync::run(settings, force, db).await,
Self::Login(l) => l.run(settings), Self::Login(l) => l.run(settings),
Self::Register(r) => register::run( Self::Register(r) => register::run(
settings, settings,

View File

@ -1,6 +1,8 @@
use eyre::Result; use eyre::Result;
use itertools::Itertools; use itertools::Itertools;
use std::io::stdout; use std::io::stdout;
use std::time::Duration;
use termion::{event::Key, input::MouseTerminal, raw::IntoRawMode, screen::AlternateScreen}; use termion::{event::Key, input::MouseTerminal, raw::IntoRawMode, screen::AlternateScreen};
use tui::{ use tui::{
backend::TermionBackend, backend::TermionBackend,
@ -26,6 +28,78 @@ struct State {
results_state: ListState, results_state: ListState,
} }
#[allow(clippy::clippy::cast_sign_loss)]
impl State {
fn durations(&self) -> Vec<String> {
self.results
.iter()
.map(|h| {
let duration =
Duration::from_millis(std::cmp::max(h.duration, 0) as u64 / 1_000_000);
let duration = humantime::format_duration(duration).to_string();
let duration: Vec<&str> = duration.split(' ').collect();
duration[0].to_string()
})
.collect()
}
fn render_results<T: tui::backend::Backend>(
&mut self,
f: &mut tui::Frame<T>,
r: tui::layout::Rect,
) {
let durations = self.durations();
let max_length = durations
.iter()
.fold(0, |largest, i| std::cmp::max(largest, i.len()));
let results: Vec<ListItem> = self
.results
.iter()
.enumerate()
.map(|(i, m)| {
let command = m.command.to_string().replace("\n", " ").replace("\t", " ");
let mut command = Span::raw(command);
let mut duration = durations[i].clone();
while duration.len() < max_length {
duration.push(' ');
}
let duration = Span::styled(
duration,
Style::default().fg(if m.exit == 0 || m.duration == -1 {
Color::Green
} else {
Color::Red
}),
);
if let Some(selected) = self.results_state.selected() {
if selected == i {
command.style =
Style::default().fg(Color::Red).add_modifier(Modifier::BOLD);
}
}
let spans = Spans::from(vec![duration, Span::raw(" "), command]);
ListItem::new(spans)
})
.collect();
let results = List::new(results)
.block(Block::default().borders(Borders::ALL).title("History"))
.start_corner(Corner::BottomLeft)
.highlight_symbol(">> ");
f.render_stateful_widget(results, r, &mut self.results_state);
}
}
fn query_results(app: &mut State, db: &mut impl Database) { fn query_results(app: &mut State, db: &mut impl Database) {
let results = match app.input.as_str() { let results = match app.input.as_str() {
"" => db.list(), "" => db.list(),
@ -48,7 +122,11 @@ fn key_handler(input: Key, db: &mut impl Database, app: &mut State) -> Option<St
Key::Esc | Key::Char('\n') => { Key::Esc | Key::Char('\n') => {
let i = app.results_state.selected().unwrap_or(0); let i = app.results_state.selected().unwrap_or(0);
return Some(app.results.get(i).unwrap().command.clone()); return Some(
app.results
.get(i)
.map_or("".to_string(), |h| h.command.clone()),
);
} }
Key::Char(c) => { Key::Char(c) => {
app.input.push(c); app.input.push(c);
@ -163,32 +241,8 @@ fn select_history(query: &[String], db: &mut impl Database) -> Result<String> {
let help = Text::from(Spans::from(help)); let help = Text::from(Spans::from(help));
let help = Paragraph::new(help); let help = Paragraph::new(help);
let input = Paragraph::new(app.input.as_ref()) let input = Paragraph::new(app.input.clone())
.block(Block::default().borders(Borders::ALL).title("Search")); .block(Block::default().borders(Borders::ALL).title("Query"));
let results: Vec<ListItem> = app
.results
.iter()
.enumerate()
.map(|(i, m)| {
let mut content =
Span::raw(m.command.to_string().replace("\n", " ").replace("\t", " "));
if let Some(selected) = app.results_state.selected() {
if selected == i {
content.style =
Style::default().fg(Color::Red).add_modifier(Modifier::BOLD);
}
}
ListItem::new(content)
})
.collect();
let results = List::new(results)
.block(Block::default().borders(Borders::ALL).title("History"))
.start_corner(Corner::BottomLeft)
.highlight_symbol(">> ");
let stats = Paragraph::new(Text::from(Span::raw(format!( let stats = Paragraph::new(Text::from(Span::raw(format!(
"history count: {}", "history count: {}",
@ -199,8 +253,8 @@ fn select_history(query: &[String], db: &mut impl Database) -> Result<String> {
f.render_widget(title, top_left_chunks[0]); f.render_widget(title, top_left_chunks[0]);
f.render_widget(help, top_left_chunks[1]); f.render_widget(help, top_left_chunks[1]);
app.render_results(f, chunks[1]);
f.render_widget(stats, top_right_chunks[0]); f.render_widget(stats, top_right_chunks[0]);
f.render_stateful_widget(results, chunks[1], &mut app.results_state);
f.render_widget(input, chunks[2]); f.render_widget(input, chunks[2]);
f.set_cursor( f.set_cursor(

View File

@ -1,7 +1,7 @@
use eyre::Result; use eyre::Result;
use structopt::StructOpt; use structopt::StructOpt;
use crate::remote::server; use crate::server;
use crate::settings::Settings; use crate::settings::Settings;
#[derive(StructOpt)] #[derive(StructOpt)]
@ -20,7 +20,7 @@ pub enum Cmd {
} }
impl Cmd { impl Cmd {
pub fn run(&self, settings: &Settings) -> Result<()> { pub async fn run(&self, settings: &Settings) -> Result<()> {
match self { match self {
Self::Start { host, port } => { Self::Start { host, port } => {
let host = host.as_ref().map_or( let host = host.as_ref().map_or(
@ -29,7 +29,7 @@ impl Cmd {
); );
let port = port.map_or(settings.server.port, |p| p); let port = port.map_or(settings.server.port, |p| p);
server::launch(settings, host, port) server::launch(settings, host, port).await
} }
} }
} }

View File

@ -4,8 +4,8 @@ use crate::local::database::Database;
use crate::local::sync; use crate::local::sync;
use crate::settings::Settings; use crate::settings::Settings;
pub fn run(settings: &Settings, force: bool, db: &mut impl Database) -> Result<()> { pub async fn run(settings: &Settings, force: bool, db: &mut (impl Database + Send)) -> Result<()> {
sync::sync(settings, force, db)?; sync::sync(settings, force, db).await?;
println!( println!(
"Sync complete! {} items in database, force: {}", "Sync complete! {} items in database, force: {}",
db.history_count()?, db.history_count()?,

View File

@ -1,93 +1,94 @@
use chrono::Utc; use chrono::Utc;
use eyre::Result; use eyre::Result;
use reqwest::header::AUTHORIZATION; use reqwest::header::{HeaderMap, AUTHORIZATION};
use reqwest::Url;
use sodiumoxide::crypto::secretbox;
use crate::api::{AddHistoryRequest, CountResponse, ListHistoryResponse}; use crate::api::{AddHistoryRequest, CountResponse, SyncHistoryResponse};
use crate::local::encryption::{decrypt, load_key}; use crate::local::encryption::decrypt;
use crate::local::history::History; use crate::local::history::History;
use crate::settings::Settings;
use crate::utils::hash_str; use crate::utils::hash_str;
pub struct Client<'a> { pub struct Client<'a> {
settings: &'a Settings, sync_addr: &'a str,
token: &'a str,
key: secretbox::Key,
client: reqwest::Client,
} }
impl<'a> Client<'a> { impl<'a> Client<'a> {
pub const fn new(settings: &'a Settings) -> Self { pub fn new(sync_addr: &'a str, token: &'a str, key: secretbox::Key) -> Self {
Client { settings } Client {
sync_addr,
token,
key,
client: reqwest::Client::new(),
}
} }
pub fn count(&self) -> Result<i64> { pub async fn count(&self) -> Result<i64> {
let url = format!("{}/sync/count", self.settings.local.sync_address); let url = format!("{}/sync/count", self.sync_addr);
let client = reqwest::blocking::Client::new(); let url = Url::parse(url.as_str())?;
let token = format!("Token {}", self.token);
let token = token.parse()?;
let resp = client let mut headers = HeaderMap::new();
.get(url) headers.insert(AUTHORIZATION, token);
.header(
AUTHORIZATION,
format!("Token {}", self.settings.local.session_token),
)
.send()?;
let count = resp.json::<CountResponse>()?; let resp = self.client.get(url).headers(headers).send().await?;
let count = resp.json::<CountResponse>().await?;
Ok(count.count) Ok(count.count)
} }
pub fn get_history( pub async fn get_history(
&self, &self,
sync_ts: chrono::DateTime<Utc>, sync_ts: chrono::DateTime<Utc>,
history_ts: chrono::DateTime<Utc>, history_ts: chrono::DateTime<Utc>,
host: Option<String>, host: Option<String>,
) -> Result<Vec<History>> { ) -> Result<Vec<History>> {
let key = load_key(self.settings)?;
let host = match host { let host = match host {
None => hash_str(&format!("{}:{}", whoami::hostname(), whoami::username())), None => hash_str(&format!("{}:{}", whoami::hostname(), whoami::username())),
Some(h) => h, Some(h) => h,
}; };
// this allows for syncing between users on the same machine
let url = format!( let url = format!(
"{}/sync/history?sync_ts={}&history_ts={}&host={}", "{}/sync/history?sync_ts={}&history_ts={}&host={}",
self.settings.local.sync_address, self.sync_addr,
sync_ts.to_rfc3339(), urlencoding::encode(sync_ts.to_rfc3339().as_str()),
history_ts.to_rfc3339(), urlencoding::encode(history_ts.to_rfc3339().as_str()),
host, host,
); );
let client = reqwest::blocking::Client::new();
let resp = client let resp = self
.client
.get(url) .get(url)
.header( .header(AUTHORIZATION, format!("Token {}", self.token))
AUTHORIZATION, .send()
format!("Token {}", self.settings.local.session_token), .await?;
)
.send()?;
let history = resp.json::<ListHistoryResponse>()?; let history = resp.json::<SyncHistoryResponse>().await?;
let history = history let history = history
.history .history
.iter() .iter()
.map(|h| serde_json::from_str(h).expect("invalid base64")) .map(|h| serde_json::from_str(h).expect("invalid base64"))
.map(|h| decrypt(&h, &key).expect("failed to decrypt history! check your key")) .map(|h| decrypt(&h, &self.key).expect("failed to decrypt history! check your key"))
.collect(); .collect();
Ok(history) Ok(history)
} }
pub fn post_history(&self, history: &[AddHistoryRequest]) -> Result<()> { pub async fn post_history(&self, history: &[AddHistoryRequest]) -> Result<()> {
let client = reqwest::blocking::Client::new(); let url = format!("{}/history", self.sync_addr);
let url = Url::parse(url.as_str())?;
let url = format!("{}/history", self.settings.local.sync_address); self.client
client
.post(url) .post(url)
.json(history) .json(history)
.header( .header(AUTHORIZATION, format!("Token {}", self.token))
AUTHORIZATION, .send()
format!("Token {}", self.settings.local.session_token), .await?;
)
.send()?;
Ok(()) Ok(())
} }

View File

@ -215,9 +215,9 @@ impl Database for Sqlite {
} }
fn before(&self, timestamp: chrono::DateTime<Utc>, count: i64) -> Result<Vec<History>> { fn before(&self, timestamp: chrono::DateTime<Utc>, count: i64) -> Result<Vec<History>> {
let mut stmt = self.conn.prepare( let mut stmt = self
"SELECT * FROM history where timestamp <= ? order by timestamp desc limit ?", .conn
)?; .prepare("SELECT * FROM history where timestamp < ? order by timestamp desc limit ?")?;
let history_iter = stmt.query_map(params![timestamp.timestamp_nanos(), count], |row| { let history_iter = stmt.query_map(params![timestamp.timestamp_nanos(), count], |row| {
history_from_sqlite_row(None, row) history_from_sqlite_row(None, row)
@ -236,7 +236,7 @@ impl Database for Sqlite {
fn prefix_search(&self, query: &str) -> Result<Vec<History>> { fn prefix_search(&self, query: &str) -> Result<Vec<History>> {
self.query( self.query(
"select * from history where command like ?1 || '%' order by timestamp asc", "select * from history where command like ?1 || '%' order by timestamp asc limit 1000",
&[query], &[query],
) )
} }

View File

@ -7,6 +7,7 @@ use std::{fs::File, path::Path};
use chrono::prelude::*; use chrono::prelude::*;
use chrono::Utc; use chrono::Utc;
use eyre::{eyre, Result}; use eyre::{eyre, Result};
use itertools::Itertools;
use super::history::History; use super::history::History;
@ -42,8 +43,8 @@ impl Zsh {
fn parse_extended(line: &str, counter: i64) -> History { fn parse_extended(line: &str, counter: i64) -> History {
let line = line.replacen(": ", "", 2); let line = line.replacen(": ", "", 2);
let (time, duration) = line.split_once(':').unwrap(); let (time, duration) = line.splitn(2, ':').collect_tuple().unwrap();
let (duration, command) = duration.split_once(';').unwrap(); let (duration, command) = duration.splitn(2, ';').collect_tuple().unwrap();
let time = time let time = time
.parse::<i64>() .parse::<i64>()
@ -60,7 +61,7 @@ fn parse_extended(line: &str, counter: i64) -> History {
time, time,
command.trim_end().to_string(), command.trim_end().to_string(),
String::from("unknown"), String::from("unknown"),
-1, 0, // assume 0, we have no way of knowing :(
duration, duration,
None, None,
None, None,

View File

@ -20,12 +20,12 @@ use crate::{api::AddHistoryRequest, utils::hash_str};
// Check if remote has things we don't, and if so, download them. // Check if remote has things we don't, and if so, download them.
// Returns (num downloaded, total local) // Returns (num downloaded, total local)
fn sync_download( async fn sync_download(
force: bool, force: bool,
client: &api_client::Client, client: &api_client::Client<'_>,
db: &mut impl Database, db: &mut (impl Database + Send),
) -> Result<(i64, i64)> { ) -> Result<(i64, i64)> {
let remote_count = client.count()?; let remote_count = client.count().await?;
let initial_local = db.history_count()?; let initial_local = db.history_count()?;
let mut local_count = initial_local; let mut local_count = initial_local;
@ -41,7 +41,9 @@ fn sync_download(
let host = if force { Some(String::from("")) } else { None }; let host = if force { Some(String::from("")) } else { None };
while remote_count > local_count { while remote_count > local_count {
let page = client.get_history(last_sync, last_timestamp, host.clone())?; let page = client
.get_history(last_sync, last_timestamp, host.clone())
.await?;
if page.len() < HISTORY_PAGE_SIZE.try_into().unwrap() { if page.len() < HISTORY_PAGE_SIZE.try_into().unwrap() {
break; break;
@ -71,13 +73,13 @@ fn sync_download(
} }
// Check if we have things remote doesn't, and if so, upload them // Check if we have things remote doesn't, and if so, upload them
fn sync_upload( async fn sync_upload(
settings: &Settings, settings: &Settings,
_force: bool, _force: bool,
client: &api_client::Client, client: &api_client::Client<'_>,
db: &mut impl Database, db: &mut (impl Database + Send),
) -> Result<()> { ) -> Result<()> {
let initial_remote_count = client.count()?; let initial_remote_count = client.count().await?;
let mut remote_count = initial_remote_count; let mut remote_count = initial_remote_count;
let local_count = db.history_count()?; let local_count = db.history_count()?;
@ -111,21 +113,25 @@ fn sync_upload(
} }
// anything left over outside of the 100 block size // anything left over outside of the 100 block size
client.post_history(&buffer)?; client.post_history(&buffer).await?;
cursor = buffer.last().unwrap().timestamp; cursor = buffer.last().unwrap().timestamp;
remote_count = client.count()?; remote_count = client.count().await?;
} }
Ok(()) Ok(())
} }
pub fn sync(settings: &Settings, force: bool, db: &mut impl Database) -> Result<()> { pub async fn sync(settings: &Settings, force: bool, db: &mut (impl Database + Send)) -> Result<()> {
let client = api_client::Client::new(settings); let client = api_client::Client::new(
settings.local.sync_address.as_str(),
settings.local.session_token.as_str(),
load_key(settings)?,
);
sync_upload(settings, force, &client, db)?; sync_upload(settings, force, &client, db).await?;
let download = sync_download(force, &client, db)?; let download = sync_download(force, &client, db).await?;
debug!("sync downloaded {}", download.0); debug!("sync downloaded {}", download.0);

View File

@ -1,32 +1,19 @@
#![feature(proc_macro_hygiene)]
#![feature(decl_macro)]
#![warn(clippy::pedantic, clippy::nursery)] #![warn(clippy::pedantic, clippy::nursery)]
#![allow(clippy::use_self)] // not 100% reliable #![allow(clippy::use_self)] // not 100% reliable
use std::path::PathBuf; use std::path::PathBuf;
use eyre::{eyre, Result}; use eyre::{eyre, Result};
use fern::colors::{Color, ColoredLevelConfig};
use human_panic::setup_panic; use human_panic::setup_panic;
use structopt::{clap::AppSettings, StructOpt}; use structopt::{clap::AppSettings, StructOpt};
#[macro_use] #[macro_use]
extern crate log; extern crate log;
#[macro_use]
extern crate rocket;
#[macro_use] #[macro_use]
extern crate serde_derive; extern crate serde_derive;
#[macro_use]
extern crate diesel;
#[macro_use]
extern crate diesel_migrations;
#[macro_use]
extern crate rocket_contrib;
use command::AtuinCmd; use command::AtuinCmd;
use local::database::Sqlite; use local::database::Sqlite;
use settings::Settings; use settings::Settings;
@ -34,12 +21,10 @@ use settings::Settings;
mod api; mod api;
mod command; mod command;
mod local; mod local;
mod remote; mod server;
mod settings; mod settings;
mod utils; mod utils;
pub mod schema;
#[derive(StructOpt)] #[derive(StructOpt)]
#[structopt( #[structopt(
author = "Ellie Huxtable <e@elm.sh>", author = "Ellie Huxtable <e@elm.sh>",
@ -56,7 +41,7 @@ struct Atuin {
} }
impl Atuin { impl Atuin {
fn run(self, settings: &Settings) -> Result<()> { async fn run(self, settings: &Settings) -> Result<()> {
let db_path = if let Some(db_path) = self.db { let db_path = if let Some(db_path) = self.db {
let path = db_path let path = db_path
.to_str() .to_str()
@ -69,26 +54,32 @@ impl Atuin {
let mut db = Sqlite::new(db_path)?; let mut db = Sqlite::new(db_path)?;
self.atuin.run(&mut db, settings) self.atuin.run(&mut db, settings).await
} }
} }
fn main() -> Result<()> { #[tokio::main]
setup_panic!(); async fn main() -> Result<()> {
let settings = Settings::new()?; let colors = ColoredLevelConfig::new()
.warn(Color::Yellow)
.error(Color::Red);
fern::Dispatch::new() fern::Dispatch::new()
.format(|out, message, record| { .format(move |out, message, record| {
out.finish(format_args!( out.finish(format_args!(
"{} [{}] {}", "{} [{}] {}",
chrono::Local::now().format("[%Y-%m-%d][%H:%M:%S]"), chrono::Local::now().to_rfc3339(),
record.level(), colors.color(record.level()),
message message
)) ))
}) })
.level(log::LevelFilter::Info) .level(log::LevelFilter::Info)
.level_for("sqlx", log::LevelFilter::Warn)
.chain(std::io::stdout()) .chain(std::io::stdout())
.apply()?; .apply()?;
Atuin::from_args().run(&settings) let settings = Settings::new()?;
setup_panic!();
Atuin::from_args().run(&settings).await
} }

View File

@ -1,22 +0,0 @@
use diesel::pg::PgConnection;
use diesel::prelude::*;
use eyre::{eyre, Result};
use crate::settings::Settings;
#[database("atuin")]
pub struct AtuinDbConn(diesel::PgConnection);
// TODO: connection pooling
pub fn establish_connection(settings: &Settings) -> Result<PgConnection> {
if settings.server.db_uri == "default_uri" {
Err(eyre!(
"Please configure your database! Set db_uri in config.toml"
))
} else {
let database_url = &settings.server.db_uri;
let conn = PgConnection::establish(database_url)?;
Ok(conn)
}
}

View File

@ -1,5 +0,0 @@
pub mod auth;
pub mod database;
pub mod models;
pub mod server;
pub mod views;

View File

@ -1,61 +0,0 @@
use std::collections::HashMap;
use crate::remote::database::establish_connection;
use crate::settings::Settings;
use super::database::AtuinDbConn;
use eyre::Result;
use rocket::config::{Config, Environment, LoggingLevel, Value};
// a bunch of these imports are generated by macros, it's easier to wildcard
#[allow(clippy::clippy::wildcard_imports)]
use super::views::*;
#[allow(clippy::clippy::wildcard_imports)]
use super::auth::*;
embed_migrations!("migrations");
pub fn launch(settings: &Settings, host: String, port: u16) -> Result<()> {
let settings: Settings = settings.clone(); // clone so rocket can manage it
let mut database_config = HashMap::new();
let mut databases = HashMap::new();
database_config.insert("url", Value::from(settings.server.db_uri.clone()));
databases.insert("atuin", Value::from(database_config));
let connection = establish_connection(&settings)?;
embedded_migrations::run(&connection).expect("failed to run migrations");
let config = Config::build(Environment::Production)
.address(host)
.log_level(LoggingLevel::Normal)
.port(port)
.extra("databases", databases)
.finalize()
.unwrap();
let app = rocket::custom(config);
app.mount(
"/",
routes![
index,
register,
add_history,
login,
get_user,
sync_count,
sync_list
],
)
.manage(settings)
.attach(AtuinDbConn::fairing())
.register(catchers![internal_error, bad_request])
.launch();
Ok(())
}

View File

@ -1,185 +0,0 @@
use chrono::Utc;
use rocket::http::uri::Uri;
use rocket::http::RawStr;
use rocket::http::{ContentType, Status};
use rocket::request::FromFormValue;
use rocket::request::Request;
use rocket::response;
use rocket::response::{Responder, Response};
use rocket_contrib::databases::diesel;
use rocket_contrib::json::{Json, JsonValue};
use self::diesel::prelude::*;
use crate::api::AddHistoryRequest;
use crate::schema::history;
use crate::settings::HISTORY_PAGE_SIZE;
use super::database::AtuinDbConn;
use super::models::{History, NewHistory, User};
#[derive(Debug)]
pub struct ApiResponse {
pub json: JsonValue,
pub status: Status,
}
impl<'r> Responder<'r> for ApiResponse {
fn respond_to(self, req: &Request) -> response::Result<'r> {
Response::build_from(self.json.respond_to(req).unwrap())
.status(self.status)
.header(ContentType::JSON)
.ok()
}
}
#[get("/")]
pub const fn index() -> &'static str {
"\"Through the fathomless deeps of space swims the star turtle Great A\u{2019}Tuin, bearing on its back the four giant elephants who carry on their shoulders the mass of the Discworld.\"\n\t-- Sir Terry Pratchett"
}
#[catch(500)]
pub fn internal_error(_req: &Request) -> ApiResponse {
ApiResponse {
status: Status::InternalServerError,
json: json!({"status": "error", "message": "an internal server error has occured"}),
}
}
#[catch(400)]
pub fn bad_request(_req: &Request) -> ApiResponse {
ApiResponse {
status: Status::InternalServerError,
json: json!({"status": "error", "message": "bad request. don't do that."}),
}
}
#[post("/history", data = "<add_history>")]
#[allow(
clippy::clippy::cast_sign_loss,
clippy::cast_possible_truncation,
clippy::clippy::needless_pass_by_value
)]
pub fn add_history(
conn: AtuinDbConn,
user: User,
add_history: Json<Vec<AddHistoryRequest>>,
) -> ApiResponse {
let new_history: Vec<NewHistory> = add_history
.iter()
.map(|h| NewHistory {
client_id: h.id.as_str(),
hostname: h.hostname.to_string(),
user_id: user.id,
timestamp: h.timestamp.naive_utc(),
data: h.data.as_str(),
})
.collect();
match diesel::insert_into(history::table)
.values(&new_history)
.on_conflict_do_nothing()
.execute(&*conn)
{
Ok(_) => ApiResponse {
status: Status::Ok,
json: json!({"status": "ok", "message": "history added"}),
},
Err(_) => ApiResponse {
status: Status::BadRequest,
json: json!({"status": "error", "message": "failed to add history"}),
},
}
}
#[get("/sync/count")]
#[allow(clippy::wildcard_imports, clippy::needless_pass_by_value)]
pub fn sync_count(conn: AtuinDbConn, user: User) -> ApiResponse {
use crate::schema::history::dsl::*;
// we need to return the number of history items we have for this user
// in the future I'd like to use something like a merkel tree to calculate
// which day specifically needs syncing
let count = history
.filter(user_id.eq(user.id))
.count()
.first::<i64>(&*conn);
if count.is_err() {
error!("failed to count: {}", count.err().unwrap());
return ApiResponse {
json: json!({"message": "internal server error"}),
status: Status::InternalServerError,
};
}
ApiResponse {
status: Status::Ok,
json: json!({"count": count.ok()}),
}
}
pub struct UtcDateTime(chrono::DateTime<Utc>);
impl<'v> FromFormValue<'v> for UtcDateTime {
type Error = &'v RawStr;
fn from_form_value(form_value: &'v RawStr) -> Result<UtcDateTime, &'v RawStr> {
let time = Uri::percent_decode(form_value.as_bytes()).map_err(|_| form_value)?;
let time = time.to_string();
match chrono::DateTime::parse_from_rfc3339(time.as_str()) {
Ok(t) => Ok(UtcDateTime(t.with_timezone(&Utc))),
Err(e) => {
error!("failed to parse time {}, got: {}", time, e);
Err(form_value)
}
}
}
}
// Request a list of all history items added to the DB after a given timestamp.
// Provide the current hostname, so that we don't send the client data that
// originated from them
#[get("/sync/history?<sync_ts>&<history_ts>&<host>")]
#[allow(clippy::wildcard_imports, clippy::needless_pass_by_value)]
pub fn sync_list(
conn: AtuinDbConn,
user: User,
sync_ts: UtcDateTime,
history_ts: UtcDateTime,
host: String,
) -> ApiResponse {
use crate::schema::history::dsl::*;
// we need to return the number of history items we have for this user
// in the future I'd like to use something like a merkel tree to calculate
// which day specifically needs syncing
// TODO: Allow for configuring the page size, both from params, and setting
// the max in config. 100 is fine for now.
let h = history
.filter(user_id.eq(user.id))
.filter(hostname.ne(host))
.filter(created_at.ge(sync_ts.0.naive_utc()))
.filter(timestamp.ge(history_ts.0.naive_utc()))
.order(timestamp.asc())
.limit(HISTORY_PAGE_SIZE)
.load::<History>(&*conn);
if let Err(e) = h {
error!("failed to load history: {}", e);
return ApiResponse {
json: json!({"message": "internal server error"}),
status: Status::InternalServerError,
};
}
let user_data: Vec<String> = h.unwrap().iter().map(|i| i.data.to_string()).collect();
ApiResponse {
status: Status::Ok,
json: json!({ "history": user_data }),
}
}

View File

@ -1,30 +0,0 @@
table! {
history (id) {
id -> Int8,
client_id -> Text,
user_id -> Int8,
hostname -> Text,
timestamp -> Timestamp,
data -> Varchar,
created_at -> Timestamp,
}
}
table! {
sessions (id) {
id -> Int8,
user_id -> Int8,
token -> Varchar,
}
}
table! {
users (id) {
id -> Int8,
username -> Varchar,
email -> Varchar,
password -> Varchar,
}
}
allow_tables_to_appear_in_same_query!(history, sessions, users,);

View File

@ -1,3 +1,4 @@
/*
use self::diesel::prelude::*; use self::diesel::prelude::*;
use eyre::Result; use eyre::Result;
use rocket::http::Status; use rocket::http::Status;
@ -218,3 +219,4 @@ pub fn login(conn: AtuinDbConn, login: Json<LoginRequest>) -> ApiResponse {
json: json!({"session": session.token}), json: json!({"session": session.token}),
} }
} }
*/

202
src/server/database.rs Normal file
View File

@ -0,0 +1,202 @@
use async_trait::async_trait;
use eyre::{eyre, Result};
use sqlx::postgres::PgPoolOptions;
use crate::settings::HISTORY_PAGE_SIZE;
use super::models::{History, NewHistory, NewSession, NewUser, Session, User};
#[async_trait]
pub trait Database {
async fn get_session(&self, token: &str) -> Result<Session>;
async fn get_session_user(&self, token: &str) -> Result<User>;
async fn add_session(&self, session: &NewSession) -> Result<()>;
async fn get_user(&self, username: String) -> Result<User>;
async fn get_user_session(&self, u: &User) -> Result<Session>;
async fn add_user(&self, user: NewUser) -> Result<i64>;
async fn count_history(&self, user: &User) -> Result<i64>;
async fn list_history(
&self,
user: &User,
created_since: chrono::NaiveDateTime,
since: chrono::NaiveDateTime,
host: String,
) -> Result<Vec<History>>;
async fn add_history(&self, history: &[NewHistory]) -> Result<()>;
}
#[derive(Clone)]
pub struct Postgres {
pool: sqlx::Pool<sqlx::postgres::Postgres>,
}
impl Postgres {
pub async fn new(uri: &str) -> Result<Self, sqlx::Error> {
let pool = PgPoolOptions::new()
.max_connections(100)
.connect(uri)
.await?;
Ok(Self { pool })
}
}
#[async_trait]
impl Database for Postgres {
async fn get_session(&self, token: &str) -> Result<Session> {
let res: Option<Session> =
sqlx::query_as::<_, Session>("select * from sessions where token = $1")
.bind(token)
.fetch_optional(&self.pool)
.await?;
if let Some(s) = res {
Ok(s)
} else {
Err(eyre!("could not find session"))
}
}
async fn get_user(&self, username: String) -> Result<User> {
let res: Option<User> =
sqlx::query_as::<_, User>("select * from users where username = $1")
.bind(username)
.fetch_optional(&self.pool)
.await?;
if let Some(u) = res {
Ok(u)
} else {
Err(eyre!("could not find user"))
}
}
async fn get_session_user(&self, token: &str) -> Result<User> {
let res: Option<User> = sqlx::query_as::<_, User>(
"select * from users
inner join sessions
on users.id = sessions.user_id
and sessions.token = $1",
)
.bind(token)
.fetch_optional(&self.pool)
.await?;
if let Some(u) = res {
Ok(u)
} else {
Err(eyre!("could not find user"))
}
}
async fn count_history(&self, user: &User) -> Result<i64> {
let res: (i64,) = sqlx::query_as(
"select count(1) from history
where user_id = $1",
)
.bind(user.id)
.fetch_one(&self.pool)
.await?;
Ok(res.0)
}
async fn list_history(
&self,
user: &User,
created_since: chrono::NaiveDateTime,
since: chrono::NaiveDateTime,
host: String,
) -> Result<Vec<History>> {
let res = sqlx::query_as::<_, History>(
"select * from history
where user_id = $1
and hostname != $2
and created_at >= $3
and timestamp >= $4
order by timestamp asc
limit $5",
)
.bind(user.id)
.bind(host)
.bind(created_since)
.bind(since)
.bind(HISTORY_PAGE_SIZE)
.fetch_all(&self.pool)
.await?;
Ok(res)
}
async fn add_history(&self, history: &[NewHistory]) -> Result<()> {
let mut tx = self.pool.begin().await?;
for i in history {
sqlx::query(
"insert into history
(client_id, user_id, hostname, timestamp, data)
values ($1, $2, $3, $4, $5)
on conflict do nothing
",
)
.bind(i.client_id)
.bind(i.user_id)
.bind(i.hostname)
.bind(i.timestamp)
.bind(i.data)
.execute(&mut tx)
.await?;
}
tx.commit().await?;
Ok(())
}
async fn add_user(&self, user: NewUser) -> Result<i64> {
let res: (i64,) = sqlx::query_as(
"insert into users
(username, email, password)
values($1, $2, $3)
returning id",
)
.bind(user.username.as_str())
.bind(user.email.as_str())
.bind(user.password)
.fetch_one(&self.pool)
.await?;
Ok(res.0)
}
async fn add_session(&self, session: &NewSession) -> Result<()> {
sqlx::query(
"insert into sessions
(user_id, token)
values($1, $2)",
)
.bind(session.user_id)
.bind(session.token)
.execute(&self.pool)
.await?;
Ok(())
}
async fn get_user_session(&self, u: &User) -> Result<Session> {
let res: Option<Session> =
sqlx::query_as::<_, Session>("select * from sessions where user_id = $1")
.bind(u.id)
.fetch_optional(&self.pool)
.await?;
if let Some(s) = res {
Ok(s)
} else {
Err(eyre!("could not find session"))
}
}
}

View File

@ -0,0 +1,89 @@
use std::convert::Infallible;
use warp::{http::StatusCode, reply::json};
use crate::api::{
AddHistoryRequest, CountResponse, ErrorResponse, SyncHistoryRequest, SyncHistoryResponse,
};
use crate::server::database::Database;
use crate::server::models::{NewHistory, User};
pub async fn count(
user: User,
db: impl Database + Clone + Send + Sync,
) -> Result<Box<dyn warp::Reply>, Infallible> {
db.count_history(&user).await.map_or(
Ok(Box::new(ErrorResponse::reply(
"failed to query history count",
StatusCode::INTERNAL_SERVER_ERROR,
))),
|count| Ok(Box::new(json(&CountResponse { count }))),
)
}
pub async fn list(
req: SyncHistoryRequest,
user: User,
db: impl Database + Clone + Send + Sync,
) -> Result<Box<dyn warp::Reply>, Infallible> {
let history = db
.list_history(
&user,
req.sync_ts.naive_utc(),
req.history_ts.naive_utc(),
req.host,
)
.await;
if let Err(e) = history {
error!("failed to load history: {}", e);
let resp =
ErrorResponse::reply("failed to load history", StatusCode::INTERNAL_SERVER_ERROR);
let resp = Box::new(resp);
return Ok(resp);
}
let history: Vec<String> = history
.unwrap()
.iter()
.map(|i| i.data.to_string())
.collect();
debug!(
"loaded {} items of history for user {}",
history.len(),
user.id
);
Ok(Box::new(json(&SyncHistoryResponse { history })))
}
pub async fn add(
req: Vec<AddHistoryRequest>,
user: User,
db: impl Database + Clone + Send + Sync,
) -> Result<Box<dyn warp::Reply>, Infallible> {
debug!("request to add {} history items", req.len());
let history: Vec<NewHistory> = req
.iter()
.map(|h| NewHistory {
client_id: h.id.as_str(),
user_id: user.id,
hostname: h.hostname.as_str(),
timestamp: h.timestamp.naive_utc(),
data: h.data.as_str(),
})
.collect();
if let Err(e) = db.add_history(&history).await {
error!("failed to add history: {}", e);
return Ok(Box::new(ErrorResponse::reply(
"failed to add history",
StatusCode::INTERNAL_SERVER_ERROR,
)));
};
Ok(Box::new(warp::reply()))
}

View File

@ -0,0 +1,6 @@
pub mod history;
pub mod user;
pub const fn index() -> &'static str {
"\"Through the fathomless deeps of space swims the star turtle Great A\u{2019}Tuin, bearing on its back the four giant elephants who carry on their shoulders the mass of the Discworld.\"\n\t-- Sir Terry Pratchett"
}

140
src/server/handlers/user.rs Normal file
View File

@ -0,0 +1,140 @@
use std::convert::Infallible;
use sodiumoxide::crypto::pwhash::argon2id13;
use uuid::Uuid;
use warp::http::StatusCode;
use warp::reply::json;
use crate::api::{
ErrorResponse, LoginRequest, LoginResponse, RegisterRequest, RegisterResponse, UserResponse,
};
use crate::server::database::Database;
use crate::server::models::{NewSession, NewUser};
use crate::settings::Settings;
use crate::utils::hash_secret;
pub fn verify_str(secret: &str, verify: &str) -> bool {
sodiumoxide::init().unwrap();
let mut padded = [0_u8; 128];
secret.as_bytes().iter().enumerate().for_each(|(i, val)| {
padded[i] = *val;
});
match argon2id13::HashedPassword::from_slice(&padded) {
Some(hp) => argon2id13::pwhash_verify(&hp, verify.as_bytes()),
None => false,
}
}
pub async fn get(
username: String,
db: impl Database + Clone + Send + Sync,
) -> Result<Box<dyn warp::Reply>, Infallible> {
let user = match db.get_user(username).await {
Ok(user) => user,
Err(e) => {
debug!("user not found: {}", e);
return Ok(Box::new(ErrorResponse::reply(
"user not found",
StatusCode::NOT_FOUND,
)));
}
};
Ok(Box::new(warp::reply::json(&UserResponse {
username: user.username,
})))
}
pub async fn register(
register: RegisterRequest,
settings: Settings,
db: impl Database + Clone + Send + Sync,
) -> Result<Box<dyn warp::Reply>, Infallible> {
if !settings.server.open_registration {
return Ok(Box::new(ErrorResponse::reply(
"this server is not open for registrations",
StatusCode::BAD_REQUEST,
)));
}
let hashed = hash_secret(register.password.as_str());
let new_user = NewUser {
email: register.email,
username: register.username,
password: hashed,
};
let user_id = match db.add_user(new_user).await {
Ok(id) => id,
Err(e) => {
error!("failed to add user: {}", e);
return Ok(Box::new(ErrorResponse::reply(
"failed to add user",
StatusCode::BAD_REQUEST,
)));
}
};
let token = Uuid::new_v4().to_simple().to_string();
let new_session = NewSession {
user_id,
token: token.as_str(),
};
match db.add_session(&new_session).await {
Ok(_) => Ok(Box::new(json(&RegisterResponse { session: token }))),
Err(e) => {
error!("failed to add session: {}", e);
Ok(Box::new(ErrorResponse::reply(
"failed to register user",
StatusCode::BAD_REQUEST,
)))
}
}
}
pub async fn login(
login: LoginRequest,
db: impl Database + Clone + Send + Sync,
) -> Result<Box<dyn warp::Reply>, Infallible> {
let user = match db.get_user(login.username.clone()).await {
Ok(u) => u,
Err(e) => {
error!("failed to get user {}: {}", login.username.clone(), e);
return Ok(Box::new(ErrorResponse::reply(
"user not found",
StatusCode::NOT_FOUND,
)));
}
};
let session = match db.get_user_session(&user).await {
Ok(u) => u,
Err(e) => {
error!("failed to get session for {}: {}", login.username, e);
return Ok(Box::new(ErrorResponse::reply(
"user not found",
StatusCode::NOT_FOUND,
)));
}
};
let verified = verify_str(user.password.as_str(), login.password.as_str());
if !verified {
return Ok(Box::new(ErrorResponse::reply(
"user not found",
StatusCode::NOT_FOUND,
)));
}
Ok(Box::new(warp::reply::json(&LoginResponse {
session: session.token,
})))
}

23
src/server/mod.rs Normal file
View File

@ -0,0 +1,23 @@
use std::net::IpAddr;
use eyre::Result;
use crate::settings::Settings;
pub mod auth;
pub mod database;
pub mod handlers;
pub mod models;
pub mod router;
pub async fn launch(settings: &Settings, host: String, port: u16) -> Result<()> {
// routes to run:
// index, register, add_history, login, get_user, sync_count, sync_list
let host = host.parse::<IpAddr>()?;
let r = router::router(settings).await?;
warp::serve(r).run((host, port)).await;
Ok(())
}

View File

@ -1,10 +1,6 @@
use chrono::prelude::*; use chrono::prelude::*;
use crate::schema::{history, sessions, users}; #[derive(sqlx::FromRow)]
#[derive(Deserialize, Serialize, Identifiable, Queryable, Associations)]
#[table_name = "history"]
#[belongs_to(User)]
pub struct History { pub struct History {
pub id: i64, pub id: i64,
pub client_id: String, // a client generated ID pub client_id: String, // a client generated ID
@ -17,7 +13,16 @@ pub struct History {
pub created_at: NaiveDateTime, pub created_at: NaiveDateTime,
} }
#[derive(Identifiable, Queryable, Associations)] pub struct NewHistory<'a> {
pub client_id: &'a str,
pub user_id: i64,
pub hostname: &'a str,
pub timestamp: chrono::NaiveDateTime,
pub data: &'a str,
}
#[derive(sqlx::FromRow)]
pub struct User { pub struct User {
pub id: i64, pub id: i64,
pub username: String, pub username: String,
@ -25,35 +30,19 @@ pub struct User {
pub password: String, pub password: String,
} }
#[derive(Queryable, Identifiable, Associations)] #[derive(sqlx::FromRow)]
#[belongs_to(User)]
pub struct Session { pub struct Session {
pub id: i64, pub id: i64,
pub user_id: i64, pub user_id: i64,
pub token: String, pub token: String,
} }
#[derive(Insertable)] pub struct NewUser {
#[table_name = "history"] pub username: String,
pub struct NewHistory<'a> { pub email: String,
pub client_id: &'a str, pub password: String,
pub user_id: i64,
pub hostname: String,
pub timestamp: chrono::NaiveDateTime,
pub data: &'a str,
} }
#[derive(Insertable)]
#[table_name = "users"]
pub struct NewUser<'a> {
pub username: &'a str,
pub email: &'a str,
pub password: &'a str,
}
#[derive(Insertable)]
#[table_name = "sessions"]
pub struct NewSession<'a> { pub struct NewSession<'a> {
pub user_id: i64, pub user_id: i64,
pub token: &'a str, pub token: &'a str,

121
src/server/router.rs Normal file
View File

@ -0,0 +1,121 @@
use std::convert::Infallible;
use eyre::Result;
use warp::Filter;
use super::handlers;
use super::{database::Database, database::Postgres};
use crate::server::models::User;
use crate::{api::SyncHistoryRequest, settings::Settings};
fn with_settings(
settings: Settings,
) -> impl Filter<Extract = (Settings,), Error = Infallible> + Clone {
warp::any().map(move || settings.clone())
}
fn with_db(
db: impl Database + Clone + Send + Sync,
) -> impl Filter<Extract = (impl Database + Clone,), Error = Infallible> + Clone {
warp::any().map(move || db.clone())
}
fn with_user(
postgres: Postgres,
) -> impl Filter<Extract = (User,), Error = warp::Rejection> + Clone {
warp::header::<String>("authorization").and_then(move |header: String| {
// async closures are still buggy :(
let postgres = postgres.clone();
async move {
let header: Vec<&str> = header.split(' ').collect();
let token;
if header.len() == 2 {
if header[0] != "Token" {
return Err(warp::reject());
}
token = header[1];
} else {
return Err(warp::reject());
}
let user = postgres
.get_session_user(token)
.await
.map_err(|_| warp::reject())?;
Ok(user)
}
})
}
pub async fn router(
settings: &Settings,
) -> Result<impl Filter<Extract = impl warp::Reply, Error = warp::Rejection> + Clone> {
let postgres = Postgres::new(settings.server.db_uri.as_str()).await?;
let index = warp::get().and(warp::path::end()).map(handlers::index);
let count = warp::get()
.and(warp::path("sync"))
.and(warp::path("count"))
.and(warp::path::end())
.and(with_user(postgres.clone()))
.and(with_db(postgres.clone()))
.and_then(handlers::history::count);
let sync = warp::get()
.and(warp::path("sync"))
.and(warp::path("history"))
.and(warp::query::<SyncHistoryRequest>())
.and(warp::path::end())
.and(with_user(postgres.clone()))
.and(with_db(postgres.clone()))
.and_then(handlers::history::list);
let add_history = warp::post()
.and(warp::path("history"))
.and(warp::path::end())
.and(warp::body::json())
.and(with_user(postgres.clone()))
.and(with_db(postgres.clone()))
.and_then(handlers::history::add);
let user = warp::get()
.and(warp::path("user"))
.and(warp::path::param::<String>())
.and(warp::path::end())
.and(with_db(postgres.clone()))
.and_then(handlers::user::get);
let register = warp::post()
.and(warp::path("register"))
.and(warp::path::end())
.and(warp::body::json())
.and(with_settings(settings.clone()))
.and(with_db(postgres.clone()))
.and_then(handlers::user::register);
let login = warp::post()
.and(warp::path("login"))
.and(warp::path::end())
.and(warp::body::json())
.and(with_db(postgres))
.and_then(handlers::user::login);
let r = warp::any()
.and(
index
.or(count)
.or(sync)
.or(add_history)
.or(user)
.or(register)
.or(login),
)
.with(warp::filters::log::log("atuin::api"));
Ok(r)
}

View File

@ -161,7 +161,7 @@ impl Settings {
// Finally, set the auth token // Finally, set the auth token
if Path::new(session_path.to_string().as_str()).exists() { if Path::new(session_path.to_string().as_str()).exists() {
let token = std::fs::read_to_string(session_path.to_string())?; let token = std::fs::read_to_string(session_path.to_string())?;
s.set("local.session_token", token)?; s.set("local.session_token", token.trim())?;
} else { } else {
s.set("local.session_token", "not logged in")?; s.set("local.session_token", "not logged in")?;
} }

View File

@ -16,6 +16,7 @@ _atuin_precmd(){
[[ -z "${ATUIN_HISTORY_ID}" ]] && return [[ -z "${ATUIN_HISTORY_ID}" ]] && return
atuin history end $ATUIN_HISTORY_ID --exit $EXIT atuin history end $ATUIN_HISTORY_ID --exit $EXIT
export ATUIN_HISTORY_ID=""
} }
_atuin_search(){ _atuin_search(){