diff --git a/Cargo.lock b/Cargo.lock index ab0eaa31..50ffc518 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -245,7 +245,6 @@ dependencies = [ "indicatif", "interim", "itertools", - "lazy_static", "log", "memchr", "minspan", @@ -3510,6 +3509,7 @@ dependencies = [ "libsqlite3-sys", "log", "percent-encoding", + "regex", "serde", "sqlx-core", "time", diff --git a/atuin-client/Cargo.toml b/atuin-client/Cargo.toml index 51227044..e2353daf 100644 --- a/atuin-client/Cargo.toml +++ b/atuin-client/Cargo.toml @@ -37,13 +37,12 @@ async-trait = { workspace = true } itertools = { workspace = true } rand = { workspace = true } shellexpand = "3" -sqlx = { workspace = true, features = ["sqlite"] } +sqlx = { workspace = true, features = ["sqlite", "regexp"] } minspan = "0.1.1" regex = "1.9.1" serde_regex = "1.1.0" fs-err = { workspace = true } sql-builder = "3" -lazy_static = "1" memchr = "2.5" rmp = { version = "0.8.11" } typed-builder = { workspace = true } diff --git a/atuin-client/src/database.rs b/atuin-client/src/database.rs index 572955f8..2be27ac8 100644 --- a/atuin-client/src/database.rs +++ b/atuin-client/src/database.rs @@ -1,4 +1,5 @@ use std::{ + borrow::Cow, env, path::{Path, PathBuf}, str::FromStr, @@ -9,10 +10,8 @@ use async_trait::async_trait; use atuin_common::utils; use fs_err as fs; use itertools::Itertools; -use lazy_static::lazy_static; use rand::{distributions::Alphanumeric, Rng}; -use regex::Regex; -use sql_builder::{esc, quote, SqlBuilder, SqlName}; +use sql_builder::{bind::Bind, esc, quote, SqlBuilder, SqlName}; use sqlx::{ sqlite::{ SqliteConnectOptions, SqliteJournalMode, SqlitePool, SqlitePoolOptions, SqliteRow, @@ -142,6 +141,7 @@ impl Sqlite { .journal_mode(SqliteJournalMode::Wal) .optimize_on_close(true, None) .synchronous(SqliteSynchronous::Normal) + .with_regexp() .create_if_missing(true); let pool = SqlitePoolOptions::new() @@ -428,18 +428,42 @@ impl Database for Sqlite { }; let orig_query = query; - let query = query.replace('*', "%"); // allow wildcard char + let mut regexes = Vec::new(); match search_mode { - SearchMode::Prefix => sql.and_where_like_left("command", query), + SearchMode::Prefix => sql.and_where_like_left("command", query.replace('*', "%")), _ => { - // don't recompile the regex on successive calls! - lazy_static! { - static ref SPLIT_REGEX: Regex = Regex::new(r" +").unwrap(); - } - let mut is_or = false; - for query_part in SPLIT_REGEX.split(query.as_str()) { + let mut regex = None; + for part in query.split_inclusive(' ') { + let query_part: Cow = match (&mut regex, part.starts_with("r/")) { + (None, false) => { + if part.trim_end().is_empty() { + continue; + } + Cow::Owned(part.trim_end().replace('*', "%")) // allow wildcard char + } + (None, true) => { + if part[2..].trim_end().ends_with('/') { + let end_pos = part.trim_end().len() - 1; + regexes.push(String::from(&part[2..end_pos])); + } else { + regex = Some(String::from(&part[2..])); + } + continue; + } + (Some(r), _) => { + if part.trim_end().ends_with('/') { + let end_pos = part.trim_end().len() - 1; + r.push_str(&part.trim_end()[..end_pos]); + regexes.push(regex.take().unwrap()); + } else { + r.push_str(part); + } + continue; + } + }; + // TODO smart case mode could be made configurable like in fzf let (is_glob, glob) = if query_part.contains(char::is_uppercase) { (true, "*") @@ -448,7 +472,7 @@ impl Database for Sqlite { }; let (is_inverse, query_part) = match query_part.strip_prefix('!') { - Some(stripped) => (true, stripped), + Some(stripped) => (true, Cow::Borrowed(stripped)), None => (false, query_part), }; @@ -477,10 +501,18 @@ impl Database for Sqlite { sql.fuzzy_condition("command", param, is_inverse, is_glob, is_or); is_or = false; } + if let Some(r) = regex { + regexes.push(r); + } + &mut sql } }; + for regex in regexes { + sql.and_where("command regexp ?".bind(®ex)); + } + filter_options .exit .map(|exit| sql.and_where_eq("exit", exit)); @@ -825,6 +857,71 @@ mod test { assert_search_eq(&db, SearchMode::FullText, FilterMode::Global, "hm", 0) .await .unwrap(); + + // regex + assert_search_eq(&db, SearchMode::FullText, FilterMode::Global, "r/^ls ", 1) + .await + .unwrap(); + assert_search_eq( + &db, + SearchMode::FullText, + FilterMode::Global, + "r/ls / ie$", + 1, + ) + .await + .unwrap(); + assert_search_eq( + &db, + SearchMode::FullText, + FilterMode::Global, + "r/ls / !ie", + 0, + ) + .await + .unwrap(); + assert_search_eq( + &db, + SearchMode::FullText, + FilterMode::Global, + "meow r/ls/", + 0, + ) + .await + .unwrap(); + assert_search_eq(&db, SearchMode::FullText, FilterMode::Global, "r//hom/", 1) + .await + .unwrap(); + assert_search_eq( + &db, + SearchMode::FullText, + FilterMode::Global, + "r//home//", + 1, + ) + .await + .unwrap(); + assert_search_eq( + &db, + SearchMode::FullText, + FilterMode::Global, + "r//home///", + 0, + ) + .await + .unwrap(); + assert_search_eq(&db, SearchMode::FullText, FilterMode::Global, "/home.*e", 0) + .await + .unwrap(); + assert_search_eq( + &db, + SearchMode::FullText, + FilterMode::Global, + "r/home.*e", + 1, + ) + .await + .unwrap(); } #[tokio::test(flavor = "multi_thread")] @@ -915,6 +1012,17 @@ mod test { assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "Ellie", 1) .await .unwrap(); + + // regex + assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "r/^ls ", 2) + .await + .unwrap(); + assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "r/[Ee]llie", 3) + .await + .unwrap(); + assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "/h/e r/^ls ", 1) + .await + .unwrap(); } #[tokio::test(flavor = "multi_thread")] diff --git a/atuin/src/command/client/search/engines/db.rs b/atuin/src/command/client/search/engines/db.rs index b4f24561..e638f9d9 100644 --- a/atuin/src/command/client/search/engines/db.rs +++ b/atuin/src/command/client/search/engines/db.rs @@ -26,8 +26,8 @@ impl SearchEngine for Search { ..Default::default() }, ) - .await? - .into_iter() - .collect::>()) + .await + // ignore errors as it may be caused by incomplete regex + .map_or(Vec::new(), |r| r.into_iter().collect())) } }