diff --git a/crates/nu_plugin_polars/Cargo.toml b/crates/nu_plugin_polars/Cargo.toml index 033eb533ee..29f39009d1 100644 --- a/crates/nu_plugin_polars/Cargo.toml +++ b/crates/nu_plugin_polars/Cargo.toml @@ -32,7 +32,7 @@ serde = { version = "1.0", features = ["derive"] } sqlparser = { version = "0.53"} polars-io = { version = "0.46", features = ["avro", "cloud", "aws"]} polars-arrow = { version = "0.46"} -polars-ops = { version = "0.46", features = ["pivot"]} +polars-ops = { version = "0.46", features = ["pivot", "cutqcut"]} polars-plan = { version = "0.46", features = ["regex"]} polars-utils = { version = "0.46"} typetag = "0.2" diff --git a/crates/nu_plugin_polars/src/dataframe/command/data/cut.rs b/crates/nu_plugin_polars/src/dataframe/command/data/cut.rs new file mode 100644 index 0000000000..5e402a96fa --- /dev/null +++ b/crates/nu_plugin_polars/src/dataframe/command/data/cut.rs @@ -0,0 +1,89 @@ +use nu_plugin::{EngineInterface, EvaluatedCall, PluginCommand}; +use nu_protocol::{Category, Example, PipelineData, ShellError, Signature, SyntaxShape, Type}; +use polars::prelude::PlSmallStr; + +use crate::{ + values::{CustomValueSupport, NuDataFrame}, + PolarsPlugin, +}; + +pub struct CutSeries; + +impl PluginCommand for CutSeries { + type Plugin = PolarsPlugin; + fn name(&self) -> &str { + "polars cut" + } + + fn description(&self) -> &str { + "Bin continuous values into discrete categories for a series." + } + + fn signature(&self) -> nu_protocol::Signature { + Signature::build(self.name()) + .required("breaks", SyntaxShape::Any, "Dataframe that contains a series of unique cut points.") + .named( + "labels", + SyntaxShape::List(Box::new(SyntaxShape::String)), + "Names of the categories. The number of labels must be equal to the number of cut points plus one.", + Some('l'), + ) + .switch("left_closed", "Set the intervals to be left-closed instead of right-closed.", Some('c')) + .switch("include_breaks", "Include a column with the right endpoint of the bin each observation falls in. This will change the data type of the output from a Categorical to a Struct.", Some('b')) + .input_output_type( + Type::Custom("dataframe".into()), + Type::Custom("dataframe".into()), + ) + .category(Category::Custom("dataframe".into())) + } + + fn examples(&self) -> Vec { + vec![Example { + description: "Divide the column into three categories.", + example: r#"[-2, -1, 0, 1, 2] | polars into-df | polars cut [-1, 1] --labels ["a", "b", "c"]"#, + result: None, + }] + } + + fn run( + &self, + plugin: &Self::Plugin, + engine: &EngineInterface, + call: &EvaluatedCall, + input: PipelineData, + ) -> Result { + command(plugin, engine, call, input).map_err(|e| e.into()) + } +} + +fn command( + plugin: &PolarsPlugin, + engine: &EngineInterface, + call: &EvaluatedCall, + input: PipelineData, +) -> Result { + let df = NuDataFrame::try_from_pipeline_coerce(plugin, input, call.head)?; + let series = df.as_series(call.head)?; + + let breaks = call.req::>(0)?; + + let labels: Option> = call.get_flag::>("labels")?.map(|l| { + l.into_iter() + .map(PlSmallStr::from) + .collect::>() + }); + + let left_closed = call.has_flag("left_closed")?; + let include_breaks = call.has_flag("include_breaks")?; + + let new_series = polars_ops::series::cut(&series, breaks, labels, left_closed, include_breaks) + .map_err(|e| ShellError::GenericError { + error: "Error cutting series".into(), + msg: e.to_string(), + span: Some(call.head), + help: None, + inner: vec![], + })?; + + NuDataFrame::try_from_series(new_series, call.head)?.to_pipeline_data(plugin, engine, call.head) +} diff --git a/crates/nu_plugin_polars/src/dataframe/command/data/mod.rs b/crates/nu_plugin_polars/src/dataframe/command/data/mod.rs index 6089e3c858..dd6cdc4eff 100644 --- a/crates/nu_plugin_polars/src/dataframe/command/data/mod.rs +++ b/crates/nu_plugin_polars/src/dataframe/command/data/mod.rs @@ -5,6 +5,7 @@ mod cast; mod col; mod collect; mod concat; +mod cut; mod drop; mod drop_duplicates; mod drop_nulls; @@ -22,6 +23,7 @@ mod last; mod len; mod lit; mod pivot; +mod qcut; mod query_df; mod rename; mod reverse; @@ -75,6 +77,7 @@ pub(crate) fn data_commands() -> Vec Vec &str { + "polars qcut" + } + + fn description(&self) -> &str { + "Bin continuous values into discrete categories based on their quantiles for a series." + } + + fn signature(&self) -> nu_protocol::Signature { + Signature::build(self.name()) + .required("quantiles", SyntaxShape::Any, "Either a list of quantile probabilities between 0 and 1 or a positive integer determining the number of bins with uniform probability.") + .named( + "labels", + SyntaxShape::List(Box::new(SyntaxShape::String)), + "Names of the categories. The number of labels must be equal to the number of cut points plus one.", + Some('l'), + ) + .switch("left_closed", "Set the intervals to be left-closed instead of right-closed.", Some('c')) + .switch("include_breaks", "Include a column with the right endpoint of the bin each observation falls in. This will change the data type of the output from a Categorical to a Struct.", Some('b')) + .switch("allow_duplicates", "If set, duplicates in the resulting quantiles are dropped, rather than raising an error. This can happen even with unique probabilities, depending on the data.", Some('d')) + .input_output_type( + Type::Custom("dataframe".into()), + Type::Custom("dataframe".into()), + ) + .category(Category::Custom("dataframe".into())) + } + + fn examples(&self) -> Vec { + vec![Example { + description: "Divide a column into three categories according to pre-defined quantile probabilities.", + example: r#"[-2, -1, 0, 1, 2] | polars into-df | polars qcut [0.25, 0.75] --labels ["a", "b", "c"]"#, + result: None, + }] + } + + fn run( + &self, + plugin: &Self::Plugin, + engine: &EngineInterface, + call: &EvaluatedCall, + input: PipelineData, + ) -> Result { + command(plugin, engine, call, input).map_err(|e| e.into()) + } +} + +fn command( + plugin: &PolarsPlugin, + engine: &EngineInterface, + call: &EvaluatedCall, + input: PipelineData, +) -> Result { + let df = NuDataFrame::try_from_pipeline_coerce(plugin, input, call.head)?; + let series = df.as_series(call.head)?; + + let quantiles = call.req::>(0)?; + + let labels: Option> = call.get_flag::>("labels")?.map(|l| { + l.into_iter() + .map(PlSmallStr::from) + .collect::>() + }); + + let left_closed = call.has_flag("left_closed")?; + let include_breaks = call.has_flag("include_breaks")?; + let allow_duplicates = call.has_flag("allow_duplicates")?; + + let new_series = polars_ops::series::qcut( + &series, + quantiles, + labels, + left_closed, + allow_duplicates, + include_breaks, + ) + .map_err(|e| ShellError::GenericError { + error: "Error cutting series".into(), + msg: e.to_string(), + span: Some(call.head), + help: None, + inner: vec![], + })?; + + NuDataFrame::try_from_series(new_series, call.head)?.to_pipeline_data(plugin, engine, call.head) +}