Casting operations for Series with differents types (#3702)

* Type in command description

* filter name change

* Clean column name

* Clippy error and updated polars version

* Lint correction in file

* CSV Infer schema optional

* Correct float operations

* changes in series castings to allow other types

* Clippy error correction
This commit is contained in:
Fernando Herrera 2021-06-28 11:17:37 +01:00 committed by GitHub
parent 7cb9fddc11
commit 1d0483c946
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -8,10 +8,9 @@ use nu_protocol::{
use nu_source::Span; use nu_source::Span;
use num_traits::ToPrimitive; use num_traits::ToPrimitive;
use num_bigint::BigInt;
use polars::prelude::{ use polars::prelude::{
BooleanType, ChunkCompare, ChunkedArray, DataType, Float64Type, Int64Type, IntoSeries, BooleanType, ChunkCompare, ChunkedArray, DataType, Float64Type, Int64Type, IntoSeries,
NumOpsDispatchChecked, Series, NumOpsDispatchChecked, PolarsError, Series,
}; };
use std::ops::{Add, BitAnd, BitOr, Div, Mul, Sub}; use std::ops::{Add, BitAnd, BitOr, Div, Mul, Sub};
@ -202,19 +201,20 @@ pub fn compute_series_single_value(
UntaggedValue::Primitive(Primitive::Int(val)) => Ok(compute_series_i64( UntaggedValue::Primitive(Primitive::Int(val)) => Ok(compute_series_i64(
lhs.as_ref(), lhs.as_ref(),
val, val,
<&ChunkedArray<Int64Type>>::add, <ChunkedArray<Int64Type>>::add,
&left.tag.span, &left.tag.span,
)), )),
UntaggedValue::Primitive(Primitive::BigInt(val)) => Ok(compute_series_bigint( UntaggedValue::Primitive(Primitive::BigInt(val)) => Ok(compute_series_i64(
lhs.as_ref(), lhs.as_ref(),
val, &val.to_i64()
<&ChunkedArray<Int64Type>>::add, .expect("Internal error: protocol did not use compatible decimal"),
<ChunkedArray<Int64Type>>::add,
&left.tag.span, &left.tag.span,
)), )),
UntaggedValue::Primitive(Primitive::Decimal(val)) => Ok(compute_series_decimal( UntaggedValue::Primitive(Primitive::Decimal(val)) => Ok(compute_series_decimal(
lhs.as_ref(), lhs.as_ref(),
val, val,
<&ChunkedArray<Float64Type>>::add, <ChunkedArray<Float64Type>>::add,
&left.tag.span, &left.tag.span,
)), )),
_ => Ok(UntaggedValue::Error( _ => Ok(UntaggedValue::Error(
@ -231,19 +231,20 @@ pub fn compute_series_single_value(
UntaggedValue::Primitive(Primitive::Int(val)) => Ok(compute_series_i64( UntaggedValue::Primitive(Primitive::Int(val)) => Ok(compute_series_i64(
lhs.as_ref(), lhs.as_ref(),
val, val,
<&ChunkedArray<Int64Type>>::sub, <ChunkedArray<Int64Type>>::sub,
&left.tag.span, &left.tag.span,
)), )),
UntaggedValue::Primitive(Primitive::BigInt(val)) => Ok(compute_series_bigint( UntaggedValue::Primitive(Primitive::BigInt(val)) => Ok(compute_series_i64(
lhs.as_ref(), lhs.as_ref(),
val, &val.to_i64()
<&ChunkedArray<Int64Type>>::sub, .expect("Internal error: protocol did not use compatible decimal"),
<ChunkedArray<Int64Type>>::sub,
&left.tag.span, &left.tag.span,
)), )),
UntaggedValue::Primitive(Primitive::Decimal(val)) => Ok(compute_series_decimal( UntaggedValue::Primitive(Primitive::Decimal(val)) => Ok(compute_series_decimal(
lhs.as_ref(), lhs.as_ref(),
val, val,
<&ChunkedArray<Float64Type>>::sub, <ChunkedArray<Float64Type>>::sub,
&left.tag.span, &left.tag.span,
)), )),
_ => Ok(UntaggedValue::Error( _ => Ok(UntaggedValue::Error(
@ -260,19 +261,20 @@ pub fn compute_series_single_value(
UntaggedValue::Primitive(Primitive::Int(val)) => Ok(compute_series_i64( UntaggedValue::Primitive(Primitive::Int(val)) => Ok(compute_series_i64(
lhs.as_ref(), lhs.as_ref(),
val, val,
<&ChunkedArray<Int64Type>>::mul, <ChunkedArray<Int64Type>>::mul,
&left.tag.span, &left.tag.span,
)), )),
UntaggedValue::Primitive(Primitive::BigInt(val)) => Ok(compute_series_bigint( UntaggedValue::Primitive(Primitive::BigInt(val)) => Ok(compute_series_i64(
lhs.as_ref(), lhs.as_ref(),
val, &val.to_i64()
<&ChunkedArray<Int64Type>>::mul, .expect("Internal error: protocol did not use compatible decimal"),
<ChunkedArray<Int64Type>>::mul,
&left.tag.span, &left.tag.span,
)), )),
UntaggedValue::Primitive(Primitive::Decimal(val)) => Ok(compute_series_decimal( UntaggedValue::Primitive(Primitive::Decimal(val)) => Ok(compute_series_decimal(
lhs.as_ref(), lhs.as_ref(),
val, val,
<&ChunkedArray<Float64Type>>::mul, <ChunkedArray<Float64Type>>::mul,
&left.tag.span, &left.tag.span,
)), )),
_ => Ok(UntaggedValue::Error( _ => Ok(UntaggedValue::Error(
@ -297,7 +299,7 @@ pub fn compute_series_single_value(
Ok(compute_series_i64( Ok(compute_series_i64(
lhs.as_ref(), lhs.as_ref(),
val, val,
<&ChunkedArray<Int64Type>>::div, <ChunkedArray<Int64Type>>::div,
&left.tag.span, &left.tag.span,
)) ))
} }
@ -310,10 +312,11 @@ pub fn compute_series_single_value(
&right.tag.span, &right.tag.span,
))) )))
} else { } else {
Ok(compute_series_bigint( Ok(compute_series_i64(
lhs.as_ref(), lhs.as_ref(),
val, &val.to_i64()
<&ChunkedArray<Int64Type>>::div, .expect("Internal error: protocol did not use compatible decimal"),
<ChunkedArray<Int64Type>>::div,
&left.tag.span, &left.tag.span,
)) ))
} }
@ -329,7 +332,7 @@ pub fn compute_series_single_value(
Ok(compute_series_decimal( Ok(compute_series_decimal(
lhs.as_ref(), lhs.as_ref(),
val, val,
<&ChunkedArray<Float64Type>>::div, <ChunkedArray<Float64Type>>::div,
&left.tag.span, &left.tag.span,
)) ))
} }
@ -352,9 +355,10 @@ pub fn compute_series_single_value(
ChunkedArray::eq, ChunkedArray::eq,
&left.tag.span, &left.tag.span,
)), )),
UntaggedValue::Primitive(Primitive::BigInt(val)) => Ok(compare_series_bigint( UntaggedValue::Primitive(Primitive::BigInt(val)) => Ok(compare_series_i64(
lhs.as_ref(), lhs.as_ref(),
val, &val.to_i64()
.expect("Internal error: protocol did not use compatible decimal"),
ChunkedArray::eq, ChunkedArray::eq,
&left.tag.span, &left.tag.span,
)), )),
@ -379,9 +383,10 @@ pub fn compute_series_single_value(
ChunkedArray::neq, ChunkedArray::neq,
&left.tag.span, &left.tag.span,
)), )),
UntaggedValue::Primitive(Primitive::BigInt(val)) => Ok(compare_series_bigint( UntaggedValue::Primitive(Primitive::BigInt(val)) => Ok(compare_series_i64(
lhs.as_ref(), lhs.as_ref(),
val, &val.to_i64()
.expect("Internal error: protocol did not use compatible decimal"),
ChunkedArray::neq, ChunkedArray::neq,
&left.tag.span, &left.tag.span,
)), )),
@ -409,9 +414,10 @@ pub fn compute_series_single_value(
ChunkedArray::lt, ChunkedArray::lt,
&left.tag.span, &left.tag.span,
)), )),
UntaggedValue::Primitive(Primitive::BigInt(val)) => Ok(compare_series_bigint( UntaggedValue::Primitive(Primitive::BigInt(val)) => Ok(compare_series_i64(
lhs.as_ref(), lhs.as_ref(),
val, &val.to_i64()
.expect("Internal error: protocol did not use compatible decimal"),
ChunkedArray::lt, ChunkedArray::lt,
&left.tag.span, &left.tag.span,
)), )),
@ -436,9 +442,10 @@ pub fn compute_series_single_value(
ChunkedArray::lt_eq, ChunkedArray::lt_eq,
&left.tag.span, &left.tag.span,
)), )),
UntaggedValue::Primitive(Primitive::BigInt(val)) => Ok(compare_series_bigint( UntaggedValue::Primitive(Primitive::BigInt(val)) => Ok(compare_series_i64(
lhs.as_ref(), lhs.as_ref(),
val, &val.to_i64()
.expect("Internal error: protocol did not use compatible decimal"),
ChunkedArray::lt_eq, ChunkedArray::lt_eq,
&left.tag.span, &left.tag.span,
)), )),
@ -466,9 +473,10 @@ pub fn compute_series_single_value(
ChunkedArray::gt, ChunkedArray::gt,
&left.tag.span, &left.tag.span,
)), )),
UntaggedValue::Primitive(Primitive::BigInt(val)) => Ok(compare_series_bigint( UntaggedValue::Primitive(Primitive::BigInt(val)) => Ok(compare_series_i64(
lhs.as_ref(), lhs.as_ref(),
val, &val.to_i64()
.expect("Internal error: protocol did not use compatible decimal"),
ChunkedArray::gt, ChunkedArray::gt,
&left.tag.span, &left.tag.span,
)), )),
@ -493,9 +501,10 @@ pub fn compute_series_single_value(
ChunkedArray::gt_eq, ChunkedArray::gt_eq,
&left.tag.span, &left.tag.span,
)), )),
UntaggedValue::Primitive(Primitive::BigInt(val)) => Ok(compare_series_bigint( UntaggedValue::Primitive(Primitive::BigInt(val)) => Ok(compare_series_i64(
lhs.as_ref(), lhs.as_ref(),
val, &val.to_i64()
.expect("Internal error: protocol did not use compatible decimal"),
ChunkedArray::gt_eq, ChunkedArray::gt_eq,
&left.tag.span, &left.tag.span,
)), )),
@ -540,14 +549,53 @@ pub fn compute_series_single_value(
} }
} }
fn compute_series_i64<'r, F>(series: &'r Series, val: &i64, f: F, span: &Span) -> UntaggedValue fn compute_series_i64<F>(series: &Series, val: &i64, f: F, span: &Span) -> UntaggedValue
where where
F: Fn(&'r ChunkedArray<Int64Type>, i64) -> ChunkedArray<Int64Type>, F: Fn(ChunkedArray<Int64Type>, i64) -> ChunkedArray<Int64Type>,
{
match series.dtype() {
DataType::UInt32 | DataType::Int32 | DataType::UInt64 => {
let to_i64 = series.cast_with_dtype(&DataType::Int64);
match to_i64 {
Ok(series) => {
let casted = series.i64();
compute_casted_i64(casted, *val, f, span)
}
Err(e) => UntaggedValue::Error(ShellError::labeled_error(
"Casting error",
format!("{}", e),
span,
)),
}
}
DataType::Int64 => {
let casted = series.i64();
compute_casted_i64(casted, *val, f, span)
}
_ => UntaggedValue::Error(ShellError::labeled_error(
"Casting error",
format!(
"Series of type {} can not be used for operations with an i64 value",
series.dtype()
),
span,
)),
}
}
fn compute_casted_i64<F>(
casted: Result<&ChunkedArray<Int64Type>, PolarsError>,
val: i64,
f: F,
span: &Span,
) -> UntaggedValue
where
F: Fn(ChunkedArray<Int64Type>, i64) -> ChunkedArray<Int64Type>,
{ {
let casted = series.i64();
match casted { match casted {
Ok(casted) => { Ok(casted) => {
let res = f(casted, *val); let res = f(casted.clone(), val);
let res = res.into_series(); let res = res.into_series();
NuSeries::series_to_untagged(res) NuSeries::series_to_untagged(res)
} }
@ -559,98 +607,65 @@ where
} }
} }
fn compute_series_bigint<'r, F>( fn compute_series_decimal<F>(series: &Series, val: &BigDecimal, f: F, span: &Span) -> UntaggedValue
series: &'r Series,
val: &BigInt,
f: F,
span: &Span,
) -> UntaggedValue
where where
F: Fn(&'r ChunkedArray<Int64Type>, i64) -> ChunkedArray<Int64Type>, F: Fn(ChunkedArray<Float64Type>, f64) -> ChunkedArray<Float64Type>,
{ {
let casted = series.i64(); match series.dtype() {
match casted { DataType::Float32 => {
Ok(casted) => { let to_f64 = series.cast_with_dtype(&DataType::Float64);
let res = f(
casted,
val.to_i64()
.expect("Internal error: protocol did not use compatible decimal"),
);
let res = res.into_series();
NuSeries::series_to_untagged(res)
}
Err(e) => UntaggedValue::Error(ShellError::labeled_error(
"Casting error",
format!("{}", e),
span,
)),
}
}
fn compute_series_decimal<'r, F>( match to_f64 {
series: &'r Series, Ok(series) => {
val: &BigDecimal, let casted = series.f64();
f: F, compute_casted_f64(
span: &Span, casted,
) -> UntaggedValue val.to_f64()
where .expect("Internal error: protocol did not use compatible decimal"),
F: Fn(&'r ChunkedArray<Float64Type>, f64) -> ChunkedArray<Float64Type>, f,
{ span,
let casted = series.f64(); )
match casted { }
Ok(casted) => { Err(e) => UntaggedValue::Error(ShellError::labeled_error(
let res = f( "Casting error",
format!("{}", e),
span,
)),
}
}
DataType::Float64 => {
let casted = series.f64();
compute_casted_f64(
casted, casted,
val.to_f64() val.to_f64()
.expect("Internal error: protocol did not use compatible decimal"), .expect("Internal error: protocol did not use compatible decimal"),
); f,
let res = res.into_series(); span,
NuSeries::series_to_untagged(res) )
} }
Err(e) => UntaggedValue::Error(ShellError::labeled_error( _ => UntaggedValue::Error(ShellError::labeled_error(
"Casting error", "Casting error",
format!("{}", e), format!(
"Series of type {} can not be used for operations with a decimal value",
series.dtype()
),
span, span,
)), )),
} }
} }
fn compare_series_i64<'r, F>(series: &'r Series, val: &i64, f: F, span: &Span) -> UntaggedValue fn compute_casted_f64<F>(
where casted: Result<&ChunkedArray<Float64Type>, PolarsError>,
F: Fn(&'r ChunkedArray<Int64Type>, i64) -> ChunkedArray<BooleanType>, val: f64,
{
let casted = series.i64();
match casted {
Ok(casted) => {
let res = f(casted, *val);
let res = res.into_series();
NuSeries::series_to_untagged(res)
}
Err(e) => UntaggedValue::Error(ShellError::labeled_error(
"Casting error",
format!("{}", e),
span,
)),
}
}
fn compare_series_bigint<'r, F>(
series: &'r Series,
val: &BigInt,
f: F, f: F,
span: &Span, span: &Span,
) -> UntaggedValue ) -> UntaggedValue
where where
F: Fn(&'r ChunkedArray<Int64Type>, i64) -> ChunkedArray<BooleanType>, F: Fn(ChunkedArray<Float64Type>, f64) -> ChunkedArray<Float64Type>,
{ {
let casted = series.i64();
match casted { match casted {
Ok(casted) => { Ok(casted) => {
let res = f( let res = f(casted.clone(), val);
casted,
val.to_i64()
.expect("Internal error: protocol did not use compatible decimal"),
);
let res = res.into_series(); let res = res.into_series();
NuSeries::series_to_untagged(res) NuSeries::series_to_untagged(res)
} }
@ -662,23 +677,123 @@ where
} }
} }
fn compare_series_decimal<'r, F>( fn compare_series_i64<F>(series: &Series, val: &i64, f: F, span: &Span) -> UntaggedValue
series: &'r Series, where
val: &BigDecimal, F: Fn(&ChunkedArray<Int64Type>, i64) -> ChunkedArray<BooleanType>,
{
match series.dtype() {
DataType::UInt32 | DataType::Int32 | DataType::UInt64 => {
let to_i64 = series.cast_with_dtype(&DataType::Int64);
match to_i64 {
Ok(series) => {
let casted = series.i64();
compare_casted_i64(casted, *val, f, span)
}
Err(e) => UntaggedValue::Error(ShellError::labeled_error(
"Casting error",
format!("{}", e),
span,
)),
}
}
DataType::Int64 => {
let casted = series.i64();
compare_casted_i64(casted, *val, f, span)
}
_ => UntaggedValue::Error(ShellError::labeled_error(
"Casting error",
format!(
"Series of type {} can not be used for operations with an i64 value",
series.dtype()
),
span,
)),
}
}
fn compare_casted_i64<F>(
casted: Result<&ChunkedArray<Int64Type>, PolarsError>,
val: i64,
f: F, f: F,
span: &Span, span: &Span,
) -> UntaggedValue ) -> UntaggedValue
where where
F: Fn(&'r ChunkedArray<Float64Type>, f64) -> ChunkedArray<BooleanType>, F: Fn(&ChunkedArray<Int64Type>, i64) -> ChunkedArray<BooleanType>,
{ {
let casted = series.f64();
match casted { match casted {
Ok(casted) => { Ok(casted) => {
let res = f( let res = f(casted, val);
let res = res.into_series();
NuSeries::series_to_untagged(res)
}
Err(e) => UntaggedValue::Error(ShellError::labeled_error(
"Casting error",
format!("{}", e),
span,
)),
}
}
fn compare_series_decimal<F>(series: &Series, val: &BigDecimal, f: F, span: &Span) -> UntaggedValue
where
F: Fn(&ChunkedArray<Float64Type>, f64) -> ChunkedArray<BooleanType>,
{
match series.dtype() {
DataType::Float32 => {
let to_f64 = series.cast_with_dtype(&DataType::Float64);
match to_f64 {
Ok(series) => {
let casted = series.f64();
compare_casted_f64(
casted,
val.to_f64()
.expect("Internal error: protocol did not use compatible decimal"),
f,
span,
)
}
Err(e) => UntaggedValue::Error(ShellError::labeled_error(
"Casting error",
format!("{}", e),
span,
)),
}
}
DataType::Float64 => {
let casted = series.f64();
compare_casted_f64(
casted, casted,
val.to_f64() val.to_f64()
.expect("Internal error: protocol did not use compatible decimal"), .expect("Internal error: protocol did not use compatible decimal"),
); f,
span,
)
}
_ => UntaggedValue::Error(ShellError::labeled_error(
"Casting error",
format!(
"Series of type {} can not be used for operations with a decimal value",
series.dtype()
),
span,
)),
}
}
fn compare_casted_f64<F>(
casted: Result<&ChunkedArray<Float64Type>, PolarsError>,
val: f64,
f: F,
span: &Span,
) -> UntaggedValue
where
F: Fn(&ChunkedArray<Float64Type>, f64) -> ChunkedArray<BooleanType>,
{
match casted {
Ok(casted) => {
let res = f(casted, val);
let res = res.into_series(); let res = res.into_series();
NuSeries::series_to_untagged(res) NuSeries::series_to_untagged(res)
} }