From 05c36d1bc76ab187559ed0e0e327bbd58fe38735 Mon Sep 17 00:00:00 2001 From: Matthias Meschede Date: Thu, 24 Apr 2025 23:44:29 +0200 Subject: [PATCH] add polars join_where command (#15635) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit # Description This adds `polars join_where` which allows joining two dataframes based on a conditions. The command can be used as: ``` ➜ let df_a = [[name cash];[Alice 5] [Bob 10]] | polars into-lazy ➜ let df_b = [[item price];[A 3] [B 7] [C 12]] | polars into-lazy ➜ $df_a | polars join_where $df_b ((polars col cash) > (polars col price)) | polars collect ╭───┬───────┬──────┬──────┬───────╮ │ # │ name │ cash │ item │ price │ ├───┼───────┼──────┼──────┼───────┤ │ 0 │ Bob │ 10 │ B │ 7 │ │ 1 │ Bob │ 10 │ A │ 3 │ │ 2 │ Alice │ 5 │ A │ 3 │ ╰───┴───────┴──────┴──────┴───────╯ ``` # User-Facing Changes - new command `polars join_where` --- crates/nu_plugin_polars/Cargo.toml | 1 + .../src/dataframe/command/data/join_where.rs | 119 ++++++++++++++++++ .../src/dataframe/command/data/mod.rs | 3 + 3 files changed, 123 insertions(+) create mode 100644 crates/nu_plugin_polars/src/dataframe/command/data/join_where.rs diff --git a/crates/nu_plugin_polars/Cargo.toml b/crates/nu_plugin_polars/Cargo.toml index 5bdff452aa..479613a728 100644 --- a/crates/nu_plugin_polars/Cargo.toml +++ b/crates/nu_plugin_polars/Cargo.toml @@ -61,6 +61,7 @@ features = [ "cloud", "concat_str", "cross_join", + "iejoin", "csv", "cum_agg", "default", diff --git a/crates/nu_plugin_polars/src/dataframe/command/data/join_where.rs b/crates/nu_plugin_polars/src/dataframe/command/data/join_where.rs new file mode 100644 index 0000000000..970a268174 --- /dev/null +++ b/crates/nu_plugin_polars/src/dataframe/command/data/join_where.rs @@ -0,0 +1,119 @@ +use crate::{ + dataframe::values::{Column, NuDataFrame, NuExpression, NuLazyFrame}, + values::CustomValueSupport, + PolarsPlugin, +}; +use nu_plugin::{EngineInterface, EvaluatedCall, PluginCommand}; +use nu_protocol::{ + Category, Example, LabeledError, PipelineData, Signature, Span, SyntaxShape, Type, Value, +}; + +#[derive(Clone)] +pub struct LazyJoinWhere; + +impl PluginCommand for LazyJoinWhere { + type Plugin = PolarsPlugin; + + fn name(&self) -> &str { + "polars join_where" + } + + fn description(&self) -> &str { + "Joins a lazy frame with other lazy frame based on conditions." + } + + fn signature(&self) -> Signature { + Signature::build(self.name()) + .required("other", SyntaxShape::Any, "LazyFrame to join with") + .required("condition", SyntaxShape::Any, "Condition") + .input_output_type( + Type::Custom("dataframe".into()), + Type::Custom("dataframe".into()), + ) + .category(Category::Custom("lazyframe".into())) + } + + fn examples(&self) -> Vec { + vec![Example { + description: "Join two lazy dataframes with a condition", + example: r#"let df_a = ([[name cash];[Alice 5] [Bob 10]] | polars into-lazy) + let df_b = ([[item price];[A 3] [B 7] [C 12]] | polars into-lazy) + $df_a | polars join_where $df_b ((polars col cash) > (polars col price)) | polars collect"#, + result: Some( + NuDataFrame::try_from_columns( + vec![ + Column::new( + "name".to_string(), + vec![ + Value::test_string("Bob"), + Value::test_string("Bob"), + Value::test_string("Alice"), + ], + ), + Column::new( + "cash".to_string(), + vec![Value::test_int(10), Value::test_int(10), Value::test_int(5)], + ), + Column::new( + "item".to_string(), + vec![ + Value::test_string("B"), + Value::test_string("A"), + Value::test_string("A"), + ], + ), + Column::new( + "price".to_string(), + vec![Value::test_int(7), Value::test_int(3), Value::test_int(3)], + ), + ], + None, + ) + .expect("simple df for test should not fail") + .into_value(Span::test_data()), + ), + }] + } + + fn run( + &self, + plugin: &Self::Plugin, + engine: &EngineInterface, + call: &EvaluatedCall, + input: PipelineData, + ) -> Result { + let other: Value = call.req(0)?; + let other = NuLazyFrame::try_from_value_coerce(plugin, &other)?; + let other = other.to_polars(); + + let condition: Value = call.req(1)?; + let condition = NuExpression::try_from_value(plugin, &condition)?; + let condition = condition.into_polars(); + + let pipeline_value = input.into_value(call.head)?; + let lazy = NuLazyFrame::try_from_value_coerce(plugin, &pipeline_value)?; + let from_eager = lazy.from_eager; + let lazy = lazy.to_polars(); + + let lazy = lazy + .join_builder() + .with(other) + .force_parallel(true) + .join_where(vec![condition]); + + let lazy = NuLazyFrame::new(from_eager, lazy); + lazy.to_pipeline_data(plugin, engine, call.head) + .map_err(LabeledError::from) + } +} + +#[cfg(test)] +mod test { + use super::*; + use crate::test::test_polars_plugin_command; + + #[test] + fn test_examples() -> Result<(), nu_protocol::ShellError> { + test_polars_plugin_command(&LazyJoinWhere) + } +} diff --git a/crates/nu_plugin_polars/src/dataframe/command/data/mod.rs b/crates/nu_plugin_polars/src/dataframe/command/data/mod.rs index dd6cdc4eff..f4089a03c7 100644 --- a/crates/nu_plugin_polars/src/dataframe/command/data/mod.rs +++ b/crates/nu_plugin_polars/src/dataframe/command/data/mod.rs @@ -19,6 +19,7 @@ mod first; mod flatten; mod get; mod join; +mod join_where; mod last; mod len; mod lit; @@ -61,6 +62,7 @@ pub use first::FirstDF; use flatten::LazyFlatten; pub use get::GetDF; use join::LazyJoin; +use join_where::LazyJoinWhere; pub use last::LastDF; pub use lit::ExprLit; use query_df::QueryDf; @@ -106,6 +108,7 @@ pub(crate) fn data_commands() -> Vec