diff --git a/crates/nu_plugin_polars/src/dataframe/command/aggregation/groupby.rs b/crates/nu_plugin_polars/src/dataframe/command/aggregation/groupby.rs index 67376ca8b1..e0074eb29f 100644 --- a/crates/nu_plugin_polars/src/dataframe/command/aggregation/groupby.rs +++ b/crates/nu_plugin_polars/src/dataframe/command/aggregation/groupby.rs @@ -31,6 +31,10 @@ impl PluginCommand for ToLazyGroupBy { SyntaxShape::Any, "Expression(s) that define the lazy group-by", ) + .switch( + "maintain-order", + "Ensure that the order of the groups is consistent with the input data. This is slower than a default group by and cannot be run on the streaming engine.", + Some('m')) .input_output_type( Type::Custom("dataframe".into()), Type::Custom("dataframe".into()), @@ -104,6 +108,7 @@ impl PluginCommand for ToLazyGroupBy { let vals: Vec = call.rest(0)?; let expr_value = Value::list(vals, call.head); let expressions = NuExpression::extract_exprs(plugin, expr_value)?; + let maintain_order = call.has_flag("maintain-order")?; if expressions .iter() @@ -118,7 +123,7 @@ impl PluginCommand for ToLazyGroupBy { let pipeline_value = input.into_value(call.head)?; let lazy = NuLazyFrame::try_from_value_coerce(plugin, &pipeline_value)?; - command(plugin, engine, call, lazy, expressions) + command(plugin, engine, call, lazy, expressions, maintain_order) .map_err(LabeledError::from) .map(|pd| pd.set_metadata(metadata)) } @@ -130,8 +135,13 @@ fn command( call: &EvaluatedCall, mut lazy: NuLazyFrame, expressions: Vec, + maintain_order: bool, ) -> Result { - let group_by = lazy.to_polars().group_by(expressions); + let group_by = if maintain_order { + lazy.to_polars().group_by_stable(expressions) + } else { + lazy.to_polars().group_by(expressions) + }; let group_by = NuLazyGroupBy::new(group_by, lazy.from_eager, lazy.schema().clone()?); group_by.to_pipeline_data(plugin, engine, call.head) } diff --git a/crates/nu_plugin_polars/src/dataframe/command/data/filter.rs b/crates/nu_plugin_polars/src/dataframe/command/data/filter.rs index 4581f8c84a..10cccf5800 100644 --- a/crates/nu_plugin_polars/src/dataframe/command/data/filter.rs +++ b/crates/nu_plugin_polars/src/dataframe/command/data/filter.rs @@ -1,7 +1,7 @@ use crate::{ PolarsPlugin, dataframe::values::{Column, NuDataFrame, NuExpression, NuLazyFrame}, - values::CustomValueSupport, + values::{CustomValueSupport, PolarsPluginObject, PolarsPluginType, cant_convert_err}, }; use nu_plugin::{EngineInterface, EvaluatedCall, PluginCommand}; @@ -31,10 +31,16 @@ impl PluginCommand for LazyFilter { SyntaxShape::Any, "Expression that define the column selection", ) - .input_output_type( - Type::Custom("dataframe".into()), - Type::Custom("dataframe".into()), - ) + .input_output_types(vec![ + ( + Type::Custom("dataframe".into()), + Type::Custom("dataframe".into()), + ), + ( + Type::Custom("expression".into()), + Type::Custom("expression".into()), + ), + ]) .category(Category::Custom("lazyframe".into())) } @@ -99,6 +105,37 @@ impl PluginCommand for LazyFilter { .into_value(Span::test_data()), ), }, + Example { + description: "Filter a single column in a group-by context", + example: "[[a b]; [foo 1] [foo 2] [foo 3] [bar 2] [bar 3] [bar 4]] | polars into-df + | polars group-by a --maintain-order + | polars agg { + lt: (polars col b | polars filter ((polars col b) < 2) | polars sum) + gte: (polars col b | polars filter ((polars col b) >= 3) | polars sum) + } + | polars collect", + result: Some( + NuDataFrame::try_from_columns( + vec![ + Column::new( + "a".to_string(), + vec![Value::test_string("foo"), Value::test_string("bar")], + ), + Column::new( + "lt".to_string(), + vec![Value::test_int(1), Value::test_int(0)], + ), + Column::new( + "gte".to_string(), + vec![Value::test_int(3), Value::test_int(7)], + ), + ], + None, + ) + .expect("simple df for test should not fail") + .into_value(Span::test_data()), + ), + }, ] } @@ -113,10 +150,31 @@ impl PluginCommand for LazyFilter { let expr_value: Value = call.req(0)?; let filter_expr = NuExpression::try_from_value(plugin, &expr_value)?; let pipeline_value = input.into_value(call.head)?; - let lazy = NuLazyFrame::try_from_value_coerce(plugin, &pipeline_value)?; - command(plugin, engine, call, lazy, filter_expr) - .map_err(LabeledError::from) - .map(|pd| pd.set_metadata(metadata)) + + match PolarsPluginObject::try_from_value(plugin, &pipeline_value)? { + PolarsPluginObject::NuDataFrame(df) => { + command(plugin, engine, call, df.lazy(), filter_expr) + } + PolarsPluginObject::NuLazyFrame(lazy) => { + command(plugin, engine, call, lazy, filter_expr) + } + + PolarsPluginObject::NuExpression(expr) => { + let res: NuExpression = expr.into_polars().filter(filter_expr.into_polars()).into(); + res.to_pipeline_data(plugin, engine, call.head) + } + + _ => Err(cant_convert_err( + &pipeline_value, + &[ + // PolarsPluginType::NuDataFrame, + PolarsPluginType::NuLazyGroupBy, + PolarsPluginType::NuExpression, + ], + )), + } + .map_err(LabeledError::from) + .map(|pd| pd.set_metadata(metadata)) } }