diff --git a/Cargo.lock b/Cargo.lock index 712986b8b6..b82951b2b7 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3442,7 +3442,7 @@ dependencies = [ "polars-plan", "polars-utils", "serde", - "sqlparser 0.47.0", + "sqlparser", "tempfile", "typetag", "uuid", @@ -4052,9 +4052,9 @@ dependencies = [ [[package]] name = "polars" -version = "0.40.0" +version = "0.41.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e148396dca5496566880fa19374f3f789a29db94e3eb458afac1497b4bac5442" +checksum = "ce49e10a756f68eb99c102c6b2a0cbc0c583a0fa7263536ad0913d94be878d2d" dependencies = [ "getrandom", "polars-arrow", @@ -4072,9 +4072,9 @@ dependencies = [ [[package]] name = "polars-arrow" -version = "0.40.0" +version = "0.41.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1cb5e11cd0752ae022fa6ca3afa50a14b0301b7ce53c0135828fbb0f4fa8303e" +checksum = "b436f83f62e864f0d91871e26528f2c5552c7cf07c8d77547f1b8e3fde22bd27" dependencies = [ "ahash 0.8.11", "atoi", @@ -4120,9 +4120,9 @@ dependencies = [ [[package]] name = "polars-compute" -version = "0.40.0" +version = "0.41.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "89fc4578f826234cdecb782952aa9c479dc49373f81694a7b439c70b6f609ba0" +checksum = "f6758f834f07e622a2f859bebb542b2b7f8879b8704dbb2b2bbab460ddcdca4b" dependencies = [ "bytemuck", "either", @@ -4136,9 +4136,9 @@ dependencies = [ [[package]] name = "polars-core" -version = "0.40.0" +version = "0.41.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e490c6bace1366a558feea33d1846f749a8ca90bd72a6748752bc65bb4710b2a" +checksum = "7ed262e9bdda15a12a9bfcfc9200bec5253335633dbd86cf5b94fda0194244b3" dependencies = [ "ahash 0.8.11", "bitflags 2.5.0", @@ -4170,9 +4170,9 @@ dependencies = [ [[package]] name = "polars-error" -version = "0.40.0" +version = "0.41.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "08888f58e61599b00f5ea0c2ccdc796b54b9859559cc0d4582733509451fa01a" +checksum = "53e1707a17475ba5e74c349154b415e3148a1a275e395965427971b5e53ad621" dependencies = [ "avro-schema", "polars-arrow-format", @@ -4183,9 +4183,9 @@ dependencies = [ [[package]] name = "polars-expr" -version = "0.40.0" +version = "0.41.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4173591920fe56ad55af025f92eb0d08421ca85705c326a640c43856094e3484" +checksum = "31a9688d5842e7a7fbad88e67a174778794a91d97d3bba1b3c09dd1656fee3b2" dependencies = [ "ahash 0.8.11", "bitflags 2.5.0", @@ -4203,9 +4203,9 @@ dependencies = [ [[package]] name = "polars-io" -version = "0.40.0" +version = "0.41.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5842896aea46d975b425d63f156f412aed3cfde4c257b64fb1f43ceea288074e" +checksum = "18798dacd94fb9263f65f63f0feab0908675422646d6f7fc37043b85ff6dca35" dependencies = [ "ahash 0.8.11", "async-trait", @@ -4244,9 +4244,9 @@ dependencies = [ [[package]] name = "polars-json" -version = "0.40.0" +version = "0.41.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "160cbad0145b93ac6a88639aadfa6f7d7c769d05a8674f9b7e895b398cae9901" +checksum = "044ea319f667efbf8007c4c38171c2956e0e7f9b078eb66e31e82f80d1e14b51" dependencies = [ "ahash 0.8.11", "chrono", @@ -4265,19 +4265,21 @@ dependencies = [ [[package]] name = "polars-lazy" -version = "0.40.0" +version = "0.41.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e805ea2ebbc6b7749b0afb31b7fc5d32b42b57ba29b984549d43d3a16114c4a5" +checksum = "74a11994c2211f2e99d9ac31776fd7c2c0607d5fe62d5b5db9e396f7d663f3d5" dependencies = [ "ahash 0.8.11", "bitflags 2.5.0", "glob", + "memchr", "once_cell", "polars-arrow", "polars-core", "polars-expr", "polars-io", "polars-json", + "polars-mem-engine", "polars-ops", "polars-pipe", "polars-plan", @@ -4289,10 +4291,29 @@ dependencies = [ ] [[package]] -name = "polars-ops" -version = "0.40.0" +name = "polars-mem-engine" +version = "0.41.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7b0aed7e169c81b98457641cf82b251f52239a668916c2e683abd1f38df00d58" +checksum = "5acd5fde6fadaddfcae3227ec5b64121007928f8e68870c80653438e20c1c587" +dependencies = [ + "polars-arrow", + "polars-core", + "polars-error", + "polars-expr", + "polars-io", + "polars-json", + "polars-ops", + "polars-plan", + "polars-time", + "polars-utils", + "rayon", +] + +[[package]] +name = "polars-ops" +version = "0.41.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f4170c59e974727941edfb722f6d430ed623be9e7f30581ee00832c907f1b9fd" dependencies = [ "ahash 0.8.11", "argminmax", @@ -4326,9 +4347,9 @@ dependencies = [ [[package]] name = "polars-parquet" -version = "0.40.0" +version = "0.41.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c70670a9e51cac66d0e77fd20b5cc957dbcf9f2660d410633862bb72f846d5b8" +checksum = "c684638c36c60c691d707d414249fe8af4a19a35a39d418464b140fe23732e5d" dependencies = [ "ahash 0.8.11", "async-stream", @@ -4341,9 +4362,11 @@ dependencies = [ "num-traits", "parquet-format-safe", "polars-arrow", + "polars-compute", "polars-error", "polars-utils", "seq-macro", + "serde", "simdutf8", "snap", "streaming-decompression", @@ -4352,9 +4375,9 @@ dependencies = [ [[package]] name = "polars-pipe" -version = "0.40.0" +version = "0.41.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0a40ae1b3c74ee07e2d1f7cbf56c5d6e15969e45d9b6f0903bd2acaf783ba436" +checksum = "832af9fbebc4c074d95fb19e1ef9e1bf37c343641238c2476febff296a7028ea" dependencies = [ "crossbeam-channel", "crossbeam-queue", @@ -4378,9 +4401,9 @@ dependencies = [ [[package]] name = "polars-plan" -version = "0.40.0" +version = "0.41.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8daa3541ae7e9af311a4389bc2b21f83349c34c723cc67fa524cdefdaa172d90" +checksum = "801390ea815c05c9cf8337f3148090c9c10c9595a839fa0706b77cc2405b4466" dependencies = [ "ahash 0.8.11", "bytemuck", @@ -4408,9 +4431,9 @@ dependencies = [ [[package]] name = "polars-row" -version = "0.40.0" +version = "0.41.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "deb285f2f3a65b00dd06bef16bb9f712dbb5478f941dab5cf74f9f016d382e40" +checksum = "dee955e91b605fc91db4d0a8ea02609d3a09ff79256d905214a2a6f758cd6f7b" dependencies = [ "bytemuck", "polars-arrow", @@ -4420,9 +4443,9 @@ dependencies = [ [[package]] name = "polars-sql" -version = "0.40.0" +version = "0.41.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a724f699d194cb02c25124d3832f7d4d77f387f1a89ee42f6b9e88ec561d4ad9" +checksum = "d89c00a4b399501d5bd478e8e8022b9391047fe8570324ecba20c4e4833c0e87" dependencies = [ "hex", "once_cell", @@ -4430,18 +4453,20 @@ dependencies = [ "polars-core", "polars-error", "polars-lazy", + "polars-ops", "polars-plan", + "polars-time", "rand", "serde", "serde_json", - "sqlparser 0.39.0", + "sqlparser", ] [[package]] name = "polars-time" -version = "0.40.0" +version = "0.41.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "87ebec238d8b6200d9f0c3ce411c8441e950bd5a7df7806b8172d06c1d5a4b97" +checksum = "9689b3aff99d64befe300495528bdc44c36d2656c3a8b242a790d4f43df027fc" dependencies = [ "atoi", "bytemuck", @@ -4461,9 +4486,9 @@ dependencies = [ [[package]] name = "polars-utils" -version = "0.40.0" +version = "0.41.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "34e1a907c63abf71e5f21467e2e4ff748896c28196746f631c6c25512ec6102c" +checksum = "12081e346983a91e26f395597e1d53dea1b4ecd694653aee1cc402d2fae01f04" dependencies = [ "ahash 0.8.11", "bytemuck", @@ -5614,15 +5639,6 @@ dependencies = [ "windows-sys 0.52.0", ] -[[package]] -name = "sqlparser" -version = "0.39.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "743b4dc2cbde11890ccb254a8fc9d537fa41b36da00de2a1c5e9848c9bc42bd7" -dependencies = [ - "log", -] - [[package]] name = "sqlparser" version = "0.47.0" diff --git a/crates/nu_plugin_polars/Cargo.toml b/crates/nu_plugin_polars/Cargo.toml index 285046381e..9589c89525 100644 --- a/crates/nu_plugin_polars/Cargo.toml +++ b/crates/nu_plugin_polars/Cargo.toml @@ -31,11 +31,11 @@ mimalloc = { version = "0.1.42" } num = {version = "0.4"} serde = { version = "1.0", features = ["derive"] } sqlparser = { version = "0.47"} -polars-io = { version = "0.40", features = ["avro"]} -polars-arrow = { version = "0.40"} -polars-ops = { version = "0.40"} -polars-plan = { version = "0.40", features = ["regex"]} -polars-utils = { version = "0.40"} +polars-io = { version = "0.41", features = ["avro"]} +polars-arrow = { version = "0.41"} +polars-ops = { version = "0.41"} +polars-plan = { version = "0.41", features = ["regex"]} +polars-utils = { version = "0.41"} typetag = "0.2" env_logger = "0.11.3" log.workspace = true @@ -73,7 +73,7 @@ features = [ "to_dummies", ] optional = false -version = "0.40" +version = "0.41" [dev-dependencies] nu-cmd-lang = { path = "../nu-cmd-lang", version = "0.95.1" } diff --git a/crates/nu_plugin_polars/src/dataframe/eager/melt.rs b/crates/nu_plugin_polars/src/dataframe/eager/melt.rs deleted file mode 100644 index b69389ed24..0000000000 --- a/crates/nu_plugin_polars/src/dataframe/eager/melt.rs +++ /dev/null @@ -1,253 +0,0 @@ -use nu_plugin::{EngineInterface, EvaluatedCall, PluginCommand}; -use nu_protocol::{ - Category, Example, LabeledError, PipelineData, ShellError, Signature, Span, Spanned, - SyntaxShape, Type, Value, -}; - -use crate::{ - dataframe::values::utils::convert_columns_string, values::CustomValueSupport, PolarsPlugin, -}; - -use super::super::values::{Column, NuDataFrame}; - -#[derive(Clone)] -pub struct MeltDF; - -impl PluginCommand for MeltDF { - type Plugin = PolarsPlugin; - - fn name(&self) -> &str { - "polars melt" - } - - fn usage(&self) -> &str { - "Unpivot a DataFrame from wide to long format." - } - - fn signature(&self) -> Signature { - Signature::build(self.name()) - .required_named( - "columns", - SyntaxShape::Table(vec![]), - "column names for melting", - Some('c'), - ) - .required_named( - "values", - SyntaxShape::Table(vec![]), - "column names used as value columns", - Some('v'), - ) - .named( - "variable-name", - SyntaxShape::String, - "optional name for variable column", - Some('r'), - ) - .named( - "value-name", - SyntaxShape::String, - "optional name for value column", - Some('l'), - ) - .input_output_type( - Type::Custom("dataframe".into()), - Type::Custom("dataframe".into()), - ) - .category(Category::Custom("dataframe".into())) - } - - fn examples(&self) -> Vec { - vec![Example { - description: "melt dataframe", - example: - "[[a b c d]; [x 1 4 a] [y 2 5 b] [z 3 6 c]] | polars into-df | polars melt -c [b c] -v [a d]", - result: Some( - NuDataFrame::try_from_columns(vec![ - Column::new( - "b".to_string(), - vec![ - Value::test_int(1), - Value::test_int(2), - Value::test_int(3), - Value::test_int(1), - Value::test_int(2), - Value::test_int(3), - ], - ), - Column::new( - "c".to_string(), - vec![ - Value::test_int(4), - Value::test_int(5), - Value::test_int(6), - Value::test_int(4), - Value::test_int(5), - Value::test_int(6), - ], - ), - Column::new( - "variable".to_string(), - vec![ - Value::test_string("a"), - Value::test_string("a"), - Value::test_string("a"), - Value::test_string("d"), - Value::test_string("d"), - Value::test_string("d"), - ], - ), - Column::new( - "value".to_string(), - vec![ - Value::test_string("x"), - Value::test_string("y"), - Value::test_string("z"), - Value::test_string("a"), - Value::test_string("b"), - Value::test_string("c"), - ], - ), - ], None) - .expect("simple df for test should not fail") - .into_value(Span::test_data()), - ), - }] - } - - fn run( - &self, - plugin: &Self::Plugin, - engine: &EngineInterface, - call: &EvaluatedCall, - input: PipelineData, - ) -> Result { - command(plugin, engine, call, input).map_err(LabeledError::from) - } -} - -fn command( - plugin: &PolarsPlugin, - engine: &EngineInterface, - call: &EvaluatedCall, - input: PipelineData, -) -> Result { - let id_col: Vec = call.get_flag("columns")?.expect("required value"); - let val_col: Vec = call.get_flag("values")?.expect("required value"); - - let value_name: Option> = call.get_flag("value-name")?; - let variable_name: Option> = call.get_flag("variable-name")?; - - let (id_col_string, id_col_span) = convert_columns_string(id_col, call.head)?; - let (val_col_string, val_col_span) = convert_columns_string(val_col, call.head)?; - - let df = NuDataFrame::try_from_pipeline_coerce(plugin, input, call.head)?; - - check_column_datatypes(df.as_ref(), &id_col_string, id_col_span)?; - check_column_datatypes(df.as_ref(), &val_col_string, val_col_span)?; - - let mut res = df - .as_ref() - .melt(&id_col_string, &val_col_string) - .map_err(|e| ShellError::GenericError { - error: "Error calculating melt".into(), - msg: e.to_string(), - span: Some(call.head), - help: None, - inner: vec![], - })?; - - if let Some(name) = &variable_name { - res.rename("variable", &name.item) - .map_err(|e| ShellError::GenericError { - error: "Error renaming column".into(), - msg: e.to_string(), - span: Some(name.span), - help: None, - inner: vec![], - })?; - } - - if let Some(name) = &value_name { - res.rename("value", &name.item) - .map_err(|e| ShellError::GenericError { - error: "Error renaming column".into(), - msg: e.to_string(), - span: Some(name.span), - help: None, - inner: vec![], - })?; - } - - let res = NuDataFrame::new(false, res); - res.to_pipeline_data(plugin, engine, call.head) -} - -fn check_column_datatypes>( - df: &polars::prelude::DataFrame, - cols: &[T], - col_span: Span, -) -> Result<(), ShellError> { - if cols.is_empty() { - return Err(ShellError::GenericError { - error: "Merge error".into(), - msg: "empty column list".into(), - span: Some(col_span), - help: None, - inner: vec![], - }); - } - - // Checking if they are same type - if cols.len() > 1 { - for w in cols.windows(2) { - let l_series = df - .column(w[0].as_ref()) - .map_err(|e| ShellError::GenericError { - error: "Error selecting columns".into(), - msg: e.to_string(), - span: Some(col_span), - help: None, - inner: vec![], - })?; - - let r_series = df - .column(w[1].as_ref()) - .map_err(|e| ShellError::GenericError { - error: "Error selecting columns".into(), - msg: e.to_string(), - span: Some(col_span), - help: None, - inner: vec![], - })?; - - if l_series.dtype() != r_series.dtype() { - return Err(ShellError::GenericError { - error: "Merge error".into(), - msg: "found different column types in list".into(), - span: Some(col_span), - help: Some(format!( - "datatypes {} and {} are incompatible", - l_series.dtype(), - r_series.dtype() - )), - inner: vec![], - }); - } - } - } - - Ok(()) -} - -#[cfg(test)] -mod test { - use crate::test::test_polars_plugin_command; - - use super::*; - - #[test] - fn test_examples() -> Result<(), ShellError> { - test_polars_plugin_command(&MeltDF) - } -} diff --git a/crates/nu_plugin_polars/src/dataframe/eager/mod.rs b/crates/nu_plugin_polars/src/dataframe/eager/mod.rs index dc50ba7cd2..509c2b6dc4 100644 --- a/crates/nu_plugin_polars/src/dataframe/eager/mod.rs +++ b/crates/nu_plugin_polars/src/dataframe/eager/mod.rs @@ -9,7 +9,6 @@ mod filter_with; mod first; mod get; mod last; -mod melt; mod open; mod query_df; mod rename; @@ -28,6 +27,7 @@ mod to_df; mod to_json_lines; mod to_nu; mod to_parquet; +mod unpivot; mod with_column; use crate::PolarsPlugin; @@ -44,7 +44,6 @@ pub use filter_with::FilterWith; pub use first::FirstDF; pub use get::GetDF; pub use last::LastDF; -pub use melt::MeltDF; use nu_plugin::PluginCommand; pub use query_df::QueryDf; pub use rename::RenameDF; @@ -62,6 +61,7 @@ pub use to_df::ToDataFrame; pub use to_json_lines::ToJsonLines; pub use to_nu::ToNu; pub use to_parquet::ToParquet; +pub use unpivot::UnpivotDF; pub use with_column::WithColumn; pub(crate) fn eager_commands() -> Vec>> { @@ -76,7 +76,7 @@ pub(crate) fn eager_commands() -> Vec Result { - let infer_schema: usize = call + let infer_schema: NonZeroUsize = call .get_flag("infer-schema")? - .unwrap_or(DEFAULT_INFER_SCHEMA); + .and_then(NonZeroUsize::new) + .unwrap_or( + NonZeroUsize::new(DEFAULT_INFER_SCHEMA) + .expect("The default infer-schema should be non zero"), + ); let maybe_schema = call .get_flag("schema")? .map(|schema| NuSchema::try_from(&schema)) @@ -528,7 +533,7 @@ fn from_csv( .with_infer_schema_length(Some(infer_schema)) .with_skip_rows(skip_rows.unwrap_or_default()) .with_schema(maybe_schema.map(|s| s.into())) - .with_columns(columns.map(Arc::new)) + .with_columns(columns.map(|v| Arc::from(v.into_boxed_slice()))) .map_parse_options(|options| { options .with_separator( diff --git a/crates/nu_plugin_polars/src/dataframe/eager/schema.rs b/crates/nu_plugin_polars/src/dataframe/eager/schema.rs index b55d8ee5e2..e1350289a7 100644 --- a/crates/nu_plugin_polars/src/dataframe/eager/schema.rs +++ b/crates/nu_plugin_polars/src/dataframe/eager/schema.rs @@ -70,7 +70,7 @@ fn command( let value: Value = schema.into(); Ok(PipelineData::Value(value, None)) } - PolarsPluginObject::NuLazyFrame(lazy) => { + PolarsPluginObject::NuLazyFrame(mut lazy) => { let schema = lazy.schema()?; let value: Value = schema.into(); Ok(PipelineData::Value(value, None)) diff --git a/crates/nu_plugin_polars/src/dataframe/eager/unpivot.rs b/crates/nu_plugin_polars/src/dataframe/eager/unpivot.rs new file mode 100644 index 0000000000..c535b54c0e --- /dev/null +++ b/crates/nu_plugin_polars/src/dataframe/eager/unpivot.rs @@ -0,0 +1,358 @@ +use nu_plugin::{EngineInterface, EvaluatedCall, PluginCommand}; +use nu_protocol::{ + Category, Example, LabeledError, PipelineData, ShellError, Signature, Span, Spanned, + SyntaxShape, Type, Value, +}; +use polars::frame::explode::UnpivotArgs; + +use crate::{ + dataframe::values::utils::convert_columns_string, + values::{CustomValueSupport, NuLazyFrame, PolarsPluginObject}, + PolarsPlugin, +}; + +use super::super::values::{Column, NuDataFrame}; + +#[derive(Clone)] +pub struct UnpivotDF; + +impl PluginCommand for UnpivotDF { + type Plugin = PolarsPlugin; + + fn name(&self) -> &str { + "polars unpivot" + } + + fn usage(&self) -> &str { + "Unpivot a DataFrame from wide to long format." + } + + fn signature(&self) -> Signature { + Signature::build(self.name()) + .required_named( + "columns", + SyntaxShape::Table(vec![]), + "column names for unpivoting", + Some('c'), + ) + .required_named( + "values", + SyntaxShape::Table(vec![]), + "column names used as value columns", + Some('v'), + ) + .named( + "variable-name", + SyntaxShape::String, + "optional name for variable column", + Some('r'), + ) + .named( + "value-name", + SyntaxShape::String, + "optional name for value column", + Some('l'), + ) + .input_output_type( + Type::Custom("dataframe".into()), + Type::Custom("dataframe".into()), + ) + .switch( + "streamable", + "Whether or not to use the polars streaming engine. Only valid for lazy dataframes", + Some('s'), + ) + .category(Category::Custom("dataframe".into())) + } + + fn examples(&self) -> Vec { + vec![ + Example { + description: "unpivot on an eager dataframe", + example: + "[[a b c d]; [x 1 4 a] [y 2 5 b] [z 3 6 c]] | polars into-df | polars unpivot -c [b c] -v [a d]", + result: Some( + NuDataFrame::try_from_columns(vec![ + Column::new( + "b".to_string(), + vec![ + Value::test_int(1), + Value::test_int(2), + Value::test_int(3), + Value::test_int(1), + Value::test_int(2), + Value::test_int(3), + ], + ), + Column::new( + "c".to_string(), + vec![ + Value::test_int(4), + Value::test_int(5), + Value::test_int(6), + Value::test_int(4), + Value::test_int(5), + Value::test_int(6), + ], + ), + Column::new( + "variable".to_string(), + vec![ + Value::test_string("a"), + Value::test_string("a"), + Value::test_string("a"), + Value::test_string("d"), + Value::test_string("d"), + Value::test_string("d"), + ], + ), + Column::new( + "value".to_string(), + vec![ + Value::test_string("x"), + Value::test_string("y"), + Value::test_string("z"), + Value::test_string("a"), + Value::test_string("b"), + Value::test_string("c"), + ], + ), + ], None) + .expect("simple df for test should not fail") + .into_value(Span::test_data()), + ), + }, + Example { + description: "unpivot on a lazy dataframe", + example: + "[[a b c d]; [x 1 4 a] [y 2 5 b] [z 3 6 c]] | polars into-lazy | polars unpivot -c [b c] -v [a d] | polars collect", + result: Some( + NuDataFrame::try_from_columns(vec![ + Column::new( + "b".to_string(), + vec![ + Value::test_int(1), + Value::test_int(2), + Value::test_int(3), + Value::test_int(1), + Value::test_int(2), + Value::test_int(3), + ], + ), + Column::new( + "c".to_string(), + vec![ + Value::test_int(4), + Value::test_int(5), + Value::test_int(6), + Value::test_int(4), + Value::test_int(5), + Value::test_int(6), + ], + ), + Column::new( + "variable".to_string(), + vec![ + Value::test_string("a"), + Value::test_string("a"), + Value::test_string("a"), + Value::test_string("d"), + Value::test_string("d"), + Value::test_string("d"), + ], + ), + Column::new( + "value".to_string(), + vec![ + Value::test_string("x"), + Value::test_string("y"), + Value::test_string("z"), + Value::test_string("a"), + Value::test_string("b"), + Value::test_string("c"), + ], + ), + ], None) + .expect("simple df for test should not fail") + .into_value(Span::test_data()), + ), + } + ] + } + + fn run( + &self, + plugin: &Self::Plugin, + engine: &EngineInterface, + call: &EvaluatedCall, + input: PipelineData, + ) -> Result { + match PolarsPluginObject::try_from_pipeline(plugin, input, call.head)? { + PolarsPluginObject::NuDataFrame(df) => command_eager(plugin, engine, call, df), + PolarsPluginObject::NuLazyFrame(lazy) => command_lazy(plugin, engine, call, lazy), + _ => Err(ShellError::GenericError { + error: "Must be a dataframe or lazy dataframe".into(), + msg: "".into(), + span: Some(call.head), + help: None, + inner: vec![], + }), + } + .map_err(LabeledError::from) + } +} + +fn command_eager( + plugin: &PolarsPlugin, + engine: &EngineInterface, + call: &EvaluatedCall, + df: NuDataFrame, +) -> Result { + let id_col: Vec = call.get_flag("columns")?.expect("required value"); + let val_col: Vec = call.get_flag("values")?.expect("required value"); + + let value_name: Option> = call.get_flag("value-name")?; + let variable_name: Option> = call.get_flag("variable-name")?; + + let (id_col_string, id_col_span) = convert_columns_string(id_col, call.head)?; + let (val_col_string, val_col_span) = convert_columns_string(val_col, call.head)?; + + check_column_datatypes(df.as_ref(), &id_col_string, id_col_span)?; + check_column_datatypes(df.as_ref(), &val_col_string, val_col_span)?; + + let mut res = df + .as_ref() + .unpivot(&val_col_string, &id_col_string) + .map_err(|e| ShellError::GenericError { + error: "Error calculating unpivot".into(), + msg: e.to_string(), + span: Some(call.head), + help: None, + inner: vec![], + })?; + + if let Some(name) = &variable_name { + res.rename("variable", &name.item) + .map_err(|e| ShellError::GenericError { + error: "Error renaming column".into(), + msg: e.to_string(), + span: Some(name.span), + help: None, + inner: vec![], + })?; + } + + if let Some(name) = &value_name { + res.rename("value", &name.item) + .map_err(|e| ShellError::GenericError { + error: "Error renaming column".into(), + msg: e.to_string(), + span: Some(name.span), + help: None, + inner: vec![], + })?; + } + + let res = NuDataFrame::new(false, res); + res.to_pipeline_data(plugin, engine, call.head) +} + +fn command_lazy( + plugin: &PolarsPlugin, + engine: &EngineInterface, + call: &EvaluatedCall, + df: NuLazyFrame, +) -> Result { + let id_col: Vec = call.get_flag("columns")?.expect("required value"); + let val_col: Vec = call.get_flag("values")?.expect("required value"); + + let (id_col_string, _id_col_span) = convert_columns_string(id_col, call.head)?; + let (val_col_string, _val_col_span) = convert_columns_string(val_col, call.head)?; + + let value_name: Option = call.get_flag("value-name")?; + let variable_name: Option = call.get_flag("variable-name")?; + + let streamable = call.has_flag("streamable")?; + + let unpivot_args = UnpivotArgs { + on: val_col_string.iter().map(Into::into).collect(), + index: id_col_string.iter().map(Into::into).collect(), + value_name: value_name.map(Into::into), + variable_name: variable_name.map(Into::into), + streamable, + }; + + let polars_df = df.to_polars().unpivot(unpivot_args); + + let res = NuLazyFrame::new(false, polars_df); + res.to_pipeline_data(plugin, engine, call.head) +} + +fn check_column_datatypes>( + df: &polars::prelude::DataFrame, + cols: &[T], + col_span: Span, +) -> Result<(), ShellError> { + if cols.is_empty() { + return Err(ShellError::GenericError { + error: "Merge error".into(), + msg: "empty column list".into(), + span: Some(col_span), + help: None, + inner: vec![], + }); + } + + // Checking if they are same type + if cols.len() > 1 { + for w in cols.windows(2) { + let l_series = df + .column(w[0].as_ref()) + .map_err(|e| ShellError::GenericError { + error: "Error selecting columns".into(), + msg: e.to_string(), + span: Some(col_span), + help: None, + inner: vec![], + })?; + + let r_series = df + .column(w[1].as_ref()) + .map_err(|e| ShellError::GenericError { + error: "Error selecting columns".into(), + msg: e.to_string(), + span: Some(col_span), + help: None, + inner: vec![], + })?; + + if l_series.dtype() != r_series.dtype() { + return Err(ShellError::GenericError { + error: "Merge error".into(), + msg: "found different column types in list".into(), + span: Some(col_span), + help: Some(format!( + "datatypes {} and {} are incompatible", + l_series.dtype(), + r_series.dtype() + )), + inner: vec![], + }); + } + } + } + + Ok(()) +} + +#[cfg(test)] +mod test { + use crate::test::test_polars_plugin_command; + + use super::*; + + #[test] + fn test_examples() -> Result<(), ShellError> { + test_polars_plugin_command(&UnpivotDF) + } +} diff --git a/crates/nu_plugin_polars/src/dataframe/lazy/aggregate.rs b/crates/nu_plugin_polars/src/dataframe/lazy/aggregate.rs index d2fd92ec87..add72faf88 100644 --- a/crates/nu_plugin_polars/src/dataframe/lazy/aggregate.rs +++ b/crates/nu_plugin_polars/src/dataframe/lazy/aggregate.rs @@ -196,7 +196,8 @@ fn get_col_name(expr: &Expr) -> Option { | Expr::Nth(_) | Expr::SubPlan(_, _) | Expr::IndexColumn(_) - | Expr::Selector(_) => None, + | Expr::Selector(_) + | Expr::Field(_) => None, } } diff --git a/crates/nu_plugin_polars/src/dataframe/lazy/groupby.rs b/crates/nu_plugin_polars/src/dataframe/lazy/groupby.rs index 9edf8f5e60..b7bd1017a3 100644 --- a/crates/nu_plugin_polars/src/dataframe/lazy/groupby.rs +++ b/crates/nu_plugin_polars/src/dataframe/lazy/groupby.rs @@ -148,11 +148,11 @@ fn command( plugin: &PolarsPlugin, engine: &EngineInterface, call: &EvaluatedCall, - lazy: NuLazyFrame, + mut lazy: NuLazyFrame, expressions: Vec, ) -> Result { let group_by = lazy.to_polars().group_by(expressions); - let group_by = NuLazyGroupBy::new(group_by, lazy.from_eager, lazy.schema()?); + let group_by = NuLazyGroupBy::new(group_by, lazy.from_eager, lazy.schema().clone()?); group_by.to_pipeline_data(plugin, engine, call.head) } diff --git a/crates/nu_plugin_polars/src/dataframe/lazy/join.rs b/crates/nu_plugin_polars/src/dataframe/lazy/join.rs index 67f5aee9ba..01fb2ac24d 100644 --- a/crates/nu_plugin_polars/src/dataframe/lazy/join.rs +++ b/crates/nu_plugin_polars/src/dataframe/lazy/join.rs @@ -35,7 +35,7 @@ impl PluginCommand for LazyJoin { Some('i'), ) .switch("left", "left join between lazyframes", Some('l')) - .switch("outer", "outer join between lazyframes", Some('o')) + .switch("full", "full join between lazyframes", Some('f')) .switch("cross", "cross join between lazyframes", Some('c')) .named( "suffix", @@ -183,13 +183,13 @@ impl PluginCommand for LazyJoin { input: PipelineData, ) -> Result { let left = call.has_flag("left")?; - let outer = call.has_flag("outer")?; + let full = call.has_flag("full")?; let cross = call.has_flag("cross")?; let how = if left { JoinType::Left - } else if outer { - JoinType::Outer + } else if full { + JoinType::Full } else if cross { JoinType::Cross } else { diff --git a/crates/nu_plugin_polars/src/dataframe/lazy/sort_by_expr.rs b/crates/nu_plugin_polars/src/dataframe/lazy/sort_by_expr.rs index b282bd6d04..20251a0ea0 100644 --- a/crates/nu_plugin_polars/src/dataframe/lazy/sort_by_expr.rs +++ b/crates/nu_plugin_polars/src/dataframe/lazy/sort_by_expr.rs @@ -140,7 +140,7 @@ impl PluginCommand for LazySortBy { let sort_options = SortMultipleOptions { descending: reverse, - nulls_last, + nulls_last: vec![nulls_last], multithreaded: true, maintain_order, }; diff --git a/crates/nu_plugin_polars/src/dataframe/series/indexes/set_with_idx.rs b/crates/nu_plugin_polars/src/dataframe/series/indexes/set_with_idx.rs index 6b1578ad0e..436ba92f96 100644 --- a/crates/nu_plugin_polars/src/dataframe/series/indexes/set_with_idx.rs +++ b/crates/nu_plugin_polars/src/dataframe/series/indexes/set_with_idx.rs @@ -7,7 +7,10 @@ use nu_protocol::{ Category, Example, LabeledError, PipelineData, ShellError, Signature, Span, SyntaxShape, Type, Value, }; -use polars::prelude::{ChunkSet, DataType, IntoSeries}; +use polars::{ + chunked_array::cast::CastOptions, + prelude::{ChunkSet, DataType, IntoSeries}, +}; #[derive(Clone)] pub struct SetWithIndex; @@ -96,7 +99,7 @@ fn command( let casted = match indices.dtype() { DataType::UInt32 | DataType::UInt64 | DataType::Int32 | DataType::Int64 => indices .as_ref() - .cast(&DataType::UInt32) + .cast(&DataType::UInt32, CastOptions::default()) .map_err(|e| ShellError::GenericError { error: "Error casting indices".into(), msg: e.to_string(), diff --git a/crates/nu_plugin_polars/src/dataframe/series/value_counts.rs b/crates/nu_plugin_polars/src/dataframe/series/value_counts.rs index 36d8718ff8..16471e7b9f 100644 --- a/crates/nu_plugin_polars/src/dataframe/series/value_counts.rs +++ b/crates/nu_plugin_polars/src/dataframe/series/value_counts.rs @@ -4,7 +4,8 @@ use super::super::values::{Column, NuDataFrame}; use nu_plugin::{EngineInterface, EvaluatedCall, PluginCommand}; use nu_protocol::{ - Category, Example, LabeledError, PipelineData, ShellError, Signature, Span, Type, Value, + Category, Example, LabeledError, PipelineData, ShellError, Signature, Span, SyntaxShape, Type, + Value, }; use polars::prelude::SeriesMethods; @@ -25,6 +26,24 @@ impl PluginCommand for ValueCount { fn signature(&self) -> Signature { Signature::build(self.name()) + .named( + "column", + SyntaxShape::String, + "Provide a custom name for the coutn column", + Some('c'), + ) + .switch("sort", "Whether or not values should be sorted", Some('s')) + .switch( + "parallel", + "Use multiple threads when processing", + Some('p'), + ) + .named( + "normalize", + SyntaxShape::String, + "Normalize the counts", + Some('n'), + ) .input_output_type( Type::Custom("dataframe".into()), Type::Custom("dataframe".into()), @@ -73,11 +92,15 @@ fn command( call: &EvaluatedCall, input: PipelineData, ) -> Result { + let column = call.get_flag("column")?.unwrap_or("count".to_string()); + let parallel = call.has_flag("parallel")?; + let sort = call.has_flag("sort")?; + let normalize = call.has_flag("normalize")?; let df = NuDataFrame::try_from_pipeline_coerce(plugin, input, call.head)?; let series = df.as_series(call.head)?; let res = series - .value_counts(false, false) + .value_counts(sort, parallel, column, normalize) .map_err(|e| ShellError::GenericError { error: "Error calculating value counts values".into(), msg: e.to_string(), diff --git a/crates/nu_plugin_polars/src/dataframe/values/nu_dataframe/between_values.rs b/crates/nu_plugin_polars/src/dataframe/values/nu_dataframe/between_values.rs index a47197bde8..c92e657fb9 100644 --- a/crates/nu_plugin_polars/src/dataframe/values/nu_dataframe/between_values.rs +++ b/crates/nu_plugin_polars/src/dataframe/values/nu_dataframe/between_values.rs @@ -41,19 +41,37 @@ pub(super) fn compute_between_series( let operation_span = Span::merge(left.span(), right.span()); match operator.item { Operator::Math(Math::Plus) => { - let mut res = lhs + rhs; + let mut res = (lhs + rhs).map_err(|e| ShellError::GenericError { + error: format!("Addition error: {e}"), + msg: "".into(), + span: Some(operation_span), + help: None, + inner: vec![], + })?; let name = format!("sum_{}_{}", lhs.name(), rhs.name()); res.rename(&name); NuDataFrame::try_from_series(res, operation_span) } Operator::Math(Math::Minus) => { - let mut res = lhs - rhs; + let mut res = (lhs - rhs).map_err(|e| ShellError::GenericError { + error: format!("Subtraction error: {e}"), + msg: "".into(), + span: Some(operation_span), + help: None, + inner: vec![], + })?; let name = format!("sub_{}_{}", lhs.name(), rhs.name()); res.rename(&name); NuDataFrame::try_from_series(res, operation_span) } Operator::Math(Math::Multiply) => { - let mut res = lhs * rhs; + let mut res = (lhs * rhs).map_err(|e| ShellError::GenericError { + error: format!("Multiplication error: {e}"), + msg: "".into(), + span: Some(operation_span), + help: None, + inner: vec![], + })?; let name = format!("mul_{}_{}", lhs.name(), rhs.name()); res.rename(&name); NuDataFrame::try_from_series(res, operation_span) diff --git a/crates/nu_plugin_polars/src/dataframe/values/nu_expression/mod.rs b/crates/nu_plugin_polars/src/dataframe/values/nu_expression/mod.rs index cead8c4a11..af00ca94a5 100644 --- a/crates/nu_plugin_polars/src/dataframe/values/nu_expression/mod.rs +++ b/crates/nu_plugin_polars/src/dataframe/values/nu_expression/mod.rs @@ -1,7 +1,10 @@ mod custom_value; use nu_protocol::{record, ShellError, Span, Value}; -use polars::prelude::{col, AggExpr, Expr, Literal}; +use polars::{ + chunked_array::cast::CastOptions, + prelude::{col, AggExpr, Expr, Literal}, +}; use serde::{Deserialize, Deserializer, Serialize, Serializer}; use uuid::Uuid; @@ -269,15 +272,23 @@ pub fn expr_to_value(expr: &Expr, span: Span) -> Result { Expr::Cast { expr, data_type, - strict, - } => Ok(Value::record( - record! { - "expr" => expr_to_value(expr.as_ref(), span)?, - "dtype" => Value::string(format!("{data_type:?}"), span), - "strict" => Value::bool(*strict, span), - }, - span, - )), + options, + } => { + let cast_option_str = match options { + CastOptions::Strict => "STRICT", + CastOptions::NonStrict => "NON_STRICT", + CastOptions::Overflowing => "OVERFLOWING", + }; + + Ok(Value::record( + record! { + "expr" => expr_to_value(expr.as_ref(), span)?, + "dtype" => Value::string(format!("{data_type:?}"), span), + "cast_options" => Value::string(cast_option_str, span) + }, + span, + )) + } Expr::Gather { expr, idx, @@ -388,6 +399,7 @@ pub fn expr_to_value(expr: &Expr, span: Span) -> Result { Expr::Window { function, partition_by, + order_by, options, } => { let partition_by: Result, ShellError> = partition_by @@ -399,6 +411,23 @@ pub fn expr_to_value(expr: &Expr, span: Span) -> Result { record! { "function" => expr_to_value(function, span)?, "partition_by" => Value::list(partition_by?, span), + "order_by" => { + if let Some((order_expr, sort_options)) = order_by { + Value::record(record! { + "expr" => expr_to_value(order_expr.as_ref(), span)?, + "sort_options" => { + Value::record(record!( + "descending" => Value::bool(sort_options.descending, span), + "nulls_last"=> Value::bool(sort_options.nulls_last, span), + "multithreaded"=> Value::bool(sort_options.multithreaded, span), + "maintain_order"=> Value::bool(sort_options.maintain_order, span), + ), span) + } + }, span) + } else { + Value::nothing(span) + } + }, "options" => Value::string(format!("{options:?}"), span), }, span, @@ -424,6 +453,18 @@ pub fn expr_to_value(expr: &Expr, span: Span) -> Result { msg_span: span, input_span: Span::unknown(), }), + Expr::Field(column_name) => { + let fields: Vec = column_name + .iter() + .map(|s| Value::string(s.to_string(), span)) + .collect(); + Ok(Value::record( + record!( + "fields" => Value::list(fields, span) + ), + span, + )) + } } } diff --git a/crates/nu_plugin_polars/src/dataframe/values/nu_lazyframe/mod.rs b/crates/nu_plugin_polars/src/dataframe/values/nu_lazyframe/mod.rs index 4dc231d706..e89f14c316 100644 --- a/crates/nu_plugin_polars/src/dataframe/values/nu_lazyframe/mod.rs +++ b/crates/nu_plugin_polars/src/dataframe/values/nu_lazyframe/mod.rs @@ -77,14 +77,17 @@ impl NuLazyFrame { Self::new(self.from_eager, new_frame) } - pub fn schema(&self) -> Result { - let internal_schema = self.lazy.schema().map_err(|e| ShellError::GenericError { - error: "Error getting schema from lazy frame".into(), - msg: e.to_string(), - span: None, - help: None, - inner: vec![], - })?; + pub fn schema(&mut self) -> Result { + let internal_schema = + Arc::make_mut(&mut self.lazy) + .schema() + .map_err(|e| ShellError::GenericError { + error: "Error getting schema from lazy frame".into(), + msg: e.to_string(), + span: None, + help: None, + inner: vec![], + })?; Ok(internal_schema.into()) }