feat(polars): add polars math expression (#15822)

<!--
if this PR closes one or more issues, you can automatically link the PR
with
them by using one of the [*linking
keywords*](https://docs.github.com/en/issues/tracking-your-work-with-issues/linking-a-pull-request-to-an-issue#linking-a-pull-request-to-an-issue-using-a-keyword),
e.g.
- this PR should close #xxxx
- fixes #xxxx

you can also mention related issues, PRs or discussions!
-->

# Description
<!--
Thank you for improving Nushell. Please, check our [contributing
guide](../CONTRIBUTING.md) and talk to the core team before making major
changes.

Description of your pull request goes here. **Provide examples and/or
screenshots** if your changes affect the user experience.
-->
This PR adds a number of math functions under a single `polars math`
command that apply to one or more column expressions.

Note, `polars math` currently resides in the new module
dataframe/command/command/computation/math.rs. I'm open to alternative
organization and naming suggestions.

```nushell
Collection of math functions to be applied on one or more column expressions

This is an incomplete implementation of the available functions listed here: https://docs.pola.rs/api/python/stable/reference/expressions/computation.html.

        The following functions are currently available:
        - abs
        - cos
        - dot <expression>
        - exp
        - log <base; default e>
        - log1p
        - sign
        - sin
        - sqrt


Usage:
  > polars math <type> ...(args)

Flags:
  -h, --help: Display the help message for this command

Parameters:
  type <string>: Function name. See extra description for full list of accepted values
  ...args <any>: Extra arguments required by some functions

Input/output types:
  ╭───┬────────────┬────────────╮
  │ # │   input    │   output   │
  ├───┼────────────┼────────────┤
  │ 0 │ expression │ expression │
  ╰───┴────────────┴────────────╯

Examples:
  Apply function to column expression
  > [[a]; [0] [-1] [2] [-3] [4]]
                    | polars into-df
                    | polars select [
                        (polars col a | polars math abs | polars as a_abs)
                        (polars col a | polars math sign | polars as a_sign)
                        (polars col a | polars math exp | polars as a_exp)]
                    | polars collect
  ╭───┬───────┬────────┬────────╮
  │ # │ a_abs │ a_sign │ a_exp  │
  ├───┼───────┼────────┼────────┤
  │ 0 │     0 │      0 │  1.000 │
  │ 1 │     1 │     -1 │  0.368 │
  │ 2 │     2 │      1 │  7.389 │
  │ 3 │     3 │     -1 │  0.050 │
  │ 4 │     4 │      1 │ 54.598 │
  ╰───┴───────┴────────┴────────╯

  Specify arguments for select functions. See description for more information.
  > [[a]; [0] [1] [2] [4] [8] [16]]
                    | polars into-df
                    | polars select [
                        (polars col a | polars math log 2 | polars as a_base2)]
                    | polars collect
  ╭───┬─────────╮
  │ # │ a_base2 │
  ├───┼─────────┤
  │ 0 │    -inf │
  │ 1 │   0.000 │
  │ 2 │   1.000 │
  │ 3 │   2.000 │
  │ 4 │   3.000 │
  │ 5 │   4.000 │
  ╰───┴─────────╯

  Specify arguments for select functions. See description for more information.
  > [[a b]; [0 0] [1 1] [2 2] [3 3] [4 4] [5 5]]
                    | polars into-df
                    | polars select [
                        (polars col a | polars math dot (polars col b) | polars as ab)]
                    | polars collect
  ╭───┬────────╮
  │ # │   ab   │
  ├───┼────────┤
  │ 0 │ 55.000 │
  ╰───┴────────╯
``` 

# User-Facing Changes
<!-- List of all changes that impact the user experience here. This
helps us keep track of breaking changes. -->
No breaking changes.

# Tests + Formatting
<!--
Don't forget to add tests that cover your changes.

Make sure you've run and fixed any issues with these commands:

- `cargo fmt --all -- --check` to check standard code formatting (`cargo
fmt --all` applies these changes)
- `cargo clippy --workspace -- -D warnings -D clippy::unwrap_used` to
check that you're using the standard code style
- `cargo test --workspace` to check that all tests pass (on Windows make
sure to [enable developer
mode](https://learn.microsoft.com/en-us/windows/apps/get-started/developer-mode-features-and-debugging))
- `cargo run -- -c "use toolkit.nu; toolkit test stdlib"` to run the
tests for the standard library

> **Note**
> from `nushell` you can also use the `toolkit` as follows
> ```bash
> use toolkit.nu # or use an `env_change` hook to activate it
automatically
> toolkit check pr
> ```
-->
Example tests were added to `polars math`.

# After Submitting
<!-- If your PR had any user-facing changes, update [the
documentation](https://github.com/nushell/nushell.github.io) after the
PR is merged, if necessary. This will help us keep the docs up to date.
-->
This commit is contained in:
pyz4 2025-05-27 19:35:48 -04:00 committed by GitHub
parent ae51f6d722
commit 37bc922a67
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 275 additions and 3 deletions

View File

@ -55,6 +55,7 @@ url.workspace = true
[dependencies.polars]
features = [
"abs",
"arg_where",
"bigidx",
"checked_arithmetic",
@ -78,6 +79,7 @@ features = [
"is_in",
"json",
"lazy",
"log",
"object",
"parquet",
"pivot",
@ -87,12 +89,14 @@ features = [
"round_series",
"serde",
"serde-lazy",
"sign",
"strings",
"string_to_integer",
"streaming",
"timezones",
"temporal",
"to_dummies",
"trigonometry",
]
optional = false
version = "0.46"

View File

@ -0,0 +1,255 @@
use crate::{PolarsPlugin, values::CustomValueSupport};
use crate::values::{
NuDataFrame, NuExpression, PolarsPluginObject, PolarsPluginType, cant_convert_err,
};
use nu_plugin::{EngineInterface, EvaluatedCall, PluginCommand};
use nu_protocol::{
Category, Example, LabeledError, PipelineData, ShellError, Signature, Span, Spanned,
SyntaxShape, Type, Value,
};
use num::ToPrimitive;
use polars::prelude::df;
enum FunctionType {
Abs,
Cos,
Dot,
Exp,
Log,
Log1p,
Sign,
Sin,
Sqrt,
}
impl FunctionType {
fn from_str(func_type: &str, span: Span) -> Result<Self, ShellError> {
match func_type {
"abs" => Ok(Self::Abs),
"cos" => Ok(Self::Cos),
"dot" => Ok(Self::Dot),
"exp" => Ok(Self::Exp),
"log" => Ok(Self::Log),
"log1p" => Ok(Self::Log1p),
"sign" => Ok(Self::Sign),
"sin" => Ok(Self::Sin),
"sqrt" => Ok(Self::Sqrt),
_ => Err(ShellError::GenericError {
error: "Invalid function name".into(),
msg: "".into(),
span: Some(span),
help: Some("See description for accepted functions".into()),
inner: vec![],
}),
}
}
#[allow(dead_code)]
fn to_str(&self) -> &'static str {
match self {
FunctionType::Abs => "abs",
FunctionType::Cos => "cos",
FunctionType::Dot => "dot",
FunctionType::Exp => "exp",
FunctionType::Log => "log",
FunctionType::Log1p => "log1p",
FunctionType::Sign => "sign",
FunctionType::Sin => "sin",
FunctionType::Sqrt => "sqrt",
}
}
}
#[derive(Clone)]
pub struct ExprMath;
impl PluginCommand for ExprMath {
type Plugin = PolarsPlugin;
fn name(&self) -> &str {
"polars math"
}
fn description(&self) -> &str {
"Collection of math functions to be applied on one or more column expressions"
}
fn extra_description(&self) -> &str {
r#"This is an incomplete implementation of the available functions listed here: https://docs.pola.rs/api/python/stable/reference/expressions/computation.html.
The following functions are currently available:
- abs
- cos
- dot <expression>
- exp
- log <base; default e>
- log1p
- sign
- sin
- sqrt
"#
}
fn signature(&self) -> Signature {
Signature::build(self.name())
.required(
"type",
SyntaxShape::String,
"Function name. See extra description for full list of accepted values",
)
.rest(
"args",
SyntaxShape::Any,
"Extra arguments required by some functions",
)
.input_output_types(vec![(
Type::Custom("expression".into()),
Type::Custom("expression".into()),
)])
.category(Category::Custom("dataframe".into()))
}
fn examples(&self) -> Vec<Example> {
vec![Example {
description: "Apply function to column expression",
example: "[[a]; [0] [-1] [2] [-3] [4]]
| polars into-df
| polars select [
(polars col a | polars math abs | polars as a_abs)
(polars col a | polars math sign | polars as a_sign)
(polars col a | polars math exp | polars as a_exp)]
| polars collect",
result: Some(
NuDataFrame::from(
df!(
"a_abs" => [0, 1, 2, 3, 4],
"a_sign" => [0, -1, 1, -1, 1],
"a_exp" => [1.000, 0.36787944117144233, 7.38905609893065, 0.049787068367863944, 54.598150033144236],
)
.expect("simple df for test should not fail"),
)
.into_value(Span::test_data()),
),
},
Example {
description: "Specify arguments for select functions. See description for more information.",
example: "[[a]; [0] [1] [2] [4] [8] [16]]
| polars into-df
| polars select [
(polars col a | polars math log 2 | polars as a_base2)]
| polars collect",
result: Some(
NuDataFrame::from(
df!(
"a_base2" => [f64::NEG_INFINITY, 0.0, 1.0, 2.0, 3.0, 4.0],
)
.expect("simple df for test should not fail"),
)
.into_value(Span::test_data()),
),
},
Example {
description: "Specify arguments for select functions. See description for more information.",
example: "[[a b]; [0 0] [1 1] [2 2] [3 3] [4 4] [5 5]]
| polars into-df
| polars select [
(polars col a | polars math dot (polars col b) | polars as ab)]
| polars collect",
result: Some(
NuDataFrame::from(
df!(
"ab" => [55.0],
)
.expect("simple df for test should not fail"),
)
.into_value(Span::test_data()),
),
}
]
}
fn run(
&self,
plugin: &Self::Plugin,
engine: &EngineInterface,
call: &EvaluatedCall,
input: PipelineData,
) -> Result<PipelineData, LabeledError> {
let metadata = input.metadata();
let value = input.into_value(call.head)?;
let func_type: Spanned<String> = call.req(0)?;
let func_type = FunctionType::from_str(&func_type.item, func_type.span)?;
match PolarsPluginObject::try_from_value(plugin, &value)? {
PolarsPluginObject::NuExpression(expr) => {
command_expr(plugin, engine, call, func_type, expr)
}
_ => Err(cant_convert_err(&value, &[PolarsPluginType::NuExpression])),
}
.map_err(LabeledError::from)
.map(|pd| pd.set_metadata(metadata))
}
}
fn command_expr(
plugin: &PolarsPlugin,
engine: &EngineInterface,
call: &EvaluatedCall,
func_type: FunctionType,
expr: NuExpression,
) -> Result<PipelineData, ShellError> {
let res = expr.into_polars();
let res: NuExpression = match func_type {
FunctionType::Abs => res.abs(),
FunctionType::Cos => res.cos(),
FunctionType::Dot => {
let expr = match call.rest::<Value>(1)?.first() {
None => Err(ShellError::GenericError { error: "Second expression to compute dot product with must be provided".into(), msg: "".into(), span: Some(call.head), help: None, inner: vec![] }),
Some(value) => {
match PolarsPluginObject::try_from_value(plugin, value)? {
PolarsPluginObject::NuExpression(expr) => {
Ok(expr.into_polars())
}
_ => Err(cant_convert_err(value, &[PolarsPluginType::NuExpression]))
}
}
}?;
res.dot(expr)
}
FunctionType::Exp => res.exp(),
FunctionType::Log => {
let base = match call.rest::<Value>(1)?.first() {
// default natural log
None => Ok(std::f64::consts::E),
Some(value) => match value {
Value::Float { val, .. } => Ok(*val),
Value::Int { val, .. } => Ok(val.to_f64().expect("i64 to f64 conversion should not panic")),
_ => Err(ShellError::GenericError { error: "log base must be a float or integer. Leave base unspecified for natural log".into(), msg: "".into(), span: Some(value.span()), help: None, inner: vec![] }),
},
}?;
res.log(base)
}
FunctionType::Log1p => res.log1p(),
FunctionType::Sign => res.sign(),
FunctionType::Sin => res.sin(),
FunctionType::Sqrt => res.sqrt(),
}
.into();
res.to_pipeline_data(plugin, engine, call.head)
}
#[cfg(test)]
mod test {
use super::*;
use crate::test::test_polars_plugin_command;
#[test]
fn test_examples() -> Result<(), ShellError> {
test_polars_plugin_command(&ExprMath)
}
}

View File

@ -0,0 +1,10 @@
mod math;
use crate::PolarsPlugin;
use nu_plugin::PluginCommand;
use math::ExprMath;
pub(crate) fn computation_commands() -> Vec<Box<dyn PluginCommand<Plugin = PolarsPlugin>>> {
vec![Box::new(ExprMath)]
}

View File

@ -1,5 +1,6 @@
pub mod aggregation;
pub mod boolean;
pub mod computation;
pub mod core;
pub mod data;
pub mod datetime;

View File

@ -6,9 +6,10 @@ use std::{
use cache::cache_commands;
pub use cache::{Cache, Cacheable};
use command::{
aggregation::aggregation_commands, boolean::boolean_commands, core::core_commands,
data::data_commands, datetime::datetime_commands, index::index_commands,
integer::integer_commands, list::list_commands, string::string_commands, stub::PolarsCmd,
aggregation::aggregation_commands, boolean::boolean_commands,
computation::computation_commands, core::core_commands, data::data_commands,
datetime::datetime_commands, index::index_commands, integer::integer_commands,
list::list_commands, string::string_commands, stub::PolarsCmd,
};
use log::debug;
use nu_plugin::{EngineInterface, Plugin, PluginCommand};
@ -88,6 +89,7 @@ impl Plugin for PolarsPlugin {
commands.append(&mut aggregation_commands());
commands.append(&mut boolean_commands());
commands.append(&mut core_commands());
commands.append(&mut computation_commands());
commands.append(&mut data_commands());
commands.append(&mut datetime_commands());
commands.append(&mut index_commands());