From 37bc922a67add1c0b8450e743efcc629b03076fe Mon Sep 17 00:00:00 2001 From: pyz4 <42039243+pyz4@users.noreply.github.com> Date: Tue, 27 May 2025 19:35:48 -0400 Subject: [PATCH] feat(polars): add `polars math` expression (#15822) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit # Description This PR adds a number of math functions under a single `polars math` command that apply to one or more column expressions. Note, `polars math` currently resides in the new module dataframe/command/command/computation/math.rs. I'm open to alternative organization and naming suggestions. ```nushell Collection of math functions to be applied on one or more column expressions This is an incomplete implementation of the available functions listed here: https://docs.pola.rs/api/python/stable/reference/expressions/computation.html. The following functions are currently available: - abs - cos - dot - exp - log - log1p - sign - sin - sqrt Usage: > polars math ...(args) Flags: -h, --help: Display the help message for this command Parameters: type : Function name. See extra description for full list of accepted values ...args : Extra arguments required by some functions Input/output types: ╭───┬────────────┬────────────╮ │ # │ input │ output │ ├───┼────────────┼────────────┤ │ 0 │ expression │ expression │ ╰───┴────────────┴────────────╯ Examples: Apply function to column expression > [[a]; [0] [-1] [2] [-3] [4]] | polars into-df | polars select [ (polars col a | polars math abs | polars as a_abs) (polars col a | polars math sign | polars as a_sign) (polars col a | polars math exp | polars as a_exp)] | polars collect ╭───┬───────┬────────┬────────╮ │ # │ a_abs │ a_sign │ a_exp │ ├───┼───────┼────────┼────────┤ │ 0 │ 0 │ 0 │ 1.000 │ │ 1 │ 1 │ -1 │ 0.368 │ │ 2 │ 2 │ 1 │ 7.389 │ │ 3 │ 3 │ -1 │ 0.050 │ │ 4 │ 4 │ 1 │ 54.598 │ ╰───┴───────┴────────┴────────╯ Specify arguments for select functions. See description for more information. > [[a]; [0] [1] [2] [4] [8] [16]] | polars into-df | polars select [ (polars col a | polars math log 2 | polars as a_base2)] | polars collect ╭───┬─────────╮ │ # │ a_base2 │ ├───┼─────────┤ │ 0 │ -inf │ │ 1 │ 0.000 │ │ 2 │ 1.000 │ │ 3 │ 2.000 │ │ 4 │ 3.000 │ │ 5 │ 4.000 │ ╰───┴─────────╯ Specify arguments for select functions. See description for more information. > [[a b]; [0 0] [1 1] [2 2] [3 3] [4 4] [5 5]] | polars into-df | polars select [ (polars col a | polars math dot (polars col b) | polars as ab)] | polars collect ╭───┬────────╮ │ # │ ab │ ├───┼────────┤ │ 0 │ 55.000 │ ╰───┴────────╯ ``` # User-Facing Changes No breaking changes. # Tests + Formatting Example tests were added to `polars math`. # After Submitting --- crates/nu_plugin_polars/Cargo.toml | 4 + .../src/dataframe/command/computation/math.rs | 255 ++++++++++++++++++ .../src/dataframe/command/computation/mod.rs | 10 + .../src/dataframe/command/mod.rs | 1 + crates/nu_plugin_polars/src/lib.rs | 8 +- 5 files changed, 275 insertions(+), 3 deletions(-) create mode 100644 crates/nu_plugin_polars/src/dataframe/command/computation/math.rs create mode 100644 crates/nu_plugin_polars/src/dataframe/command/computation/mod.rs diff --git a/crates/nu_plugin_polars/Cargo.toml b/crates/nu_plugin_polars/Cargo.toml index a4f0fbd965..240b49414a 100644 --- a/crates/nu_plugin_polars/Cargo.toml +++ b/crates/nu_plugin_polars/Cargo.toml @@ -55,6 +55,7 @@ url.workspace = true [dependencies.polars] features = [ + "abs", "arg_where", "bigidx", "checked_arithmetic", @@ -78,6 +79,7 @@ features = [ "is_in", "json", "lazy", + "log", "object", "parquet", "pivot", @@ -87,12 +89,14 @@ features = [ "round_series", "serde", "serde-lazy", + "sign", "strings", "string_to_integer", "streaming", "timezones", "temporal", "to_dummies", + "trigonometry", ] optional = false version = "0.46" diff --git a/crates/nu_plugin_polars/src/dataframe/command/computation/math.rs b/crates/nu_plugin_polars/src/dataframe/command/computation/math.rs new file mode 100644 index 0000000000..be140a5d2d --- /dev/null +++ b/crates/nu_plugin_polars/src/dataframe/command/computation/math.rs @@ -0,0 +1,255 @@ +use crate::{PolarsPlugin, values::CustomValueSupport}; + +use crate::values::{ + NuDataFrame, NuExpression, PolarsPluginObject, PolarsPluginType, cant_convert_err, +}; + +use nu_plugin::{EngineInterface, EvaluatedCall, PluginCommand}; +use nu_protocol::{ + Category, Example, LabeledError, PipelineData, ShellError, Signature, Span, Spanned, + SyntaxShape, Type, Value, +}; +use num::ToPrimitive; +use polars::prelude::df; + +enum FunctionType { + Abs, + Cos, + Dot, + Exp, + Log, + Log1p, + Sign, + Sin, + Sqrt, +} + +impl FunctionType { + fn from_str(func_type: &str, span: Span) -> Result { + match func_type { + "abs" => Ok(Self::Abs), + "cos" => Ok(Self::Cos), + "dot" => Ok(Self::Dot), + "exp" => Ok(Self::Exp), + "log" => Ok(Self::Log), + "log1p" => Ok(Self::Log1p), + "sign" => Ok(Self::Sign), + "sin" => Ok(Self::Sin), + "sqrt" => Ok(Self::Sqrt), + _ => Err(ShellError::GenericError { + error: "Invalid function name".into(), + msg: "".into(), + span: Some(span), + help: Some("See description for accepted functions".into()), + inner: vec![], + }), + } + } + + #[allow(dead_code)] + fn to_str(&self) -> &'static str { + match self { + FunctionType::Abs => "abs", + FunctionType::Cos => "cos", + FunctionType::Dot => "dot", + FunctionType::Exp => "exp", + FunctionType::Log => "log", + FunctionType::Log1p => "log1p", + FunctionType::Sign => "sign", + FunctionType::Sin => "sin", + FunctionType::Sqrt => "sqrt", + } + } +} + +#[derive(Clone)] +pub struct ExprMath; + +impl PluginCommand for ExprMath { + type Plugin = PolarsPlugin; + + fn name(&self) -> &str { + "polars math" + } + + fn description(&self) -> &str { + "Collection of math functions to be applied on one or more column expressions" + } + + fn extra_description(&self) -> &str { + r#"This is an incomplete implementation of the available functions listed here: https://docs.pola.rs/api/python/stable/reference/expressions/computation.html. + + The following functions are currently available: + - abs + - cos + - dot + - exp + - log + - log1p + - sign + - sin + - sqrt + "# + } + + fn signature(&self) -> Signature { + Signature::build(self.name()) + .required( + "type", + SyntaxShape::String, + "Function name. See extra description for full list of accepted values", + ) + .rest( + "args", + SyntaxShape::Any, + "Extra arguments required by some functions", + ) + .input_output_types(vec![( + Type::Custom("expression".into()), + Type::Custom("expression".into()), + )]) + .category(Category::Custom("dataframe".into())) + } + + fn examples(&self) -> Vec { + vec![Example { + description: "Apply function to column expression", + example: "[[a]; [0] [-1] [2] [-3] [4]] + | polars into-df + | polars select [ + (polars col a | polars math abs | polars as a_abs) + (polars col a | polars math sign | polars as a_sign) + (polars col a | polars math exp | polars as a_exp)] + | polars collect", + result: Some( + NuDataFrame::from( + df!( + "a_abs" => [0, 1, 2, 3, 4], + "a_sign" => [0, -1, 1, -1, 1], + "a_exp" => [1.000, 0.36787944117144233, 7.38905609893065, 0.049787068367863944, 54.598150033144236], + ) + .expect("simple df for test should not fail"), + ) + .into_value(Span::test_data()), + ), + }, + Example { + description: "Specify arguments for select functions. See description for more information.", + example: "[[a]; [0] [1] [2] [4] [8] [16]] + | polars into-df + | polars select [ + (polars col a | polars math log 2 | polars as a_base2)] + | polars collect", + result: Some( + NuDataFrame::from( + df!( + "a_base2" => [f64::NEG_INFINITY, 0.0, 1.0, 2.0, 3.0, 4.0], + ) + .expect("simple df for test should not fail"), + ) + .into_value(Span::test_data()), + ), + }, + Example { + description: "Specify arguments for select functions. See description for more information.", + example: "[[a b]; [0 0] [1 1] [2 2] [3 3] [4 4] [5 5]] + | polars into-df + | polars select [ + (polars col a | polars math dot (polars col b) | polars as ab)] + | polars collect", + result: Some( + NuDataFrame::from( + df!( + "ab" => [55.0], + ) + .expect("simple df for test should not fail"), + ) + .into_value(Span::test_data()), + ), + } + ] + } + + fn run( + &self, + plugin: &Self::Plugin, + engine: &EngineInterface, + call: &EvaluatedCall, + input: PipelineData, + ) -> Result { + let metadata = input.metadata(); + let value = input.into_value(call.head)?; + let func_type: Spanned = call.req(0)?; + let func_type = FunctionType::from_str(&func_type.item, func_type.span)?; + + match PolarsPluginObject::try_from_value(plugin, &value)? { + PolarsPluginObject::NuExpression(expr) => { + command_expr(plugin, engine, call, func_type, expr) + } + _ => Err(cant_convert_err(&value, &[PolarsPluginType::NuExpression])), + } + .map_err(LabeledError::from) + .map(|pd| pd.set_metadata(metadata)) + } +} + +fn command_expr( + plugin: &PolarsPlugin, + engine: &EngineInterface, + call: &EvaluatedCall, + func_type: FunctionType, + expr: NuExpression, +) -> Result { + let res = expr.into_polars(); + + let res: NuExpression = match func_type { + FunctionType::Abs => res.abs(), + FunctionType::Cos => res.cos(), + FunctionType::Dot => { + let expr = match call.rest::(1)?.first() { + None => Err(ShellError::GenericError { error: "Second expression to compute dot product with must be provided".into(), msg: "".into(), span: Some(call.head), help: None, inner: vec![] }), + Some(value) => { + match PolarsPluginObject::try_from_value(plugin, value)? { + PolarsPluginObject::NuExpression(expr) => { + Ok(expr.into_polars()) + } + _ => Err(cant_convert_err(value, &[PolarsPluginType::NuExpression])) + } + } + }?; + res.dot(expr) + } + FunctionType::Exp => res.exp(), + FunctionType::Log => { + let base = match call.rest::(1)?.first() { + // default natural log + None => Ok(std::f64::consts::E), + Some(value) => match value { + Value::Float { val, .. } => Ok(*val), + Value::Int { val, .. } => Ok(val.to_f64().expect("i64 to f64 conversion should not panic")), + _ => Err(ShellError::GenericError { error: "log base must be a float or integer. Leave base unspecified for natural log".into(), msg: "".into(), span: Some(value.span()), help: None, inner: vec![] }), + }, + }?; + + res.log(base) + } + FunctionType::Log1p => res.log1p(), + FunctionType::Sign => res.sign(), + FunctionType::Sin => res.sin(), + FunctionType::Sqrt => res.sqrt(), + } + .into(); + + res.to_pipeline_data(plugin, engine, call.head) +} + +#[cfg(test)] +mod test { + use super::*; + use crate::test::test_polars_plugin_command; + + #[test] + fn test_examples() -> Result<(), ShellError> { + test_polars_plugin_command(&ExprMath) + } +} diff --git a/crates/nu_plugin_polars/src/dataframe/command/computation/mod.rs b/crates/nu_plugin_polars/src/dataframe/command/computation/mod.rs new file mode 100644 index 0000000000..d4ed540b67 --- /dev/null +++ b/crates/nu_plugin_polars/src/dataframe/command/computation/mod.rs @@ -0,0 +1,10 @@ +mod math; + +use crate::PolarsPlugin; +use nu_plugin::PluginCommand; + +use math::ExprMath; + +pub(crate) fn computation_commands() -> Vec>> { + vec![Box::new(ExprMath)] +} diff --git a/crates/nu_plugin_polars/src/dataframe/command/mod.rs b/crates/nu_plugin_polars/src/dataframe/command/mod.rs index c439411b75..e629e8d690 100644 --- a/crates/nu_plugin_polars/src/dataframe/command/mod.rs +++ b/crates/nu_plugin_polars/src/dataframe/command/mod.rs @@ -1,5 +1,6 @@ pub mod aggregation; pub mod boolean; +pub mod computation; pub mod core; pub mod data; pub mod datetime; diff --git a/crates/nu_plugin_polars/src/lib.rs b/crates/nu_plugin_polars/src/lib.rs index 8a75509f47..9612ddabbd 100644 --- a/crates/nu_plugin_polars/src/lib.rs +++ b/crates/nu_plugin_polars/src/lib.rs @@ -6,9 +6,10 @@ use std::{ use cache::cache_commands; pub use cache::{Cache, Cacheable}; use command::{ - aggregation::aggregation_commands, boolean::boolean_commands, core::core_commands, - data::data_commands, datetime::datetime_commands, index::index_commands, - integer::integer_commands, list::list_commands, string::string_commands, stub::PolarsCmd, + aggregation::aggregation_commands, boolean::boolean_commands, + computation::computation_commands, core::core_commands, data::data_commands, + datetime::datetime_commands, index::index_commands, integer::integer_commands, + list::list_commands, string::string_commands, stub::PolarsCmd, }; use log::debug; use nu_plugin::{EngineInterface, Plugin, PluginCommand}; @@ -88,6 +89,7 @@ impl Plugin for PolarsPlugin { commands.append(&mut aggregation_commands()); commands.append(&mut boolean_commands()); commands.append(&mut core_commands()); + commands.append(&mut computation_commands()); commands.append(&mut data_commands()); commands.append(&mut datetime_commands()); commands.append(&mut index_commands());