From 6c1c7f95090649b3d73845f9cc81f35a3161161b Mon Sep 17 00:00:00 2001 From: Jack Wright <56345+ayax79@users.noreply.github.com> Date: Fri, 6 Sep 2024 20:03:51 -0700 Subject: [PATCH] Added expression support for `polars cumulative` (#13799) # Description Provides the ability to use `polars cumulative` as an expression: Screenshot 2024-09-06 at 17 47 15 # User-Facing Changes - `polars cumulative` can now be used as an expression. --- .../command/aggregation/cumulative.rs | 96 ++++++++++++++++--- 1 file changed, 83 insertions(+), 13 deletions(-) diff --git a/crates/nu_plugin_polars/src/dataframe/command/aggregation/cumulative.rs b/crates/nu_plugin_polars/src/dataframe/command/aggregation/cumulative.rs index e745bf5432..db3affbd25 100644 --- a/crates/nu_plugin_polars/src/dataframe/command/aggregation/cumulative.rs +++ b/crates/nu_plugin_polars/src/dataframe/command/aggregation/cumulative.rs @@ -1,6 +1,8 @@ use crate::{values::CustomValueSupport, PolarsPlugin}; -use crate::values::{Column, NuDataFrame}; +use crate::values::{ + cant_convert_err, Column, NuDataFrame, NuExpression, PolarsPluginObject, PolarsPluginType, +}; use nu_plugin::{EngineInterface, EvaluatedCall, PluginCommand}; use nu_protocol::{ @@ -52,22 +54,53 @@ impl PluginCommand for Cumulative { } fn description(&self) -> &str { - "Cumulative calculation for a series." + "Cumulative calculation for a column or series." } fn signature(&self) -> Signature { Signature::build(self.name()) - .required("type", SyntaxShape::String, "rolling operation") - .switch("reverse", "Reverse cumulative calculation", Some('r')) - .input_output_type( - Type::Custom("dataframe".into()), - Type::Custom("dataframe".into()), + .required( + "type", + SyntaxShape::String, + "rolling operation. Values of min, max, and sum are accepted.", ) + .switch("reverse", "Reverse cumulative calculation", Some('r')) + .input_output_types(vec![ + ( + Type::Custom("dataframe".into()), + Type::Custom("dataframe".into()), + ), + ( + Type::Custom("expression".into()), + Type::Custom("expression".into()), + ), + ]) .category(Category::Custom("dataframe".into())) } fn examples(&self) -> Vec { vec![ + Example { + description: "Cumulative sum for a column", + example: "[[a]; [1] [2] [3] [4] [5]] | polars into-df | polars select (polars col a | polars cumulative sum | polars as cum_a) | polars collect", + result: Some( + NuDataFrame::try_from_columns( + vec![Column::new( + "cum_a".to_string(), + vec![ + Value::test_int(1), + Value::test_int(3), + Value::test_int(6), + Value::test_int(10), + Value::test_int(15), + ], + )], + None, + ) + .expect("simple df for test should not fail") + .into_value(Span::test_data()), + ), + }, Example { description: "Cumulative sum for a series", example: "[1 2 3 4 5] | polars into-df | polars cumulative sum", @@ -120,20 +153,58 @@ impl PluginCommand for Cumulative { call: &EvaluatedCall, input: PipelineData, ) -> Result { - command(plugin, engine, call, input).map_err(LabeledError::from) + let value = input.into_value(call.head)?; + let cum_type: Spanned = call.req(0)?; + let cum_type = CumulativeType::from_str(&cum_type.item, cum_type.span)?; + match PolarsPluginObject::try_from_value(plugin, &value)? { + PolarsPluginObject::NuDataFrame(df) => command_df(plugin, engine, call, cum_type, df), + PolarsPluginObject::NuLazyFrame(lazy) => { + command_df(plugin, engine, call, cum_type, lazy.collect(call.head)?) + } + PolarsPluginObject::NuExpression(expr) => { + command_expr(plugin, engine, call, cum_type, expr) + } + _ => Err(cant_convert_err( + &value, + &[ + PolarsPluginType::NuDataFrame, + PolarsPluginType::NuLazyFrame, + PolarsPluginType::NuExpression, + ], + )), + } + .map_err(LabeledError::from) } } -fn command( +fn command_expr( plugin: &PolarsPlugin, engine: &EngineInterface, call: &EvaluatedCall, - input: PipelineData, + cum_type: CumulativeType, + expr: NuExpression, ) -> Result { - let cum_type: Spanned = call.req(0)?; let reverse = call.has_flag("reverse")?; + let polars = expr.into_polars(); - let df = NuDataFrame::try_from_pipeline_coerce(plugin, input, call.head)?; + let res: NuExpression = match cum_type { + CumulativeType::Max => polars.cum_max(reverse), + CumulativeType::Min => polars.cum_min(reverse), + CumulativeType::Sum => polars.cum_sum(reverse), + } + .into(); + + res.to_pipeline_data(plugin, engine, call.head) +} + +fn command_df( + plugin: &PolarsPlugin, + engine: &EngineInterface, + call: &EvaluatedCall, + cum_type: CumulativeType, + df: NuDataFrame, +) -> Result { + let reverse = call.has_flag("reverse")?; let series = df.as_series(call.head)?; if let DataType::Object(..) = series.dtype() { @@ -146,7 +217,6 @@ fn command( }); } - let cum_type = CumulativeType::from_str(&cum_type.item, cum_type.span)?; let mut res = match cum_type { CumulativeType::Max => cum_max(&series, reverse), CumulativeType::Min => cum_min(&series, reverse),