Adding parse fix for power operator error on negative integers and de… (#3821)

* Adding parse fix for power operator error on negative integers and decimal

* Adding correct formatting

* Changed is negative check to follow conventions

* Adding tests

* Added fix for the negatives and added tests

* Removed comments
This commit is contained in:
Joel Afriyie 2021-07-30 07:08:57 -05:00 committed by GitHub
parent 1e15f26e98
commit 69083bfca0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -1,6 +1,7 @@
use crate::base::coerce_compare; use crate::base::coerce_compare;
use crate::base::shape::{Column, InlineShape}; use crate::base::shape::{Column, InlineShape};
use crate::primitive::style_primitive; use crate::primitive::style_primitive;
use bigdecimal::Signed;
use chrono::{DateTime, NaiveDate, Utc}; use chrono::{DateTime, NaiveDate, Utc};
use nu_errors::ShellError; use nu_errors::ShellError;
use nu_protocol::hir::Operator; use nu_protocol::hir::Operator;
@ -9,6 +10,7 @@ use nu_protocol::{Primitive, Type, UntaggedValue};
use nu_source::{DebugDocBuilder, PrettyDebug, Span, Tagged}; use nu_source::{DebugDocBuilder, PrettyDebug, Span, Tagged};
use nu_table::TextStyle; use nu_table::TextStyle;
use num_bigint::BigInt; use num_bigint::BigInt;
use num_bigint::ToBigInt;
use num_traits::{ToPrimitive, Zero}; use num_traits::{ToPrimitive, Zero};
use std::collections::HashMap; use std::collections::HashMap;
@ -171,10 +173,31 @@ pub fn compute_values(
} }
Operator::Pow => { Operator::Pow => {
let prim_u32 = ToPrimitive::to_u32(y); let prim_u32 = ToPrimitive::to_u32(y);
let sign = match x.is_negative() {
true => -1,
false => 1,
};
if !y.is_negative() {
match prim_u32 { match prim_u32 {
Some(num) => Ok(UntaggedValue::Primitive(Primitive::Int(x.pow(num)))), Some(num) => Ok(UntaggedValue::Primitive(Primitive::Int(
sign * (x.pow(num)),
))),
_ => Err((left.type_name(), right.type_name())), _ => Err((left.type_name(), right.type_name())),
} }
} else {
let yp = bigdecimal::ToPrimitive::to_f64(y).unwrap_or(0.0);
let xp = bigdecimal::ToPrimitive::to_f64(x).unwrap_or(0.0);
let pow =
bigdecimal::FromPrimitive::from_f64((sign as f64) * (xp.powf(yp)));
match pow {
Some(p) => Ok(UntaggedValue::Primitive(Primitive::Decimal(p))),
_ => Err((left.type_name(), right.type_name())),
}
}
} }
_ => Err((left.type_name(), right.type_name())), _ => Err((left.type_name(), right.type_name())),
}, },
@ -211,10 +234,29 @@ pub fn compute_values(
} }
Operator::Pow => { Operator::Pow => {
let prim_u32 = ToPrimitive::to_u32(y); let prim_u32 = ToPrimitive::to_u32(y);
let sign = match x.is_negative() {
true => -1,
false => 1,
};
if !y.is_negative() {
match prim_u32 { match prim_u32 {
Some(num) => Ok(UntaggedValue::Primitive(Primitive::Int(x.pow(num)))), Some(num) => Ok(UntaggedValue::Primitive(Primitive::Int(
sign * (x.pow(num)),
))),
_ => Err((left.type_name(), right.type_name())), _ => Err((left.type_name(), right.type_name())),
} }
} else {
let yp = bigdecimal::ToPrimitive::to_f64(y).unwrap_or(0.0);
let xp = bigdecimal::ToPrimitive::to_f64(x).unwrap_or(0.0);
let pow =
bigdecimal::FromPrimitive::from_f64((sign as f64) * (xp.powf(yp)));
match pow {
Some(p) => Ok(UntaggedValue::Primitive(Primitive::Decimal(p))),
_ => Err((left.type_name(), right.type_name())),
}
}
} }
_ => Err((left.type_name(), right.type_name())), _ => Err((left.type_name(), right.type_name())),
}, },
@ -251,10 +293,28 @@ pub fn compute_values(
} }
Operator::Pow => { Operator::Pow => {
let prim_u32 = ToPrimitive::to_u32(y); let prim_u32 = ToPrimitive::to_u32(y);
let sign = match x.is_negative() {
true => -1,
false => 1,
};
if !y.is_negative() {
match prim_u32 { match prim_u32 {
Some(num) => Ok(UntaggedValue::Primitive(Primitive::BigInt(x.pow(num)))), Some(num) => Ok(UntaggedValue::Primitive(Primitive::BigInt(
(sign.to_bigint().unwrap_or_default()) * x.pow(num),
))),
_ => Err((left.type_name(), right.type_name())), _ => Err((left.type_name(), right.type_name())),
} }
} else {
let yp = bigdecimal::ToPrimitive::to_f64(y).unwrap_or(0.0);
let xp = bigdecimal::ToPrimitive::to_f64(x).unwrap_or(0.0);
let pow = bigdecimal::FromPrimitive::from_f64((sign as f64) * xp.powf(yp));
match pow {
Some(p) => Ok(UntaggedValue::Primitive(Primitive::Decimal(p))),
_ => Err((left.type_name(), right.type_name())),
}
}
} }
_ => Err((left.type_name(), right.type_name())), _ => Err((left.type_name(), right.type_name())),
}, },
@ -283,10 +343,31 @@ pub fn compute_values(
} }
Operator::Pow => { Operator::Pow => {
let prim_u32 = ToPrimitive::to_u32(y); let prim_u32 = ToPrimitive::to_u32(y);
let sign = match x.is_negative() {
true => -1,
false => 1,
};
if !y.is_negative() {
match prim_u32 { match prim_u32 {
Some(num) => Ok(UntaggedValue::Primitive(Primitive::BigInt(x.pow(num)))), Some(num) => Ok(UntaggedValue::Primitive(Primitive::BigInt(
(sign.to_bigint().unwrap_or_default()).pow(num),
))),
_ => Err((left.type_name(), right.type_name())), _ => Err((left.type_name(), right.type_name())),
} }
} else {
let yp = bigdecimal::ToPrimitive::to_f64(y).unwrap_or(0.0);
let xp = bigdecimal::ToPrimitive::to_f64(x).unwrap_or(0.0);
let pow =
bigdecimal::FromPrimitive::from_f64((sign as f64) * (xp.powf(yp)));
match pow {
Some(p) => Ok(UntaggedValue::Primitive(Primitive::Decimal(p))),
_ => Err((left.type_name(), right.type_name())),
}
}
} }
_ => Err((left.type_name(), right.type_name())), _ => Err((left.type_name(), right.type_name())),
}, },
@ -307,16 +388,22 @@ pub fn compute_values(
} }
Ok(x % bigdecimal::BigDecimal::from(*y)) Ok(x % bigdecimal::BigDecimal::from(*y))
} }
// leaving this here for the hope that bigdecimal will one day support pow/powf/fpow
// Operator::Pow => { Operator::Pow => {
// let xp = bigdecimal::ToPrimitive::to_f64(x).unwrap_or(0.0); let sign = match x.is_negative() {
// let yp = bigdecimal::ToPrimitive::to_f64(y).unwrap_or(0.0); true => -1,
// let pow = bigdecimal::FromPrimitive::from_f64(xp.powf(yp)); false => 1,
// match pow { };
// Some(p) => Ok(p),
// None => Err((left.type_name(), right.type_name())), let yp = bigdecimal::ToPrimitive::to_f64(y).unwrap_or(0.0);
// } let xp = bigdecimal::ToPrimitive::to_f64(x).unwrap_or(0.0);
// } let pow =
bigdecimal::FromPrimitive::from_f64((sign as f64) * (xp.powf(yp)));
match pow {
Some(p) => Ok(p),
None => Err((left.type_name(), right.type_name())),
}
}
_ => Err((left.type_name(), right.type_name())), _ => Err((left.type_name(), right.type_name())),
}?; }?;
Ok(UntaggedValue::Primitive(Primitive::Decimal(result))) Ok(UntaggedValue::Primitive(Primitive::Decimal(result)))
@ -338,11 +425,22 @@ pub fn compute_values(
} }
Ok(bigdecimal::BigDecimal::from(*x) % y) Ok(bigdecimal::BigDecimal::from(*x) % y)
} }
// big decimal doesn't support pow yet
// Operator::Pow => { Operator::Pow => {
// let yp = bigdecimal::ToPrimitive::to_u32(y).unwrap_or(0); let sign = match x.is_negative() {
// Ok(bigdecimal::BigDecimal::from(x.pow(yp))) true => -1,
// } false => 1,
};
let yp = bigdecimal::ToPrimitive::to_f64(y).unwrap_or(0.0);
let xp = bigdecimal::ToPrimitive::to_f64(x).unwrap_or(0.0);
let pow =
bigdecimal::FromPrimitive::from_f64((sign as f64) * (xp.powf(yp)));
match pow {
Some(p) => Ok(p),
None => Err((left.type_name(), right.type_name())),
}
}
_ => Err((left.type_name(), right.type_name())), _ => Err((left.type_name(), right.type_name())),
}?; }?;
Ok(UntaggedValue::Primitive(Primitive::Decimal(result))) Ok(UntaggedValue::Primitive(Primitive::Decimal(result)))
@ -364,16 +462,22 @@ pub fn compute_values(
} }
Ok(x % bigdecimal::BigDecimal::from(y.clone())) Ok(x % bigdecimal::BigDecimal::from(y.clone()))
} }
// leaving this here for the hope that bigdecimal will one day support pow/powf/fpow
// Operator::Pow => { Operator::Pow => {
// let xp = bigdecimal::ToPrimitive::to_f64(x).unwrap_or(0.0); let sign = match x.is_negative() {
// let yp = bigdecimal::ToPrimitive::to_f64(y).unwrap_or(0.0); true => -1,
// let pow = bigdecimal::FromPrimitive::from_f64(xp.powf(yp)); false => 1,
// match pow { };
// Some(p) => Ok(p),
// None => Err((left.type_name(), right.type_name())), let yp = bigdecimal::ToPrimitive::to_f64(y).unwrap_or(0.0);
// } let xp = bigdecimal::ToPrimitive::to_f64(x).unwrap_or(0.0);
// } let pow =
bigdecimal::FromPrimitive::from_f64((sign as f64) * (xp.powf(yp)));
match pow {
Some(p) => Ok(p),
None => Err((left.type_name(), right.type_name())),
}
}
_ => Err((left.type_name(), right.type_name())), _ => Err((left.type_name(), right.type_name())),
}?; }?;
Ok(UntaggedValue::Primitive(Primitive::Decimal(result))) Ok(UntaggedValue::Primitive(Primitive::Decimal(result)))
@ -395,11 +499,22 @@ pub fn compute_values(
} }
Ok(bigdecimal::BigDecimal::from(x.clone()) % y) Ok(bigdecimal::BigDecimal::from(x.clone()) % y)
} }
// big decimal doesn't support pow yet
// Operator::Pow => { Operator::Pow => {
// let yp = bigdecimal::ToPrimitive::to_u32(y).unwrap_or(0); let sign = match x.is_negative() {
// Ok(bigdecimal::BigDecimal::from(x.pow(yp))) true => -1,
// } false => 1,
};
let yp = bigdecimal::ToPrimitive::to_f64(y).unwrap_or(0.0);
let xp = bigdecimal::ToPrimitive::to_f64(x).unwrap_or(0.0);
let pow =
bigdecimal::FromPrimitive::from_f64((sign as f64) * (xp.powf(yp)));
match pow {
Some(p) => Ok(p),
None => Err((left.type_name(), right.type_name())),
}
}
_ => Err((left.type_name(), right.type_name())), _ => Err((left.type_name(), right.type_name())),
}?; }?;
Ok(UntaggedValue::Primitive(Primitive::Decimal(result))) Ok(UntaggedValue::Primitive(Primitive::Decimal(result)))
@ -421,16 +536,22 @@ pub fn compute_values(
} }
Ok(x % y) Ok(x % y)
} }
// big decimal doesn't support pow yet
// Operator::Pow => { Operator::Pow => {
// let xp = bigdecimal::ToPrimitive::to_f64(x).unwrap_or(0.0); let sign = match x.is_negative() {
// let yp = bigdecimal::ToPrimitive::to_f64(y).unwrap_or(0.0); true => -1,
// let pow = bigdecimal::FromPrimitive::from_f64(xp.powf(yp)); false => 1,
// match pow { };
// Some(p) => Ok(p),
// None => Err((left.type_name(), right.type_name())), let yp = bigdecimal::ToPrimitive::to_f64(y).unwrap_or(0.0);
// } let xp = bigdecimal::ToPrimitive::to_f64(x).unwrap_or(0.0);
// } let pow =
bigdecimal::FromPrimitive::from_f64((sign as f64) * (xp.powf(yp)));
match pow {
Some(p) => Ok(p),
None => Err((left.type_name(), right.type_name())),
}
}
_ => Err((left.type_name(), right.type_name())), _ => Err((left.type_name(), right.type_name())),
}?; }?;
Ok(UntaggedValue::Primitive(Primitive::Decimal(result))) Ok(UntaggedValue::Primitive(Primitive::Decimal(result)))
@ -578,9 +699,11 @@ pub fn format_for_column<'a>(
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::merge_values;
use super::Date as d; use super::Date as d;
use super::UntaggedValue as v; use super::UntaggedValue as v;
use super::{compute_values, merge_values};
use nu_protocol::hir::Operator;
use nu_protocol::{Primitive, UntaggedValue};
use nu_source::TaggedItem; use nu_source::TaggedItem;
use indexmap::indexmap; use indexmap::indexmap;
@ -609,4 +732,107 @@ mod tests {
merge_values(&table_author_row, &other_table_author_row).unwrap() merge_values(&table_author_row, &other_table_author_row).unwrap()
); );
} }
#[test]
fn pow_operator_negatives_and_decimals() {
// test 2 ** 2
let result_one = compute_values(
Operator::Pow,
&UntaggedValue::Primitive(Primitive::Int(2)),
&UntaggedValue::Primitive(Primitive::Int(2)),
);
assert_eq!(
result_one.unwrap(),
UntaggedValue::Primitive(Primitive::Int(4))
);
// test 2 ** 2.0
let rhs_decimal = bigdecimal::FromPrimitive::from_f64(2.0).unwrap();
let result_two = compute_values(
Operator::Pow,
&UntaggedValue::Primitive(Primitive::Int(2)),
&UntaggedValue::Primitive(Primitive::Decimal(rhs_decimal)),
);
let should_equal_four_decimal = bigdecimal::FromPrimitive::from_f64(4.0).unwrap();
assert_eq!(
result_two.unwrap(),
UntaggedValue::Primitive(Primitive::Decimal(should_equal_four_decimal))
);
// test 2.0 ** 2.0
let rhs_decimal = bigdecimal::FromPrimitive::from_f64(2.0).unwrap();
let lhs_decimal = bigdecimal::FromPrimitive::from_f64(2.0).unwrap();
let should_equal_four_decimal = bigdecimal::FromPrimitive::from_f64(4.0).unwrap();
let result_three = compute_values(
Operator::Pow,
&UntaggedValue::Primitive(Primitive::Decimal(lhs_decimal)),
&UntaggedValue::Primitive(Primitive::Decimal(rhs_decimal)),
);
assert_eq!(
result_three.unwrap(),
UntaggedValue::Primitive(Primitive::Decimal(should_equal_four_decimal))
);
// test 2 ** -2
let result_four = compute_values(
Operator::Pow,
&UntaggedValue::Primitive(Primitive::Int(2)),
&UntaggedValue::Primitive(Primitive::Int(-2)),
);
let should_equal_zero_decimal = bigdecimal::FromPrimitive::from_f64(0.25).unwrap();
assert_eq!(
result_four.unwrap(),
UntaggedValue::Primitive(Primitive::Decimal(should_equal_zero_decimal))
);
// test -2 ** -2
let result_five = compute_values(
Operator::Pow,
&UntaggedValue::Primitive(Primitive::Int(-2)),
&UntaggedValue::Primitive(Primitive::Int(-2)),
);
let should_equal_neg_zero_decimal = bigdecimal::FromPrimitive::from_f64(-0.25).unwrap();
assert_eq!(
result_five.unwrap(),
UntaggedValue::Primitive(Primitive::Decimal(should_equal_neg_zero_decimal))
);
// test -2 ** 2
let result_six = compute_values(
Operator::Pow,
&UntaggedValue::Primitive(Primitive::Int(-2)),
&UntaggedValue::Primitive(Primitive::Int(2)),
);
assert_eq!(
result_six.unwrap(),
UntaggedValue::Primitive(Primitive::Int(-4))
);
// test -2.0 ** 2
let lhs_decimal = bigdecimal::FromPrimitive::from_f64(-2.0).unwrap();
let should_equal_neg_four_decimal = bigdecimal::FromPrimitive::from_f64(-4.0).unwrap();
let result_seven = compute_values(
Operator::Pow,
&UntaggedValue::Primitive(Primitive::Decimal(lhs_decimal)),
&UntaggedValue::Primitive(Primitive::Int(2)),
);
// Need to validate
assert_eq!(
result_seven.unwrap(),
UntaggedValue::Primitive(Primitive::Decimal(should_equal_neg_four_decimal))
);
}
} }