diff --git a/crates/nu-cli/src/data/value.rs b/crates/nu-cli/src/data/value.rs index d4521bf92..385adf962 100644 --- a/crates/nu-cli/src/data/value.rs +++ b/crates/nu-cli/src/data/value.rs @@ -8,6 +8,7 @@ use nu_protocol::ShellTypeName; use nu_protocol::{Primitive, Type, UntaggedValue}; use nu_source::{DebugDocBuilder, PrettyDebug, Tagged}; use nu_table::TextStyle; +use num_traits::Zero; pub fn date_from_str(s: Tagged<&str>) -> Result { let date = DateTime::parse_from_rfc3339(s.item).map_err(|err| { @@ -35,6 +36,10 @@ pub fn merge_values( } } +fn zero_division_error() -> UntaggedValue { + UntaggedValue::Error(ShellError::untagged_runtime_error("division by zero")) +} + pub fn compute_values( operator: Operator, left: &UntaggedValue, @@ -55,7 +60,9 @@ pub fn compute_values( Operator::Minus => Ok(UntaggedValue::Primitive(Primitive::Int(x - y))), Operator::Multiply => Ok(UntaggedValue::Primitive(Primitive::Int(x * y))), Operator::Divide => { - if x - (y * (x / y)) == num_bigint::BigInt::from(0) { + if y.is_zero() { + Ok(zero_division_error()) + } else if x - (y * (x / y)) == num_bigint::BigInt::from(0) { Ok(UntaggedValue::Primitive(Primitive::Int(x / y))) } else { Ok(UntaggedValue::Primitive(Primitive::Decimal( @@ -71,7 +78,12 @@ pub fn compute_values( Operator::Plus => Ok(x + bigdecimal::BigDecimal::from(y.clone())), Operator::Minus => Ok(x - bigdecimal::BigDecimal::from(y.clone())), Operator::Multiply => Ok(x * bigdecimal::BigDecimal::from(y.clone())), - Operator::Divide => Ok(x / bigdecimal::BigDecimal::from(y.clone())), + Operator::Divide => { + if y.is_zero() { + return Ok(zero_division_error()); + } + Ok(x / bigdecimal::BigDecimal::from(y.clone())) + } _ => Err((left.type_name(), right.type_name())), }?; Ok(UntaggedValue::Primitive(Primitive::Decimal(result))) @@ -81,7 +93,12 @@ pub fn compute_values( Operator::Plus => Ok(bigdecimal::BigDecimal::from(x.clone()) + y), Operator::Minus => Ok(bigdecimal::BigDecimal::from(x.clone()) - y), Operator::Multiply => Ok(bigdecimal::BigDecimal::from(x.clone()) * y), - Operator::Divide => Ok(bigdecimal::BigDecimal::from(x.clone()) / y), + Operator::Divide => { + if y.is_zero() { + return Ok(zero_division_error()); + } + Ok(bigdecimal::BigDecimal::from(x.clone()) / y) + } _ => Err((left.type_name(), right.type_name())), }?; Ok(UntaggedValue::Primitive(Primitive::Decimal(result))) @@ -91,7 +108,12 @@ pub fn compute_values( Operator::Plus => Ok(x + y), Operator::Minus => Ok(x - y), Operator::Multiply => Ok(x * y), - Operator::Divide => Ok(x / y), + Operator::Divide => { + if y.is_zero() { + return Ok(zero_division_error()); + } + Ok(x / y) + } _ => Err((left.type_name(), right.type_name())), }?; Ok(UntaggedValue::Primitive(Primitive::Decimal(result))) diff --git a/crates/nu-cli/src/evaluate/evaluator.rs b/crates/nu-cli/src/evaluate/evaluator.rs index 026e7b466..b8f84d9e0 100644 --- a/crates/nu-cli/src/evaluate/evaluator.rs +++ b/crates/nu-cli/src/evaluate/evaluator.rs @@ -47,7 +47,10 @@ pub(crate) async fn evaluate_baseline_expr( match binary.op.expr { Expression::Literal(hir::Literal::Operator(op)) => { match apply_operator(op, &left, &right) { - Ok(result) => Ok(result.into_value(tag)), + Ok(result) => match result { + UntaggedValue::Error(shell_err) => Err(shell_err), + _ => Ok(result.into_value(tag)), + }, Err((left_type, right_type)) => Err(ShellError::coerce_error( left_type.spanned(binary.left.span), right_type.spanned(binary.right.span), diff --git a/crates/nu-cli/src/evaluate/operator.rs b/crates/nu-cli/src/evaluate/operator.rs index 291c40b4f..9350172b5 100644 --- a/crates/nu-cli/src/evaluate/operator.rs +++ b/crates/nu-cli/src/evaluate/operator.rs @@ -1,4 +1,5 @@ use crate::data::value; +use nu_errors::ShellError; use nu_protocol::hir::Operator; use nu_protocol::{Primitive, ShellTypeName, UntaggedValue, Value}; use std::ops::Not; @@ -24,7 +25,14 @@ pub fn apply_operator( Operator::Plus => value::compute_values(op, left, right), Operator::Minus => value::compute_values(op, left, right), Operator::Multiply => value::compute_values(op, left, right), - Operator::Divide => value::compute_values(op, left, right), + Operator::Divide => value::compute_values(op, left, right).map(|res| match res { + UntaggedValue::Error(_) => UntaggedValue::Error(ShellError::labeled_error( + "Evaluation error", + "division by zero", + &right.tag.span, + )), + _ => res, + }), Operator::In => table_contains(left, right).map(UntaggedValue::boolean), Operator::NotIn => table_contains(left, right).map(|x| UntaggedValue::boolean(!x)), Operator::And => match (left.as_bool(), right.as_bool()) { diff --git a/crates/nu-cli/tests/commands/math/mod.rs b/crates/nu-cli/tests/commands/math/mod.rs index 78c591c90..136d56872 100644 --- a/crates/nu-cli/tests/commands/math/mod.rs +++ b/crates/nu-cli/tests/commands/math/mod.rs @@ -87,6 +87,54 @@ fn division_of_ints2() { assert_eq!(actual.out, "0.25"); } +#[test] +fn error_zero_division_int_int() { + let actual = nu!( + cwd: "tests/fixtures/formats", pipeline( + r#" + = 1 / 0 + "# + )); + + assert!(actual.err.contains("division by zero")); +} + +#[test] +fn error_zero_division_decimal_int() { + let actual = nu!( + cwd: "tests/fixtures/formats", pipeline( + r#" + = 1.0 / 0 + "# + )); + + assert!(actual.err.contains("division by zero")); +} + +#[test] +fn error_zero_division_int_decimal() { + let actual = nu!( + cwd: "tests/fixtures/formats", pipeline( + r#" + = 1 / 0.0 + "# + )); + + assert!(actual.err.contains("division by zero")); +} + +#[test] +fn error_zero_division_decimal_decimal() { + let actual = nu!( + cwd: "tests/fixtures/formats", pipeline( + r#" + = 1.0 / 0.0 + "# + )); + + assert!(actual.err.contains("division by zero")); +} + #[test] fn proper_precedence_history() { let actual = nu!(