diff --git a/Cargo.lock b/Cargo.lock index e333a6631f..fb45198ee1 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2477,6 +2477,7 @@ dependencies = [ "miette 4.2.1", "nu-json", "num-format", + "regex", "serde", "serde_json", "sys-locale", diff --git a/crates/nu-command/src/dataframe/values/nu_dataframe/between_values.rs b/crates/nu-command/src/dataframe/values/nu_dataframe/between_values.rs index da52985edd..96a2b10118 100644 --- a/crates/nu-command/src/dataframe/values/nu_dataframe/between_values.rs +++ b/crates/nu-command/src/dataframe/values/nu_dataframe/between_values.rs @@ -366,7 +366,8 @@ pub(super) fn compute_series_single_value( rhs_span: right.span()?, }), }, - Operator::Contains => match &right { + // TODO: update this to do a regex match instead of a simple contains? + Operator::RegexMatch => match &right { Value::String { val, .. } => contains_series_pat(&lhs, val, lhs_span), _ => Err(ShellError::OperatorMismatch { op_span: operator.span, diff --git a/crates/nu-engine/src/eval.rs b/crates/nu-engine/src/eval.rs index 1a38c553c5..d3d2241e49 100644 --- a/crates/nu-engine/src/eval.rs +++ b/crates/nu-engine/src/eval.rs @@ -405,13 +405,13 @@ pub fn eval_expression( let rhs = eval_expression(engine_state, stack, rhs)?; lhs.not_in(op_span, &rhs) } - Operator::Contains => { + Operator::RegexMatch => { let rhs = eval_expression(engine_state, stack, rhs)?; - lhs.contains(op_span, &rhs) + lhs.regex_match(op_span, &rhs, false) } - Operator::NotContains => { + Operator::NotRegexMatch => { let rhs = eval_expression(engine_state, stack, rhs)?; - lhs.not_contains(op_span, &rhs) + lhs.regex_match(op_span, &rhs, true) } Operator::Modulo => { let rhs = eval_expression(engine_state, stack, rhs)?; diff --git a/crates/nu-parser/src/parser.rs b/crates/nu-parser/src/parser.rs index 42169cd874..35db9ec20f 100644 --- a/crates/nu-parser/src/parser.rs +++ b/crates/nu-parser/src/parser.rs @@ -3967,9 +3967,9 @@ pub fn parse_operator( b"<=" => Operator::LessThanOrEqual, b">" => Operator::GreaterThan, b">=" => Operator::GreaterThanOrEqual, - b"=~" => Operator::Contains, + b"=~" => Operator::RegexMatch, b"=^" => Operator::StartsWith, - b"!~" => Operator::NotContains, + b"!~" => Operator::NotRegexMatch, b"+" => Operator::Plus, b"-" => Operator::Minus, b"*" => Operator::Multiply, diff --git a/crates/nu-parser/src/type_check.rs b/crates/nu-parser/src/type_check.rs index 8d9cf20fa2..11a16d3ea6 100644 --- a/crates/nu-parser/src/type_check.rs +++ b/crates/nu-parser/src/type_check.rs @@ -283,7 +283,7 @@ pub fn math_result_type( }, Operator::Equal => (Type::Bool, None), Operator::NotEqual => (Type::Bool, None), - Operator::Contains => match (&lhs.ty, &rhs.ty) { + Operator::RegexMatch => match (&lhs.ty, &rhs.ty) { (Type::String, Type::String) => (Type::Bool, None), (Type::Any, _) => (Type::Bool, None), (_, Type::Any) => (Type::Bool, None), @@ -301,7 +301,7 @@ pub fn math_result_type( ) } }, - Operator::NotContains => match (&lhs.ty, &rhs.ty) { + Operator::NotRegexMatch => match (&lhs.ty, &rhs.ty) { (Type::String, Type::String) => (Type::Bool, None), (Type::Any, _) => (Type::Bool, None), (_, Type::Any) => (Type::Bool, None), diff --git a/crates/nu-protocol/Cargo.toml b/crates/nu-protocol/Cargo.toml index 8ceffd47db..0f9fb0b546 100644 --- a/crates/nu-protocol/Cargo.toml +++ b/crates/nu-protocol/Cargo.toml @@ -22,6 +22,7 @@ nu-json = { path = "../nu-json", version = "0.60.1" } typetag = "0.1.8" num-format = "0.4.0" sys-locale = "0.2.0" +regex = "1.5.4" [features] plugin = ["serde_json"] diff --git a/crates/nu-protocol/src/ast/expression.rs b/crates/nu-protocol/src/ast/expression.rs index b0e92fec7f..3fe59a1777 100644 --- a/crates/nu-protocol/src/ast/expression.rs +++ b/crates/nu-protocol/src/ast/expression.rs @@ -32,8 +32,8 @@ impl Expression { Operator::Pow => 100, Operator::Multiply | Operator::Divide | Operator::Modulo => 95, Operator::Plus | Operator::Minus => 90, - Operator::NotContains - | Operator::Contains + Operator::NotRegexMatch + | Operator::RegexMatch | Operator::StartsWith | Operator::LessThan | Operator::LessThanOrEqual diff --git a/crates/nu-protocol/src/ast/operator.rs b/crates/nu-protocol/src/ast/operator.rs index c01ccd39eb..9e18683ab3 100644 --- a/crates/nu-protocol/src/ast/operator.rs +++ b/crates/nu-protocol/src/ast/operator.rs @@ -11,8 +11,8 @@ pub enum Operator { GreaterThan, LessThanOrEqual, GreaterThanOrEqual, - Contains, - NotContains, + RegexMatch, + NotRegexMatch, Plus, Minus, Multiply, @@ -33,8 +33,8 @@ impl Display for Operator { Operator::NotEqual => write!(f, "!="), Operator::LessThan => write!(f, "<"), Operator::GreaterThan => write!(f, ">"), - Operator::Contains => write!(f, "=~"), - Operator::NotContains => write!(f, "!~"), + Operator::RegexMatch => write!(f, "=~"), + Operator::NotRegexMatch => write!(f, "!~"), Operator::Plus => write!(f, "+"), Operator::Minus => write!(f, "-"), Operator::Multiply => write!(f, "*"), diff --git a/crates/nu-protocol/src/value/mod.rs b/crates/nu-protocol/src/value/mod.rs index 382f3685ba..bf712c4f7e 100644 --- a/crates/nu-protocol/src/value/mod.rs +++ b/crates/nu-protocol/src/value/mod.rs @@ -12,6 +12,7 @@ pub use from_value::FromValue; use indexmap::map::IndexMap; use num_format::{Locale, ToFormattedString}; pub use range::*; +use regex::Regex; use serde::{Deserialize, Serialize}; pub use stream::*; use sys_locale::get_locale; @@ -2029,17 +2030,38 @@ impl Value { } } - pub fn contains(&self, op: Span, rhs: &Value) -> Result { + pub fn regex_match(&self, op: Span, rhs: &Value, invert: bool) -> Result { let span = span(&[self.span()?, rhs.span()?]); match (self, rhs) { - (Value::String { val: lhs, .. }, Value::String { val: rhs, .. }) => Ok(Value::Bool { - val: lhs.contains(rhs), - span, - }), - (Value::CustomValue { val: lhs, span }, rhs) => { - lhs.operation(*span, Operator::Contains, op, rhs) + ( + Value::String { val: lhs, .. }, + Value::String { + val: rhs, + span: rhs_span, + }, + ) => { + // We are leaving some performance on the table by compiling the regex every time. + // Small regexes compile in microseconds, and the simplicity of this approach currently + // outweighs the performance costs. Revisit this if it ever becomes a bottleneck. + let regex = Regex::new(rhs) + .map_err(|e| ShellError::UnsupportedInput(format!("{e}"), *rhs_span))?; + let is_match = regex.is_match(lhs); + Ok(Value::Bool { + val: if invert { !is_match } else { is_match }, + span, + }) } + (Value::CustomValue { val: lhs, span }, rhs) => lhs.operation( + *span, + if invert { + Operator::NotRegexMatch + } else { + Operator::RegexMatch + }, + op, + rhs, + ), _ => Err(ShellError::OperatorMismatch { op_span: op, lhs_ty: self.get_type(), @@ -2071,27 +2093,6 @@ impl Value { } } - pub fn not_contains(&self, op: Span, rhs: &Value) -> Result { - let span = span(&[self.span()?, rhs.span()?]); - - match (self, rhs) { - (Value::String { val: lhs, .. }, Value::String { val: rhs, .. }) => Ok(Value::Bool { - val: !lhs.contains(rhs), - span, - }), - (Value::CustomValue { val: lhs, span }, rhs) => { - lhs.operation(*span, Operator::NotContains, op, rhs) - } - _ => Err(ShellError::OperatorMismatch { - op_span: op, - lhs_ty: self.get_type(), - lhs_span: self.span()?, - rhs_ty: rhs.get_type(), - rhs_span: rhs.span()?, - }), - } - } - pub fn modulo(&self, op: Span, rhs: &Value) -> Result { let span = span(&[self.span()?, rhs.span()?]); diff --git a/src/tests.rs b/src/tests.rs index a41fed6ea9..eb1575a035 100644 --- a/src/tests.rs +++ b/src/tests.rs @@ -9,6 +9,7 @@ mod test_math; mod test_modules; mod test_parser; mod test_ranges; +mod test_regex; mod test_strings; mod test_table_operations; mod test_type_check; diff --git a/src/tests/test_regex.rs b/src/tests/test_regex.rs new file mode 100644 index 0000000000..0f26f7dfde --- /dev/null +++ b/src/tests/test_regex.rs @@ -0,0 +1,82 @@ +use crate::tests::{fail_test, run_test, TestResult}; + +#[test] +fn contains() -> TestResult { + run_test(r#"'foobarbaz' =~ bar"#, "true") +} + +#[test] +fn contains_case_insensitive() -> TestResult { + run_test(r#"'foobarbaz' =~ '(?i)BaR'"#, "true") +} + +#[test] +fn not_contains() -> TestResult { + run_test(r#"'foobarbaz' !~ asdf"#, "true") +} + +#[test] +fn match_full_line() -> TestResult { + run_test(r#"'foobarbaz' =~ '^foobarbaz$'"#, "true") +} + +#[test] +fn not_match_full_line() -> TestResult { + run_test(r#"'foobarbaz' !~ '^foobarbaz$'"#, "false") +} + +#[test] +fn starts_with() -> TestResult { + run_test(r#"'foobarbaz' =~ '^foo'"#, "true") +} + +#[test] +fn not_starts_with() -> TestResult { + run_test(r#"'foobarbaz' !~ '^foo'"#, "false") +} + +#[test] +fn ends_with() -> TestResult { + run_test(r#"'foobarbaz' =~ 'baz$'"#, "true") +} + +#[test] +fn not_ends_with() -> TestResult { + run_test(r#"'foobarbaz' !~ 'baz$'"#, "false") +} + +#[test] +fn where_works() -> TestResult { + run_test( + r#"[{name: somefile.txt} {name: anotherfile.csv }] | where name =~ ^s | get name.0"#, + "somefile.txt", + ) +} + +#[test] +fn where_not_works() -> TestResult { + run_test( + r#"[{name: somefile.txt} {name: anotherfile.csv }] | where name !~ ^s | get name.0"#, + "anotherfile.csv", + ) +} + +#[test] +fn invalid_regex_fails() -> TestResult { + fail_test(r#"'foo' =~ '['"#, "regex parse error") +} + +#[test] +fn invalid_not_regex_fails() -> TestResult { + fail_test(r#"'foo' !~ '['"#, "regex parse error") +} + +#[test] +fn regex_on_int_fails() -> TestResult { + fail_test(r#"33 =~ foo"#, "Types mismatched") +} + +#[test] +fn not_regex_on_int_fails() -> TestResult { + fail_test(r#"33 !~ foo"#, "Types mismatched") +}