Files
nushell/crates/nu_plugin_polars/src/lib.rs
pyz4 37bc922a67 feat(polars): add polars math expression (#15822)
<!--
if this PR closes one or more issues, you can automatically link the PR
with
them by using one of the [*linking
keywords*](https://docs.github.com/en/issues/tracking-your-work-with-issues/linking-a-pull-request-to-an-issue#linking-a-pull-request-to-an-issue-using-a-keyword),
e.g.
- this PR should close #xxxx
- fixes #xxxx

you can also mention related issues, PRs or discussions!
-->

# Description
<!--
Thank you for improving Nushell. Please, check our [contributing
guide](../CONTRIBUTING.md) and talk to the core team before making major
changes.

Description of your pull request goes here. **Provide examples and/or
screenshots** if your changes affect the user experience.
-->
This PR adds a number of math functions under a single `polars math`
command that apply to one or more column expressions.

Note, `polars math` currently resides in the new module
dataframe/command/command/computation/math.rs. I'm open to alternative
organization and naming suggestions.

```nushell
Collection of math functions to be applied on one or more column expressions

This is an incomplete implementation of the available functions listed here: https://docs.pola.rs/api/python/stable/reference/expressions/computation.html.

        The following functions are currently available:
        - abs
        - cos
        - dot <expression>
        - exp
        - log <base; default e>
        - log1p
        - sign
        - sin
        - sqrt


Usage:
  > polars math <type> ...(args)

Flags:
  -h, --help: Display the help message for this command

Parameters:
  type <string>: Function name. See extra description for full list of accepted values
  ...args <any>: Extra arguments required by some functions

Input/output types:
  ╭───┬────────────┬────────────╮
  │ # │   input    │   output   │
  ├───┼────────────┼────────────┤
  │ 0 │ expression │ expression │
  ╰───┴────────────┴────────────╯

Examples:
  Apply function to column expression
  > [[a]; [0] [-1] [2] [-3] [4]]
                    | polars into-df
                    | polars select [
                        (polars col a | polars math abs | polars as a_abs)
                        (polars col a | polars math sign | polars as a_sign)
                        (polars col a | polars math exp | polars as a_exp)]
                    | polars collect
  ╭───┬───────┬────────┬────────╮
  │ # │ a_abs │ a_sign │ a_exp  │
  ├───┼───────┼────────┼────────┤
  │ 0 │     0 │      0 │  1.000 │
  │ 1 │     1 │     -1 │  0.368 │
  │ 2 │     2 │      1 │  7.389 │
  │ 3 │     3 │     -1 │  0.050 │
  │ 4 │     4 │      1 │ 54.598 │
  ╰───┴───────┴────────┴────────╯

  Specify arguments for select functions. See description for more information.
  > [[a]; [0] [1] [2] [4] [8] [16]]
                    | polars into-df
                    | polars select [
                        (polars col a | polars math log 2 | polars as a_base2)]
                    | polars collect
  ╭───┬─────────╮
  │ # │ a_base2 │
  ├───┼─────────┤
  │ 0 │    -inf │
  │ 1 │   0.000 │
  │ 2 │   1.000 │
  │ 3 │   2.000 │
  │ 4 │   3.000 │
  │ 5 │   4.000 │
  ╰───┴─────────╯

  Specify arguments for select functions. See description for more information.
  > [[a b]; [0 0] [1 1] [2 2] [3 3] [4 4] [5 5]]
                    | polars into-df
                    | polars select [
                        (polars col a | polars math dot (polars col b) | polars as ab)]
                    | polars collect
  ╭───┬────────╮
  │ # │   ab   │
  ├───┼────────┤
  │ 0 │ 55.000 │
  ╰───┴────────╯
``` 

# User-Facing Changes
<!-- List of all changes that impact the user experience here. This
helps us keep track of breaking changes. -->
No breaking changes.

# Tests + Formatting
<!--
Don't forget to add tests that cover your changes.

Make sure you've run and fixed any issues with these commands:

- `cargo fmt --all -- --check` to check standard code formatting (`cargo
fmt --all` applies these changes)
- `cargo clippy --workspace -- -D warnings -D clippy::unwrap_used` to
check that you're using the standard code style
- `cargo test --workspace` to check that all tests pass (on Windows make
sure to [enable developer
mode](https://learn.microsoft.com/en-us/windows/apps/get-started/developer-mode-features-and-debugging))
- `cargo run -- -c "use toolkit.nu; toolkit test stdlib"` to run the
tests for the standard library

> **Note**
> from `nushell` you can also use the `toolkit` as follows
> ```bash
> use toolkit.nu # or use an `env_change` hook to activate it
automatically
> toolkit check pr
> ```
-->
Example tests were added to `polars math`.

# After Submitting
<!-- If your PR had any user-facing changes, update [the
documentation](https://github.com/nushell/nushell.github.io) after the
PR is merged, if necessary. This will help us keep the docs up to date.
-->
2025-05-27 16:35:48 -07:00

342 lines
12 KiB
Rust

use std::{
cmp::Ordering,
panic::{AssertUnwindSafe, catch_unwind},
};
use cache::cache_commands;
pub use cache::{Cache, Cacheable};
use command::{
aggregation::aggregation_commands, boolean::boolean_commands,
computation::computation_commands, core::core_commands, data::data_commands,
datetime::datetime_commands, index::index_commands, integer::integer_commands,
list::list_commands, string::string_commands, stub::PolarsCmd,
};
use log::debug;
use nu_plugin::{EngineInterface, Plugin, PluginCommand};
mod cache;
mod cloud;
pub mod dataframe;
pub use dataframe::*;
use nu_protocol::{CustomValue, LabeledError, ShellError, Span, Spanned, Value, ast::Operator};
use tokio::runtime::Runtime;
use values::CustomValueType;
use crate::values::PolarsPluginCustomValue;
pub trait EngineWrapper {
fn get_env_var(&self, key: &str) -> Option<String>;
fn use_color(&self) -> bool;
fn set_gc_disabled(&self, disabled: bool) -> Result<(), ShellError>;
}
impl EngineWrapper for &EngineInterface {
fn get_env_var(&self, key: &str) -> Option<String> {
EngineInterface::get_env_var(self, key)
.ok()
.flatten()
.map(|x| match x {
Value::String { val, .. } => val,
_ => "".to_string(),
})
}
fn use_color(&self) -> bool {
self.get_config()
.ok()
.and_then(|config| config.color_config.get("use_color").cloned())
.unwrap_or(Value::bool(false, Span::unknown()))
.is_true()
}
fn set_gc_disabled(&self, disabled: bool) -> Result<(), ShellError> {
debug!("set_gc_disabled called with {disabled}");
EngineInterface::set_gc_disabled(self, disabled)
}
}
pub struct PolarsPlugin {
pub(crate) cache: Cache,
/// For testing purposes only
pub(crate) disable_cache_drop: bool,
pub(crate) runtime: Runtime,
}
impl PolarsPlugin {
pub fn new() -> Result<Self, ShellError> {
Ok(Self {
cache: Cache::default(),
disable_cache_drop: false,
runtime: Runtime::new().map_err(|e| ShellError::GenericError {
error: format!("Could not instantiate tokio: {e}"),
msg: "".into(),
span: None,
help: None,
inner: vec![],
})?,
})
}
}
impl Plugin for PolarsPlugin {
fn version(&self) -> String {
env!("CARGO_PKG_VERSION").into()
}
fn commands(&self) -> Vec<Box<dyn PluginCommand<Plugin = Self>>> {
let mut commands: Vec<Box<dyn PluginCommand<Plugin = Self>>> = vec![Box::new(PolarsCmd)];
commands.append(&mut aggregation_commands());
commands.append(&mut boolean_commands());
commands.append(&mut core_commands());
commands.append(&mut computation_commands());
commands.append(&mut data_commands());
commands.append(&mut datetime_commands());
commands.append(&mut index_commands());
commands.append(&mut integer_commands());
commands.append(&mut string_commands());
commands.append(&mut list_commands());
commands.append(&mut cache_commands());
commands
}
fn custom_value_dropped(
&self,
engine: &EngineInterface,
custom_value: Box<dyn CustomValue>,
) -> Result<(), LabeledError> {
debug!("custom_value_dropped called {:?}", custom_value);
if !self.disable_cache_drop {
let id = CustomValueType::try_from_custom_value(custom_value)?.id();
let _ = self.cache.remove(engine, &id, false);
}
Ok(())
}
fn custom_value_to_base_value(
&self,
engine: &EngineInterface,
custom_value: Spanned<Box<dyn CustomValue>>,
) -> Result<Value, LabeledError> {
let result = match CustomValueType::try_from_custom_value(custom_value.item)? {
CustomValueType::NuDataFrame(cv) => cv.custom_value_to_base_value(self, engine),
CustomValueType::NuLazyFrame(cv) => cv.custom_value_to_base_value(self, engine),
CustomValueType::NuExpression(cv) => cv.custom_value_to_base_value(self, engine),
CustomValueType::NuLazyGroupBy(cv) => cv.custom_value_to_base_value(self, engine),
CustomValueType::NuWhen(cv) => cv.custom_value_to_base_value(self, engine),
CustomValueType::NuDataType(cv) => cv.custom_value_to_base_value(self, engine),
CustomValueType::NuSchema(cv) => cv.custom_value_to_base_value(self, engine),
};
Ok(result?)
}
fn custom_value_operation(
&self,
engine: &EngineInterface,
left: Spanned<Box<dyn CustomValue>>,
operator: Spanned<Operator>,
right: Value,
) -> Result<Value, LabeledError> {
let result = match CustomValueType::try_from_custom_value(left.item)? {
CustomValueType::NuDataFrame(cv) => {
cv.custom_value_operation(self, engine, left.span, operator, right)
}
CustomValueType::NuLazyFrame(cv) => {
cv.custom_value_operation(self, engine, left.span, operator, right)
}
CustomValueType::NuExpression(cv) => {
cv.custom_value_operation(self, engine, left.span, operator, right)
}
CustomValueType::NuLazyGroupBy(cv) => {
cv.custom_value_operation(self, engine, left.span, operator, right)
}
CustomValueType::NuWhen(cv) => {
cv.custom_value_operation(self, engine, left.span, operator, right)
}
CustomValueType::NuDataType(cv) => {
cv.custom_value_operation(self, engine, left.span, operator, right)
}
CustomValueType::NuSchema(cv) => {
cv.custom_value_operation(self, engine, left.span, operator, right)
}
};
Ok(result?)
}
fn custom_value_follow_path_int(
&self,
engine: &EngineInterface,
custom_value: Spanned<Box<dyn CustomValue>>,
index: Spanned<usize>,
) -> Result<Value, LabeledError> {
let result = match CustomValueType::try_from_custom_value(custom_value.item)? {
CustomValueType::NuDataFrame(cv) => {
cv.custom_value_follow_path_int(self, engine, custom_value.span, index)
}
CustomValueType::NuLazyFrame(cv) => {
cv.custom_value_follow_path_int(self, engine, custom_value.span, index)
}
CustomValueType::NuExpression(cv) => {
cv.custom_value_follow_path_int(self, engine, custom_value.span, index)
}
CustomValueType::NuLazyGroupBy(cv) => {
cv.custom_value_follow_path_int(self, engine, custom_value.span, index)
}
CustomValueType::NuWhen(cv) => {
cv.custom_value_follow_path_int(self, engine, custom_value.span, index)
}
CustomValueType::NuDataType(cv) => {
cv.custom_value_follow_path_int(self, engine, custom_value.span, index)
}
CustomValueType::NuSchema(cv) => {
cv.custom_value_follow_path_int(self, engine, custom_value.span, index)
}
};
Ok(result?)
}
fn custom_value_follow_path_string(
&self,
engine: &EngineInterface,
custom_value: Spanned<Box<dyn CustomValue>>,
column_name: Spanned<String>,
) -> Result<Value, LabeledError> {
let result = match CustomValueType::try_from_custom_value(custom_value.item)? {
CustomValueType::NuDataFrame(cv) => {
cv.custom_value_follow_path_string(self, engine, custom_value.span, column_name)
}
CustomValueType::NuLazyFrame(cv) => {
cv.custom_value_follow_path_string(self, engine, custom_value.span, column_name)
}
CustomValueType::NuExpression(cv) => {
cv.custom_value_follow_path_string(self, engine, custom_value.span, column_name)
}
CustomValueType::NuLazyGroupBy(cv) => {
cv.custom_value_follow_path_string(self, engine, custom_value.span, column_name)
}
CustomValueType::NuWhen(cv) => {
cv.custom_value_follow_path_string(self, engine, custom_value.span, column_name)
}
CustomValueType::NuDataType(cv) => {
cv.custom_value_follow_path_string(self, engine, custom_value.span, column_name)
}
CustomValueType::NuSchema(cv) => {
cv.custom_value_follow_path_string(self, engine, custom_value.span, column_name)
}
};
Ok(result?)
}
fn custom_value_partial_cmp(
&self,
engine: &EngineInterface,
custom_value: Box<dyn CustomValue>,
other_value: Value,
) -> Result<Option<Ordering>, LabeledError> {
let result = match CustomValueType::try_from_custom_value(custom_value)? {
CustomValueType::NuDataFrame(cv) => {
cv.custom_value_partial_cmp(self, engine, other_value)
}
CustomValueType::NuLazyFrame(cv) => {
cv.custom_value_partial_cmp(self, engine, other_value)
}
CustomValueType::NuExpression(cv) => {
cv.custom_value_partial_cmp(self, engine, other_value)
}
CustomValueType::NuLazyGroupBy(cv) => {
cv.custom_value_partial_cmp(self, engine, other_value)
}
CustomValueType::NuWhen(cv) => cv.custom_value_partial_cmp(self, engine, other_value),
CustomValueType::NuDataType(cv) => {
cv.custom_value_partial_cmp(self, engine, other_value)
}
CustomValueType::NuSchema(cv) => cv.custom_value_partial_cmp(self, engine, other_value),
};
Ok(result?)
}
}
pub(crate) fn handle_panic<F, R>(f: F, span: Span) -> Result<R, ShellError>
where
F: FnOnce() -> Result<R, ShellError>,
{
match catch_unwind(AssertUnwindSafe(f)) {
Ok(inner_result) => inner_result,
Err(_) => Err(ShellError::GenericError {
error: "Panic occurred".into(),
msg: "".into(),
span: Some(span),
help: None,
inner: vec![],
}),
}
}
#[cfg(test)]
pub mod test {
use super::*;
use crate::values::PolarsPluginObject;
use nu_plugin_test_support::PluginTest;
use nu_protocol::{ShellError, Span, engine::Command};
impl PolarsPlugin {
/// Creates a new polars plugin in test mode
pub fn new_test_mode() -> Result<Self, ShellError> {
Ok(PolarsPlugin {
disable_cache_drop: true,
..PolarsPlugin::new()?
})
}
}
struct TestEngineWrapper;
impl EngineWrapper for TestEngineWrapper {
fn get_env_var(&self, key: &str) -> Option<String> {
std::env::var(key).ok()
}
fn use_color(&self) -> bool {
false
}
fn set_gc_disabled(&self, _disabled: bool) -> Result<(), ShellError> {
Ok(())
}
}
pub fn test_polars_plugin_command(command: &impl PluginCommand) -> Result<(), ShellError> {
test_polars_plugin_command_with_decls(command, vec![])
}
pub fn test_polars_plugin_command_with_decls(
command: &impl PluginCommand,
decls: Vec<Box<dyn Command>>,
) -> Result<(), ShellError> {
let plugin = PolarsPlugin::new_test_mode()?;
let examples = command.examples();
// we need to cache values in the examples
for example in &examples {
if let Some(ref result) = example.result {
// if it's a polars plugin object, try to cache it
if let Ok(obj) = PolarsPluginObject::try_from_value(&plugin, result) {
let id = obj.id();
plugin
.cache
.insert(TestEngineWrapper {}, id, obj, Span::test_data())
.unwrap();
}
}
}
let mut plugin_test = PluginTest::new(command.name(), plugin.into())?;
for decl in decls {
let _ = plugin_test.add_decl(decl)?;
}
plugin_test.test_examples(&examples)
}
}