From 525acf9d9e3b049280ca531c123ee5610d6b966a Mon Sep 17 00:00:00 2001 From: Jack Wright <56345+ayax79@users.noreply.github.com> Date: Wed, 14 Feb 2024 16:15:00 -0800 Subject: [PATCH] Ability to cast a dataframe's column to a different dtype (#11803) Provides the ability to cast columns in dataframes, lazy dataframes, and expressions. Screenshot 2024-02-14 at 13 53 01 Screenshot 2024-02-14 at 13 52 37 Screenshot 2024-02-14 at 13 54 58 --------- Co-authored-by: Jack Wright --- .../src/dataframe/eager/cast.rs | 207 ++++++++++++++++++ .../src/dataframe/eager/mod.rs | 3 + .../src/dataframe/test_dataframe.rs | 3 +- .../src/dataframe/values/mod.rs | 2 +- .../src/dataframe/values/nu_schema.rs | 64 +++--- 5 files changed, 245 insertions(+), 34 deletions(-) create mode 100644 crates/nu-cmd-dataframe/src/dataframe/eager/cast.rs diff --git a/crates/nu-cmd-dataframe/src/dataframe/eager/cast.rs b/crates/nu-cmd-dataframe/src/dataframe/eager/cast.rs new file mode 100644 index 0000000000..4a8133bf38 --- /dev/null +++ b/crates/nu-cmd-dataframe/src/dataframe/eager/cast.rs @@ -0,0 +1,207 @@ +use crate::dataframe::values::{str_to_dtype, NuExpression, NuLazyFrame}; + +use super::super::values::NuDataFrame; +use nu_engine::CallExt; +use nu_protocol::{ + ast::Call, + engine::{Command, EngineState, Stack}, + Category, Example, PipelineData, Record, ShellError, Signature, Span, SyntaxShape, Type, Value, +}; +use polars::prelude::*; + +#[derive(Clone)] +pub struct CastDF; + +impl Command for CastDF { + fn name(&self) -> &str { + "dfr cast" + } + + fn usage(&self) -> &str { + "Cast a column to a different dtype." + } + + fn signature(&self) -> Signature { + Signature::build(self.name()) + .input_output_types(vec![ + ( + Type::Custom("expression".into()), + Type::Custom("expression".into()), + ), + ( + Type::Custom("dataframe".into()), + Type::Custom("dataframe".into()), + ), + ]) + .required( + "dtype", + SyntaxShape::String, + "The dtype to cast the column to", + ) + .optional( + "column", + SyntaxShape::String, + "The column to cast. Required when used with a dataframe.", + ) + .category(Category::Custom("dataframe".into())) + } + + fn examples(&self) -> Vec { + vec![ + Example { + description: "Cast a column in a dataframe to a different dtype", + example: "[[a b]; [1 2] [3 4]] | dfr into-df | dfr cast u8 a | dfr schema", + result: Some(Value::record( + Record::from_raw_cols_vals_unchecked( + vec!["a".to_string(), "b".to_string()], + vec![ + Value::string("u8", Span::test_data()), + Value::string("i64", Span::test_data()), + ], + ), + Span::test_data(), + )), + }, + Example { + description: "Cast a column in a lazy dataframe to a different dtype", + example: "[[a b]; [1 2] [3 4]] | dfr into-df | dfr into-lazy | dfr cast u8 a | dfr schema", + result: Some(Value::record( + Record::from_raw_cols_vals_unchecked( + vec!["a".to_string(), "b".to_string()], + vec![ + Value::string("u8", Span::test_data()), + Value::string("i64", Span::test_data()), + ], + ), + Span::test_data(), + )), + }, + Example { + description: "Cast a column in a expression to a different dtype", + example: r#"[[a b]; [1 2] [1 4]] | dfr into-df | dfr group-by a | dfr agg [ (dfr col b | dfr cast u8 | dfr min | dfr as "b_min") ] | dfr schema"#, + result: None + } + ] + } + + fn run( + &self, + engine_state: &EngineState, + stack: &mut Stack, + call: &Call, + input: PipelineData, + ) -> Result { + let value = input.into_value(call.head); + if NuLazyFrame::can_downcast(&value) { + let (dtype, column_nm) = df_args(engine_state, stack, call)?; + let df = NuLazyFrame::try_from_value(value)?; + command_lazy(call, column_nm, dtype, df) + } else if NuDataFrame::can_downcast(&value) { + let (dtype, column_nm) = df_args(engine_state, stack, call)?; + let df = NuDataFrame::try_from_value(value)?; + command_eager(call, column_nm, dtype, df) + } else { + let dtype: String = call.req(engine_state, stack, 0)?; + let dtype = str_to_dtype(&dtype, call.head)?; + + let expr = NuExpression::try_from_value(value)?; + let expr: NuExpression = expr.into_polars().cast(dtype).into(); + + Ok(PipelineData::Value( + NuExpression::into_value(expr, call.head), + None, + )) + } + } +} + +fn df_args( + engine_state: &EngineState, + stack: &mut Stack, + call: &Call, +) -> Result<(DataType, String), ShellError> { + let dtype = dtype_arg(engine_state, stack, call)?; + let column_nm: String = + call.opt(engine_state, stack, 1)? + .ok_or(ShellError::MissingParameter { + param_name: "column_name".into(), + span: call.head, + })?; + Ok((dtype, column_nm)) +} + +fn dtype_arg( + engine_state: &EngineState, + stack: &mut Stack, + call: &Call, +) -> Result { + let dtype: String = call.req(engine_state, stack, 0)?; + str_to_dtype(&dtype, call.head) +} + +fn command_lazy( + call: &Call, + column_nm: String, + dtype: DataType, + lazy: NuLazyFrame, +) -> Result { + let column = col(&column_nm).cast(dtype); + let lazy = lazy.into_polars().with_columns(&[column]); + let lazy = NuLazyFrame::new(false, lazy); + + Ok(PipelineData::Value( + NuLazyFrame::into_value(lazy, call.head)?, + None, + )) +} + +fn command_eager( + call: &Call, + column_nm: String, + dtype: DataType, + nu_df: NuDataFrame, +) -> Result { + let mut df = nu_df.df; + let column = df + .column(&column_nm) + .map_err(|e| ShellError::GenericError { + error: format!("{e}"), + msg: "".into(), + span: Some(call.head), + help: None, + inner: vec![], + })?; + + let casted = column.cast(&dtype).map_err(|e| ShellError::GenericError { + error: format!("{e}"), + msg: "".into(), + span: Some(call.head), + help: None, + inner: vec![], + })?; + + let _ = df + .with_column(casted) + .map_err(|e| ShellError::GenericError { + error: format!("{e}"), + msg: "".into(), + span: Some(call.head), + help: None, + inner: vec![], + })?; + + let df = NuDataFrame::new(false, df); + Ok(PipelineData::Value(df.into_value(call.head), None)) +} + +#[cfg(test)] +mod test { + + use super::super::super::test_dataframe::test_dataframe; + use super::*; + + #[test] + fn test_examples() { + test_dataframe(vec![Box::new(CastDF {})]) + } +} diff --git a/crates/nu-cmd-dataframe/src/dataframe/eager/mod.rs b/crates/nu-cmd-dataframe/src/dataframe/eager/mod.rs index 7aedf5ebcb..db7a5c9312 100644 --- a/crates/nu-cmd-dataframe/src/dataframe/eager/mod.rs +++ b/crates/nu-cmd-dataframe/src/dataframe/eager/mod.rs @@ -1,4 +1,5 @@ mod append; +mod cast; mod columns; mod drop; mod drop_duplicates; @@ -35,6 +36,7 @@ use nu_protocol::engine::StateWorkingSet; pub use self::open::OpenDataFrame; pub use append::AppendDF; +pub use cast::CastDF; pub use columns::ColumnsDF; pub use drop::DropDF; pub use drop_duplicates::DropDuplicates; @@ -78,6 +80,7 @@ pub fn add_eager_decls(working_set: &mut StateWorkingSet) { // Dataframe commands bind_command!( AppendDF, + CastDF, ColumnsDF, DataTypes, Summary, diff --git a/crates/nu-cmd-dataframe/src/dataframe/test_dataframe.rs b/crates/nu-cmd-dataframe/src/dataframe/test_dataframe.rs index ff163ae940..904beaf313 100644 --- a/crates/nu-cmd-dataframe/src/dataframe/test_dataframe.rs +++ b/crates/nu-cmd-dataframe/src/dataframe/test_dataframe.rs @@ -5,7 +5,7 @@ use nu_protocol::{ Example, PipelineData, Span, }; -use super::eager::ToDataFrame; +use super::eager::{SchemaDF, ToDataFrame}; use super::expressions::ExprCol; use super::lazy::{LazyCollect, ToLazyFrame}; use nu_cmd_lang::Let; @@ -36,6 +36,7 @@ pub fn build_test_engine_state(cmds: Vec>) -> Box Result, ShellError> Ok(Field::new(col, dtype)) } _ => { - let dtype = dtype_str_to_schema(&val.as_string()?, span)?; + let dtype = str_to_dtype(&val.as_string()?, span)?; Ok(Field::new(col, dtype)) } }) @@ -81,7 +81,7 @@ fn value_to_fields(value: &Value, span: Span) -> Result, ShellError> Ok(fields) } -fn dtype_str_to_schema(dtype: &str, span: Span) -> Result { +pub fn str_to_dtype(dtype: &str, span: Span) -> Result { match dtype { "bool" => Ok(DataType::Boolean), "u8" => Ok(DataType::UInt8), @@ -107,7 +107,7 @@ fn dtype_str_to_schema(dtype: &str, span: Span) -> Result .trim_start_matches('<') .trim_end_matches('>') .trim(); - let dtype = dtype_str_to_schema(dtype, span)?; + let dtype = str_to_dtype(dtype, span)?; Ok(DataType::List(Box::new(dtype))) } _ if dtype.starts_with("datetime") => { @@ -242,82 +242,82 @@ mod test { #[test] fn test_dtype_str_to_schema_simple_types() { let dtype = "bool"; - let schema = dtype_str_to_schema(dtype, Span::unknown()).unwrap(); + let schema = str_to_dtype(dtype, Span::unknown()).unwrap(); let expected = DataType::Boolean; assert_eq!(schema, expected); let dtype = "u8"; - let schema = dtype_str_to_schema(dtype, Span::unknown()).unwrap(); + let schema = str_to_dtype(dtype, Span::unknown()).unwrap(); let expected = DataType::UInt8; assert_eq!(schema, expected); let dtype = "u16"; - let schema = dtype_str_to_schema(dtype, Span::unknown()).unwrap(); + let schema = str_to_dtype(dtype, Span::unknown()).unwrap(); let expected = DataType::UInt16; assert_eq!(schema, expected); let dtype = "u32"; - let schema = dtype_str_to_schema(dtype, Span::unknown()).unwrap(); + let schema = str_to_dtype(dtype, Span::unknown()).unwrap(); let expected = DataType::UInt32; assert_eq!(schema, expected); let dtype = "u64"; - let schema = dtype_str_to_schema(dtype, Span::unknown()).unwrap(); + let schema = str_to_dtype(dtype, Span::unknown()).unwrap(); let expected = DataType::UInt64; assert_eq!(schema, expected); let dtype = "i8"; - let schema = dtype_str_to_schema(dtype, Span::unknown()).unwrap(); + let schema = str_to_dtype(dtype, Span::unknown()).unwrap(); let expected = DataType::Int8; assert_eq!(schema, expected); let dtype = "i16"; - let schema = dtype_str_to_schema(dtype, Span::unknown()).unwrap(); + let schema = str_to_dtype(dtype, Span::unknown()).unwrap(); let expected = DataType::Int16; assert_eq!(schema, expected); let dtype = "i32"; - let schema = dtype_str_to_schema(dtype, Span::unknown()).unwrap(); + let schema = str_to_dtype(dtype, Span::unknown()).unwrap(); let expected = DataType::Int32; assert_eq!(schema, expected); let dtype = "i64"; - let schema = dtype_str_to_schema(dtype, Span::unknown()).unwrap(); + let schema = str_to_dtype(dtype, Span::unknown()).unwrap(); let expected = DataType::Int64; assert_eq!(schema, expected); let dtype = "str"; - let schema = dtype_str_to_schema(dtype, Span::unknown()).unwrap(); + let schema = str_to_dtype(dtype, Span::unknown()).unwrap(); let expected = DataType::String; assert_eq!(schema, expected); let dtype = "binary"; - let schema = dtype_str_to_schema(dtype, Span::unknown()).unwrap(); + let schema = str_to_dtype(dtype, Span::unknown()).unwrap(); let expected = DataType::Binary; assert_eq!(schema, expected); let dtype = "date"; - let schema = dtype_str_to_schema(dtype, Span::unknown()).unwrap(); + let schema = str_to_dtype(dtype, Span::unknown()).unwrap(); let expected = DataType::Date; assert_eq!(schema, expected); let dtype = "time"; - let schema = dtype_str_to_schema(dtype, Span::unknown()).unwrap(); + let schema = str_to_dtype(dtype, Span::unknown()).unwrap(); let expected = DataType::Time; assert_eq!(schema, expected); let dtype = "null"; - let schema = dtype_str_to_schema(dtype, Span::unknown()).unwrap(); + let schema = str_to_dtype(dtype, Span::unknown()).unwrap(); let expected = DataType::Null; assert_eq!(schema, expected); let dtype = "unknown"; - let schema = dtype_str_to_schema(dtype, Span::unknown()).unwrap(); + let schema = str_to_dtype(dtype, Span::unknown()).unwrap(); let expected = DataType::Unknown; assert_eq!(schema, expected); let dtype = "object"; - let schema = dtype_str_to_schema(dtype, Span::unknown()).unwrap(); + let schema = str_to_dtype(dtype, Span::unknown()).unwrap(); let expected = DataType::Object("unknown", None); assert_eq!(schema, expected); } @@ -325,54 +325,54 @@ mod test { #[test] fn test_dtype_str_schema_datetime() { let dtype = "datetime"; - let schema = dtype_str_to_schema(dtype, Span::unknown()).unwrap(); + let schema = str_to_dtype(dtype, Span::unknown()).unwrap(); let expected = DataType::Datetime(TimeUnit::Milliseconds, None); assert_eq!(schema, expected); let dtype = "datetime"; - let schema = dtype_str_to_schema(dtype, Span::unknown()).unwrap(); + let schema = str_to_dtype(dtype, Span::unknown()).unwrap(); let expected = DataType::Datetime(TimeUnit::Microseconds, None); assert_eq!(schema, expected); let dtype = "datetime<μs, *>"; - let schema = dtype_str_to_schema(dtype, Span::unknown()).unwrap(); + let schema = str_to_dtype(dtype, Span::unknown()).unwrap(); let expected = DataType::Datetime(TimeUnit::Microseconds, None); assert_eq!(schema, expected); let dtype = "datetime"; - let schema = dtype_str_to_schema(dtype, Span::unknown()).unwrap(); + let schema = str_to_dtype(dtype, Span::unknown()).unwrap(); let expected = DataType::Datetime(TimeUnit::Nanoseconds, None); assert_eq!(schema, expected); let dtype = "datetime"; - let schema = dtype_str_to_schema(dtype, Span::unknown()).unwrap(); + let schema = str_to_dtype(dtype, Span::unknown()).unwrap(); let expected = DataType::Datetime(TimeUnit::Milliseconds, Some("UTC".into())); assert_eq!(schema, expected); let dtype = "invalid"; - let schema = dtype_str_to_schema(dtype, Span::unknown()); + let schema = str_to_dtype(dtype, Span::unknown()); assert!(schema.is_err()) } #[test] fn test_dtype_str_schema_duration() { let dtype = "duration"; - let schema = dtype_str_to_schema(dtype, Span::unknown()).unwrap(); + let schema = str_to_dtype(dtype, Span::unknown()).unwrap(); let expected = DataType::Duration(TimeUnit::Milliseconds); assert_eq!(schema, expected); let dtype = "duration"; - let schema = dtype_str_to_schema(dtype, Span::unknown()).unwrap(); + let schema = str_to_dtype(dtype, Span::unknown()).unwrap(); let expected = DataType::Duration(TimeUnit::Microseconds); assert_eq!(schema, expected); let dtype = "duration<μs>"; - let schema = dtype_str_to_schema(dtype, Span::unknown()).unwrap(); + let schema = str_to_dtype(dtype, Span::unknown()).unwrap(); let expected = DataType::Duration(TimeUnit::Microseconds); assert_eq!(schema, expected); let dtype = "duration"; - let schema = dtype_str_to_schema(dtype, Span::unknown()).unwrap(); + let schema = str_to_dtype(dtype, Span::unknown()).unwrap(); let expected = DataType::Duration(TimeUnit::Nanoseconds); assert_eq!(schema, expected); } @@ -380,17 +380,17 @@ mod test { #[test] fn test_dtype_str_to_schema_list_types() { let dtype = "list"; - let schema = dtype_str_to_schema(dtype, Span::unknown()).unwrap(); + let schema = str_to_dtype(dtype, Span::unknown()).unwrap(); let expected = DataType::List(Box::new(DataType::Int32)); assert_eq!(schema, expected); let dtype = "list>"; - let schema = dtype_str_to_schema(dtype, Span::unknown()).unwrap(); + let schema = str_to_dtype(dtype, Span::unknown()).unwrap(); let expected = DataType::List(Box::new(DataType::Duration(TimeUnit::Milliseconds))); assert_eq!(schema, expected); let dtype = "list>"; - let schema = dtype_str_to_schema(dtype, Span::unknown()).unwrap(); + let schema = str_to_dtype(dtype, Span::unknown()).unwrap(); let expected = DataType::List(Box::new(DataType::Datetime(TimeUnit::Milliseconds, None))); assert_eq!(schema, expected); }