diff --git a/src/evaluate/evaluator.rs b/src/evaluate/evaluator.rs index 1930b64c82..8c5f23f1a5 100644 --- a/src/evaluate/evaluator.rs +++ b/src/evaluate/evaluator.rs @@ -1,5 +1,6 @@ use crate::data::base::Block; use crate::errors::ArgumentError; +use crate::evaluate::operator::apply_operator; use crate::parser::hir::path::{ColumnPath, RawPathMember}; use crate::parser::{ hir::{self, Expression, RawExpression}, @@ -79,8 +80,8 @@ pub(crate) fn evaluate_baseline_expr( trace!("left={:?} right={:?}", left.item, right.item); - match left.compare(binary.op(), &*right) { - Ok(result) => Ok(Value::boolean(result).tagged(tag)), + match apply_operator(binary.op(), &*left, &*right) { + Ok(result) => Ok(result.tagged(tag)), Err((left_type, right_type)) => Err(ShellError::coerce_error( left_type.spanned(binary.left().span), right_type.spanned(binary.right().span), diff --git a/src/evaluate/mod.rs b/src/evaluate/mod.rs index 21a3b369d8..f8133808e0 100644 --- a/src/evaluate/mod.rs +++ b/src/evaluate/mod.rs @@ -1,3 +1,4 @@ pub(crate) mod evaluator; +pub(crate) mod operator; pub(crate) use evaluator::{evaluate_baseline_expr, Scope}; diff --git a/src/evaluate/operator.rs b/src/evaluate/operator.rs new file mode 100644 index 0000000000..d73e122bc9 --- /dev/null +++ b/src/evaluate/operator.rs @@ -0,0 +1,33 @@ +use crate::data::Primitive; +use crate::data::Value; +use crate::parser::Operator; +use crate::traits::ShellTypeName; +use std::ops::Not; + +pub fn apply_operator( + op: &Operator, + left: &Value, + right: &Value, +) -> Result { + match *op { + Operator::Equal + | Operator::NotEqual + | Operator::LessThan + | Operator::GreaterThan + | Operator::LessThanOrEqual + | Operator::GreaterThanOrEqual => left.compare(op, right).map(Value::boolean), + Operator::Dot => Ok(Value::boolean(false)), + Operator::Contains => contains(left, right).map(Value::boolean), + Operator::NotContains => contains(left, right).map(Not::not).map(Value::boolean), + } +} + +fn contains(left: &Value, right: &Value) -> Result { + if let (Value::Primitive(Primitive::String(l)), Value::Primitive(Primitive::String(r))) = + (left, right) + { + Ok(l.contains(r)) + } else { + Err((left.type_name(), right.type_name())) + } +} diff --git a/src/parser/parse/operator.rs b/src/parser/parse/operator.rs index 47c63075af..0a596e5897 100644 --- a/src/parser/parse/operator.rs +++ b/src/parser/parse/operator.rs @@ -12,6 +12,8 @@ pub enum Operator { LessThanOrEqual, GreaterThanOrEqual, Dot, + Contains, + NotContains, } impl FormatDebug for Operator { @@ -34,6 +36,8 @@ impl Operator { Operator::LessThanOrEqual => "<=", Operator::GreaterThanOrEqual => ">=", Operator::Dot => ".", + Operator::Contains => "=~", + Operator::NotContains => "!~", } } } @@ -55,6 +59,8 @@ impl FromStr for Operator { "<=" => Ok(Operator::LessThanOrEqual), ">=" => Ok(Operator::GreaterThanOrEqual), "." => Ok(Operator::Dot), + "=~" => Ok(Operator::Contains), + "!~" => Ok(Operator::NotContains), _ => Err(()), } } diff --git a/src/parser/parse/parser.rs b/src/parser/parse/parser.rs index e5fe60559a..c352cc4cf9 100644 --- a/src/parser/parse/parser.rs +++ b/src/parser/parse/parser.rs @@ -59,7 +59,7 @@ macro_rules! operator { #[tracable_parser] pub fn $name(input: NomSpan) -> IResult { let start = input.offset; - let (input, tag) = tag(stringify!($token))(input)?; + let (input, tag) = tag($token)(input)?; let end = input.offset; Ok(( @@ -70,13 +70,15 @@ macro_rules! operator { }; } -operator! { gt: > } -operator! { lt: < } -operator! { gte: >= } -operator! { lte: <= } -operator! { eq: == } -operator! { neq: != } -operator! { dot: . } +operator! { gt: ">" } +operator! { lt: "<" } +operator! { gte: ">=" } +operator! { lte: "<=" } +operator! { eq: "==" } +operator! { neq: "!=" } +operator! { dot: "." } +operator! { cont: "=~" } +operator! { ncont: "!~" } #[derive(Debug, Clone, Eq, PartialEq, Hash, Ord, PartialOrd, Serialize, Deserialize)] pub enum Number { @@ -228,7 +230,7 @@ pub fn raw_number(input: NomSpan) -> IResult> { #[tracable_parser] pub fn operator(input: NomSpan) -> IResult { - let (input, operator) = alt((gte, lte, neq, gt, lt, eq))(input)?; + let (input, operator) = alt((gte, lte, neq, gt, lt, eq, cont, ncont))(input)?; Ok((input, operator)) } @@ -830,6 +832,16 @@ mod tests { "!=" -> b::token_list(vec![b::op("!=")]) } + + equal_tokens! { + + "=~" -> b::token_list(vec![b::op("=~")]) + } + + equal_tokens! { + + "!~" -> b::token_list(vec![b::op("!~")]) + } } #[test] diff --git a/tests/filter_where_tests.rs b/tests/filter_where_tests.rs new file mode 100644 index 0000000000..e802607d85 --- /dev/null +++ b/tests/filter_where_tests.rs @@ -0,0 +1,112 @@ +mod helpers; + +use helpers as h; + +#[test] +fn test_compare() { + let actual = nu!( + cwd: "tests/fixtures/formats", h::pipeline( + r#" + open sample.db + | where table_name == ints + | get table_values + | first 4 + | where z > 4200 + | get z + | echo $it + "# + )); + + assert_eq!(actual, "4253"); + + let actual = nu!( + cwd: "tests/fixtures/formats", h::pipeline( + r#" + open sample.db + | where table_name == ints + | get table_values + | first 4 + | where z >= 4253 + | get z + | echo $it + "# + )); + + assert_eq!(actual, "4253"); + + let actual = nu!( + cwd: "tests/fixtures/formats", h::pipeline( + r#" + open sample.db + | where table_name == ints + | get table_values + | first 4 + | where z < 10 + | get z + | echo $it + "# + )); + + assert_eq!(actual, "1"); + + let actual = nu!( + cwd: "tests/fixtures/formats", h::pipeline( + r#" + open sample.db + | where table_name == ints + | get table_values + | first 4 + | where z <= 1 + | get z + | echo $it + "# + )); + + assert_eq!(actual, "1"); + + let actual = nu!( + cwd: "tests/fixtures/formats", h::pipeline( + r#" + open sample.db + | where table_name == ints + | get table_values + | where z != 1 + | first 1 + | get z + | echo $it + "# + )); + + assert_eq!(actual, "42"); +} + +#[test] +fn test_contains() { + let actual = nu!( + cwd: "tests/fixtures/formats", h::pipeline( + r#" + open sample.db + | where table_name == strings + | get table_values + | where x =~ ell + | count + | echo $it + "# + )); + + assert_eq!(actual, "4"); + + let actual = nu!( + cwd: "tests/fixtures/formats", h::pipeline( + r#" + open sample.db + | where table_name == strings + | get table_values + | where x !~ ell + | count + | echo $it + "# + )); + + assert_eq!(actual, "2"); +}