nushell/crates/nu-command/src/commands/dataframe/where_.rs
Fernando Herrera aa1cd7eba6
Series Operation (#3563)
* Sample command

* Join command with checks

* More dataframes commands

* Groupby and aggregate commands

* Missing feature dataframe flag

* Renamed file

* New commands for dataframes

* error parser and df reference

* filter command for dataframes

* removed name from nu_dataframe

* commands to save to parquet and csv

* polars new version

* new dataframe commands

* series type and print

* Series basic arithmetics

* Add new column to dataframe

* Command names changed to nushell standard
2021-06-08 05:27:46 +12:00

196 lines
6.0 KiB
Rust

use crate::prelude::*;
use nu_engine::{evaluate_baseline_expr, EvaluatedCommandArgs, WholeStreamCommand};
use nu_errors::ShellError;
use nu_protocol::{
dataframe::NuDataFrame,
hir::{CapturedBlock, ClassifiedCommand, Expression, Literal, Operator, SpannedExpression},
Primitive, Signature, SyntaxShape, UnspannedPathMember, UntaggedValue,
};
use super::utils::parse_polars_error;
use polars::prelude::{ChunkCompare, Series};
pub struct DataFrame;
impl WholeStreamCommand for DataFrame {
fn name(&self) -> &str {
"pls where"
}
fn signature(&self) -> Signature {
Signature::build("pls where").required(
"condition",
SyntaxShape::RowCondition,
"the condition that must match",
)
}
fn usage(&self) -> &str {
"Filter dataframe to match the condition"
}
fn run(&self, args: CommandArgs) -> Result<OutputStream, ShellError> {
command(args)
}
fn examples(&self) -> Vec<Example> {
vec![Example {
description: "Filter dataframe based on column a",
example: "[[a b]; [1 2] [3 4]] | pls to-df | pls where a == 1",
result: None,
}]
}
}
fn command(args: CommandArgs) -> Result<OutputStream, ShellError> {
let tag = args.call_info.name_tag.clone();
let args = args.evaluate_once()?;
let block: CapturedBlock = args.req(0)?;
let expression = block
.block
.block
.get(0)
.and_then(|group| {
group
.pipelines
.get(0)
.and_then(|v| v.list.get(0))
.and_then(|expr| match &expr {
ClassifiedCommand::Expr(expr) => match &expr.as_ref().expr {
Expression::Binary(expr) => Some(expr),
_ => None,
},
_ => None,
})
})
.ok_or(ShellError::labeled_error(
"Expected a condition",
"expected a condition",
&tag.span,
))?;
let lhs = match &expression.left.expr {
Expression::FullColumnPath(p) => p.as_ref().tail.get(0),
_ => None,
}
.ok_or(ShellError::labeled_error(
"No column name",
"Not a column name found in left hand side of comparison",
&expression.left.span,
))?;
let (col_name, col_name_span) = match &lhs.unspanned {
UnspannedPathMember::String(name) => Ok((name, &lhs.span)),
_ => Err(ShellError::labeled_error(
"No column name",
"Not a string as column name",
&lhs.span,
)),
}?;
let rhs = evaluate_baseline_expr(&expression.right, &args.args.context)?;
let right_condition = match &rhs.value {
UntaggedValue::Primitive(primitive) => Ok(primitive),
_ => Err(ShellError::labeled_error(
"Incorrect argument",
"Expected primitive values",
&rhs.tag.span,
)),
}?;
filter_dataframe(
args,
&col_name,
&col_name_span,
&right_condition,
&expression.op,
)
}
macro_rules! comparison_arm {
($comparison:expr, $col:expr, $condition:expr, $span:expr) => {
match $condition {
Primitive::Int(val) => Ok($comparison($col, *val)),
Primitive::BigInt(val) => Ok($comparison(
$col,
val.to_i64()
.expect("Internal error: protocol did not use compatible decimal"),
)),
Primitive::Decimal(val) => Ok($comparison(
$col,
val.to_f64()
.expect("Internal error: protocol did not use compatible decimal"),
)),
Primitive::String(val) => {
let temp: &str = val.as_ref();
Ok($comparison($col, temp))
}
_ => Err(ShellError::labeled_error(
"Invalid datatype",
format!(
"this operator cannot be used with the selected '{}' datatype",
$col.dtype()
),
&$span,
)),
}
};
}
// With the information extracted from the block we can filter the dataframe using
// polars operations
fn filter_dataframe(
mut args: EvaluatedCommandArgs,
col_name: &str,
col_name_span: &Span,
right_condition: &Primitive,
operator: &SpannedExpression,
) -> Result<OutputStream, ShellError> {
let span = args.call_info.name_tag.span;
let df = NuDataFrame::try_from_stream(&mut args.input, &span)?;
let col = df
.as_ref()
.column(col_name)
.map_err(|e| parse_polars_error::<&str>(&e, &col_name_span, None))?;
let op = match &operator.expr {
Expression::Literal(Literal::Operator(op)) => Ok(op),
_ => Err(ShellError::labeled_error(
"Incorrect argument",
"Expected operator",
&operator.span,
)),
}?;
let mask = match op {
Operator::Equal => comparison_arm!(Series::eq, col, right_condition, operator.span),
Operator::NotEqual => comparison_arm!(Series::neq, col, right_condition, operator.span),
Operator::LessThan => comparison_arm!(Series::lt, col, right_condition, operator.span),
Operator::LessThanOrEqual => {
comparison_arm!(Series::lt_eq, col, right_condition, operator.span)
}
Operator::GreaterThan => comparison_arm!(Series::gt, col, right_condition, operator.span),
Operator::GreaterThanOrEqual => {
comparison_arm!(Series::gt_eq, col, right_condition, operator.span)
}
_ => Err(ShellError::labeled_error(
"Incorrect operator",
"Not implemented operator for dataframes filter",
&operator.span,
)),
}?;
let res = df
.as_ref()
.filter(&mask)
.map_err(|e| parse_polars_error::<&str>(&e, &args.call_info.name_tag.span, None))?;
Ok(OutputStream::one(NuDataFrame::dataframe_to_value(
res,
args.call_info.name_tag.clone(),
)))
}