diff --git a/crates/nu_plugin_polars/src/dataframe/command/data/col.rs b/crates/nu_plugin_polars/src/dataframe/command/data/col.rs index e39952378d..e02d0f8b2f 100644 --- a/crates/nu_plugin_polars/src/dataframe/command/data/col.rs +++ b/crates/nu_plugin_polars/src/dataframe/command/data/col.rs @@ -1,14 +1,14 @@ use crate::{ dataframe::values::NuExpression, - values::{Column, CustomValueSupport, NuDataFrame}, + values::{str_to_dtype, Column, CustomValueSupport, NuDataFrame}, PolarsPlugin, }; use nu_plugin::{EngineInterface, EvaluatedCall, PluginCommand}; use nu_protocol::{ - record, Category, Example, LabeledError, PipelineData, Signature, Span, SyntaxShape, Type, - Value, + record, Category, Example, LabeledError, PipelineData, ShellError, Signature, Span, + SyntaxShape, Type, Value, }; -use polars::prelude::col; +use polars::prelude::DataType; #[derive(Clone)] pub struct ExprCol; @@ -31,6 +31,12 @@ impl PluginCommand for ExprCol { SyntaxShape::String, "Name of column to be used. '*' can be used for all columns.", ) + .rest( + "more columns", + SyntaxShape::String, + "Additional columns to be used. Cannot be '*'", + ) + .switch("type", "Treat column names as type names", Some('t')) .input_output_type(Type::Any, Type::Custom("expression".into())) .category(Category::Custom("expression".into())) } @@ -57,6 +63,31 @@ impl PluginCommand for ExprCol { .into_value(Span::test_data()), ), }, + Example { + description: "Select multiple columns (cannot be used with asterisk wildcard)", + example: "[[a b c]; [x 1 1.1] [y 2 2.2] [z 3 3.3]] | polars into-df | polars select (polars col b c | polars sum) | polars collect", + result: Some( + NuDataFrame::try_from_columns(vec![ + Column::new("b".to_string(), vec![Value::test_int(6)]), + Column::new("c".to_string(), vec![Value::test_float(6.6)]), + ],None) + .expect("should not fail") + .into_value(Span::test_data()), + ), + }, + Example { + description: "Select multiple columns by types (cannot be used with asterisk wildcard)", + example: "[[a b c]; [x o 1.1] [y p 2.2] [z q 3.3]] | polars into-df | polars select (polars col str f64 --type | polars max) | polars collect", + result: Some( + NuDataFrame::try_from_columns(vec![ + Column::new("a".to_string(), vec![Value::test_string("z")]), + Column::new("b".to_string(), vec![Value::test_string("q")]), + Column::new("c".to_string(), vec![Value::test_float(3.3)]), + ],None) + .expect("should not fail") + .into_value(Span::test_data()), + ), + }, ] } @@ -71,8 +102,27 @@ impl PluginCommand for ExprCol { call: &EvaluatedCall, _input: PipelineData, ) -> Result { - let name: String = call.req(0)?; - let expr: NuExpression = col(name.as_str()).into(); + let mut names: Vec = vec![call.req(0)?]; + names.extend(call.rest(1)?); + + let as_type = call.has_flag("type")?; + + let expr: NuExpression = match as_type { + false => match names.as_slice() { + [single] => polars::prelude::col(single).into(), + _ => polars::prelude::cols(&names).into(), + }, + true => { + let dtypes = names + .iter() + .map(|n| str_to_dtype(n, call.head)) + .collect::, ShellError>>() + .map_err(LabeledError::from)?; + + polars::prelude::dtype_cols(dtypes).into() + } + }; + expr.to_pipeline_data(plugin, engine, call.head) .map_err(LabeledError::from) }