feat: support regex with r/.../ syntax (#1745)

* feat: support regex with r/.../ syntax

* cargo fmt

* feat(tests): add some tests for regex matching
This commit is contained in:
依云 2024-03-01 21:21:53 +08:00 committed by GitHub
parent 897af9a326
commit aec5df4123
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 125 additions and 18 deletions

2
Cargo.lock generated
View File

@ -245,7 +245,6 @@ dependencies = [
"indicatif",
"interim",
"itertools",
"lazy_static",
"log",
"memchr",
"minspan",
@ -3510,6 +3509,7 @@ dependencies = [
"libsqlite3-sys",
"log",
"percent-encoding",
"regex",
"serde",
"sqlx-core",
"time",

View File

@ -37,13 +37,12 @@ async-trait = { workspace = true }
itertools = { workspace = true }
rand = { workspace = true }
shellexpand = "3"
sqlx = { workspace = true, features = ["sqlite"] }
sqlx = { workspace = true, features = ["sqlite", "regexp"] }
minspan = "0.1.1"
regex = "1.9.1"
serde_regex = "1.1.0"
fs-err = { workspace = true }
sql-builder = "3"
lazy_static = "1"
memchr = "2.5"
rmp = { version = "0.8.11" }
typed-builder = { workspace = true }

View File

@ -1,4 +1,5 @@
use std::{
borrow::Cow,
env,
path::{Path, PathBuf},
str::FromStr,
@ -9,10 +10,8 @@ use async_trait::async_trait;
use atuin_common::utils;
use fs_err as fs;
use itertools::Itertools;
use lazy_static::lazy_static;
use rand::{distributions::Alphanumeric, Rng};
use regex::Regex;
use sql_builder::{esc, quote, SqlBuilder, SqlName};
use sql_builder::{bind::Bind, esc, quote, SqlBuilder, SqlName};
use sqlx::{
sqlite::{
SqliteConnectOptions, SqliteJournalMode, SqlitePool, SqlitePoolOptions, SqliteRow,
@ -142,6 +141,7 @@ impl Sqlite {
.journal_mode(SqliteJournalMode::Wal)
.optimize_on_close(true, None)
.synchronous(SqliteSynchronous::Normal)
.with_regexp()
.create_if_missing(true);
let pool = SqlitePoolOptions::new()
@ -428,18 +428,42 @@ impl Database for Sqlite {
};
let orig_query = query;
let query = query.replace('*', "%"); // allow wildcard char
let mut regexes = Vec::new();
match search_mode {
SearchMode::Prefix => sql.and_where_like_left("command", query),
SearchMode::Prefix => sql.and_where_like_left("command", query.replace('*', "%")),
_ => {
// don't recompile the regex on successive calls!
lazy_static! {
static ref SPLIT_REGEX: Regex = Regex::new(r" +").unwrap();
}
let mut is_or = false;
for query_part in SPLIT_REGEX.split(query.as_str()) {
let mut regex = None;
for part in query.split_inclusive(' ') {
let query_part: Cow<str> = match (&mut regex, part.starts_with("r/")) {
(None, false) => {
if part.trim_end().is_empty() {
continue;
}
Cow::Owned(part.trim_end().replace('*', "%")) // allow wildcard char
}
(None, true) => {
if part[2..].trim_end().ends_with('/') {
let end_pos = part.trim_end().len() - 1;
regexes.push(String::from(&part[2..end_pos]));
} else {
regex = Some(String::from(&part[2..]));
}
continue;
}
(Some(r), _) => {
if part.trim_end().ends_with('/') {
let end_pos = part.trim_end().len() - 1;
r.push_str(&part.trim_end()[..end_pos]);
regexes.push(regex.take().unwrap());
} else {
r.push_str(part);
}
continue;
}
};
// TODO smart case mode could be made configurable like in fzf
let (is_glob, glob) = if query_part.contains(char::is_uppercase) {
(true, "*")
@ -448,7 +472,7 @@ impl Database for Sqlite {
};
let (is_inverse, query_part) = match query_part.strip_prefix('!') {
Some(stripped) => (true, stripped),
Some(stripped) => (true, Cow::Borrowed(stripped)),
None => (false, query_part),
};
@ -477,10 +501,18 @@ impl Database for Sqlite {
sql.fuzzy_condition("command", param, is_inverse, is_glob, is_or);
is_or = false;
}
if let Some(r) = regex {
regexes.push(r);
}
&mut sql
}
};
for regex in regexes {
sql.and_where("command regexp ?".bind(&regex));
}
filter_options
.exit
.map(|exit| sql.and_where_eq("exit", exit));
@ -825,6 +857,71 @@ mod test {
assert_search_eq(&db, SearchMode::FullText, FilterMode::Global, "hm", 0)
.await
.unwrap();
// regex
assert_search_eq(&db, SearchMode::FullText, FilterMode::Global, "r/^ls ", 1)
.await
.unwrap();
assert_search_eq(
&db,
SearchMode::FullText,
FilterMode::Global,
"r/ls / ie$",
1,
)
.await
.unwrap();
assert_search_eq(
&db,
SearchMode::FullText,
FilterMode::Global,
"r/ls / !ie",
0,
)
.await
.unwrap();
assert_search_eq(
&db,
SearchMode::FullText,
FilterMode::Global,
"meow r/ls/",
0,
)
.await
.unwrap();
assert_search_eq(&db, SearchMode::FullText, FilterMode::Global, "r//hom/", 1)
.await
.unwrap();
assert_search_eq(
&db,
SearchMode::FullText,
FilterMode::Global,
"r//home//",
1,
)
.await
.unwrap();
assert_search_eq(
&db,
SearchMode::FullText,
FilterMode::Global,
"r//home///",
0,
)
.await
.unwrap();
assert_search_eq(&db, SearchMode::FullText, FilterMode::Global, "/home.*e", 0)
.await
.unwrap();
assert_search_eq(
&db,
SearchMode::FullText,
FilterMode::Global,
"r/home.*e",
1,
)
.await
.unwrap();
}
#[tokio::test(flavor = "multi_thread")]
@ -915,6 +1012,17 @@ mod test {
assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "Ellie", 1)
.await
.unwrap();
// regex
assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "r/^ls ", 2)
.await
.unwrap();
assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "r/[Ee]llie", 3)
.await
.unwrap();
assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "/h/e r/^ls ", 1)
.await
.unwrap();
}
#[tokio::test(flavor = "multi_thread")]

View File

@ -26,8 +26,8 @@ impl SearchEngine for Search {
..Default::default()
},
)
.await?
.into_iter()
.collect::<Vec<_>>())
.await
// ignore errors as it may be caused by incomplete regex
.map_or(Vec::new(), |r| r.into_iter().collect()))
}
}