From 525acf9d9e3b049280ca531c123ee5610d6b966a Mon Sep 17 00:00:00 2001
From: Jack Wright <56345+ayax79@users.noreply.github.com>
Date: Wed, 14 Feb 2024 16:15:00 -0800
Subject: [PATCH] Ability to cast a dataframe's column to a different dtype
(#11803)
Provides the ability to cast columns in dataframes, lazy dataframes, and
expressions.
---------
Co-authored-by: Jack Wright
---
.../src/dataframe/eager/cast.rs | 207 ++++++++++++++++++
.../src/dataframe/eager/mod.rs | 3 +
.../src/dataframe/test_dataframe.rs | 3 +-
.../src/dataframe/values/mod.rs | 2 +-
.../src/dataframe/values/nu_schema.rs | 64 +++---
5 files changed, 245 insertions(+), 34 deletions(-)
create mode 100644 crates/nu-cmd-dataframe/src/dataframe/eager/cast.rs
diff --git a/crates/nu-cmd-dataframe/src/dataframe/eager/cast.rs b/crates/nu-cmd-dataframe/src/dataframe/eager/cast.rs
new file mode 100644
index 0000000000..4a8133bf38
--- /dev/null
+++ b/crates/nu-cmd-dataframe/src/dataframe/eager/cast.rs
@@ -0,0 +1,207 @@
+use crate::dataframe::values::{str_to_dtype, NuExpression, NuLazyFrame};
+
+use super::super::values::NuDataFrame;
+use nu_engine::CallExt;
+use nu_protocol::{
+ ast::Call,
+ engine::{Command, EngineState, Stack},
+ Category, Example, PipelineData, Record, ShellError, Signature, Span, SyntaxShape, Type, Value,
+};
+use polars::prelude::*;
+
+#[derive(Clone)]
+pub struct CastDF;
+
+impl Command for CastDF {
+ fn name(&self) -> &str {
+ "dfr cast"
+ }
+
+ fn usage(&self) -> &str {
+ "Cast a column to a different dtype."
+ }
+
+ fn signature(&self) -> Signature {
+ Signature::build(self.name())
+ .input_output_types(vec![
+ (
+ Type::Custom("expression".into()),
+ Type::Custom("expression".into()),
+ ),
+ (
+ Type::Custom("dataframe".into()),
+ Type::Custom("dataframe".into()),
+ ),
+ ])
+ .required(
+ "dtype",
+ SyntaxShape::String,
+ "The dtype to cast the column to",
+ )
+ .optional(
+ "column",
+ SyntaxShape::String,
+ "The column to cast. Required when used with a dataframe.",
+ )
+ .category(Category::Custom("dataframe".into()))
+ }
+
+ fn examples(&self) -> Vec {
+ vec![
+ Example {
+ description: "Cast a column in a dataframe to a different dtype",
+ example: "[[a b]; [1 2] [3 4]] | dfr into-df | dfr cast u8 a | dfr schema",
+ result: Some(Value::record(
+ Record::from_raw_cols_vals_unchecked(
+ vec!["a".to_string(), "b".to_string()],
+ vec![
+ Value::string("u8", Span::test_data()),
+ Value::string("i64", Span::test_data()),
+ ],
+ ),
+ Span::test_data(),
+ )),
+ },
+ Example {
+ description: "Cast a column in a lazy dataframe to a different dtype",
+ example: "[[a b]; [1 2] [3 4]] | dfr into-df | dfr into-lazy | dfr cast u8 a | dfr schema",
+ result: Some(Value::record(
+ Record::from_raw_cols_vals_unchecked(
+ vec!["a".to_string(), "b".to_string()],
+ vec![
+ Value::string("u8", Span::test_data()),
+ Value::string("i64", Span::test_data()),
+ ],
+ ),
+ Span::test_data(),
+ )),
+ },
+ Example {
+ description: "Cast a column in a expression to a different dtype",
+ example: r#"[[a b]; [1 2] [1 4]] | dfr into-df | dfr group-by a | dfr agg [ (dfr col b | dfr cast u8 | dfr min | dfr as "b_min") ] | dfr schema"#,
+ result: None
+ }
+ ]
+ }
+
+ fn run(
+ &self,
+ engine_state: &EngineState,
+ stack: &mut Stack,
+ call: &Call,
+ input: PipelineData,
+ ) -> Result {
+ let value = input.into_value(call.head);
+ if NuLazyFrame::can_downcast(&value) {
+ let (dtype, column_nm) = df_args(engine_state, stack, call)?;
+ let df = NuLazyFrame::try_from_value(value)?;
+ command_lazy(call, column_nm, dtype, df)
+ } else if NuDataFrame::can_downcast(&value) {
+ let (dtype, column_nm) = df_args(engine_state, stack, call)?;
+ let df = NuDataFrame::try_from_value(value)?;
+ command_eager(call, column_nm, dtype, df)
+ } else {
+ let dtype: String = call.req(engine_state, stack, 0)?;
+ let dtype = str_to_dtype(&dtype, call.head)?;
+
+ let expr = NuExpression::try_from_value(value)?;
+ let expr: NuExpression = expr.into_polars().cast(dtype).into();
+
+ Ok(PipelineData::Value(
+ NuExpression::into_value(expr, call.head),
+ None,
+ ))
+ }
+ }
+}
+
+fn df_args(
+ engine_state: &EngineState,
+ stack: &mut Stack,
+ call: &Call,
+) -> Result<(DataType, String), ShellError> {
+ let dtype = dtype_arg(engine_state, stack, call)?;
+ let column_nm: String =
+ call.opt(engine_state, stack, 1)?
+ .ok_or(ShellError::MissingParameter {
+ param_name: "column_name".into(),
+ span: call.head,
+ })?;
+ Ok((dtype, column_nm))
+}
+
+fn dtype_arg(
+ engine_state: &EngineState,
+ stack: &mut Stack,
+ call: &Call,
+) -> Result {
+ let dtype: String = call.req(engine_state, stack, 0)?;
+ str_to_dtype(&dtype, call.head)
+}
+
+fn command_lazy(
+ call: &Call,
+ column_nm: String,
+ dtype: DataType,
+ lazy: NuLazyFrame,
+) -> Result {
+ let column = col(&column_nm).cast(dtype);
+ let lazy = lazy.into_polars().with_columns(&[column]);
+ let lazy = NuLazyFrame::new(false, lazy);
+
+ Ok(PipelineData::Value(
+ NuLazyFrame::into_value(lazy, call.head)?,
+ None,
+ ))
+}
+
+fn command_eager(
+ call: &Call,
+ column_nm: String,
+ dtype: DataType,
+ nu_df: NuDataFrame,
+) -> Result {
+ let mut df = nu_df.df;
+ let column = df
+ .column(&column_nm)
+ .map_err(|e| ShellError::GenericError {
+ error: format!("{e}"),
+ msg: "".into(),
+ span: Some(call.head),
+ help: None,
+ inner: vec![],
+ })?;
+
+ let casted = column.cast(&dtype).map_err(|e| ShellError::GenericError {
+ error: format!("{e}"),
+ msg: "".into(),
+ span: Some(call.head),
+ help: None,
+ inner: vec![],
+ })?;
+
+ let _ = df
+ .with_column(casted)
+ .map_err(|e| ShellError::GenericError {
+ error: format!("{e}"),
+ msg: "".into(),
+ span: Some(call.head),
+ help: None,
+ inner: vec![],
+ })?;
+
+ let df = NuDataFrame::new(false, df);
+ Ok(PipelineData::Value(df.into_value(call.head), None))
+}
+
+#[cfg(test)]
+mod test {
+
+ use super::super::super::test_dataframe::test_dataframe;
+ use super::*;
+
+ #[test]
+ fn test_examples() {
+ test_dataframe(vec![Box::new(CastDF {})])
+ }
+}
diff --git a/crates/nu-cmd-dataframe/src/dataframe/eager/mod.rs b/crates/nu-cmd-dataframe/src/dataframe/eager/mod.rs
index 7aedf5ebcb..db7a5c9312 100644
--- a/crates/nu-cmd-dataframe/src/dataframe/eager/mod.rs
+++ b/crates/nu-cmd-dataframe/src/dataframe/eager/mod.rs
@@ -1,4 +1,5 @@
mod append;
+mod cast;
mod columns;
mod drop;
mod drop_duplicates;
@@ -35,6 +36,7 @@ use nu_protocol::engine::StateWorkingSet;
pub use self::open::OpenDataFrame;
pub use append::AppendDF;
+pub use cast::CastDF;
pub use columns::ColumnsDF;
pub use drop::DropDF;
pub use drop_duplicates::DropDuplicates;
@@ -78,6 +80,7 @@ pub fn add_eager_decls(working_set: &mut StateWorkingSet) {
// Dataframe commands
bind_command!(
AppendDF,
+ CastDF,
ColumnsDF,
DataTypes,
Summary,
diff --git a/crates/nu-cmd-dataframe/src/dataframe/test_dataframe.rs b/crates/nu-cmd-dataframe/src/dataframe/test_dataframe.rs
index ff163ae940..904beaf313 100644
--- a/crates/nu-cmd-dataframe/src/dataframe/test_dataframe.rs
+++ b/crates/nu-cmd-dataframe/src/dataframe/test_dataframe.rs
@@ -5,7 +5,7 @@ use nu_protocol::{
Example, PipelineData, Span,
};
-use super::eager::ToDataFrame;
+use super::eager::{SchemaDF, ToDataFrame};
use super::expressions::ExprCol;
use super::lazy::{LazyCollect, ToLazyFrame};
use nu_cmd_lang::Let;
@@ -36,6 +36,7 @@ pub fn build_test_engine_state(cmds: Vec>) -> Box Result, ShellError>
Ok(Field::new(col, dtype))
}
_ => {
- let dtype = dtype_str_to_schema(&val.as_string()?, span)?;
+ let dtype = str_to_dtype(&val.as_string()?, span)?;
Ok(Field::new(col, dtype))
}
})
@@ -81,7 +81,7 @@ fn value_to_fields(value: &Value, span: Span) -> Result, ShellError>
Ok(fields)
}
-fn dtype_str_to_schema(dtype: &str, span: Span) -> Result {
+pub fn str_to_dtype(dtype: &str, span: Span) -> Result {
match dtype {
"bool" => Ok(DataType::Boolean),
"u8" => Ok(DataType::UInt8),
@@ -107,7 +107,7 @@ fn dtype_str_to_schema(dtype: &str, span: Span) -> Result
.trim_start_matches('<')
.trim_end_matches('>')
.trim();
- let dtype = dtype_str_to_schema(dtype, span)?;
+ let dtype = str_to_dtype(dtype, span)?;
Ok(DataType::List(Box::new(dtype)))
}
_ if dtype.starts_with("datetime") => {
@@ -242,82 +242,82 @@ mod test {
#[test]
fn test_dtype_str_to_schema_simple_types() {
let dtype = "bool";
- let schema = dtype_str_to_schema(dtype, Span::unknown()).unwrap();
+ let schema = str_to_dtype(dtype, Span::unknown()).unwrap();
let expected = DataType::Boolean;
assert_eq!(schema, expected);
let dtype = "u8";
- let schema = dtype_str_to_schema(dtype, Span::unknown()).unwrap();
+ let schema = str_to_dtype(dtype, Span::unknown()).unwrap();
let expected = DataType::UInt8;
assert_eq!(schema, expected);
let dtype = "u16";
- let schema = dtype_str_to_schema(dtype, Span::unknown()).unwrap();
+ let schema = str_to_dtype(dtype, Span::unknown()).unwrap();
let expected = DataType::UInt16;
assert_eq!(schema, expected);
let dtype = "u32";
- let schema = dtype_str_to_schema(dtype, Span::unknown()).unwrap();
+ let schema = str_to_dtype(dtype, Span::unknown()).unwrap();
let expected = DataType::UInt32;
assert_eq!(schema, expected);
let dtype = "u64";
- let schema = dtype_str_to_schema(dtype, Span::unknown()).unwrap();
+ let schema = str_to_dtype(dtype, Span::unknown()).unwrap();
let expected = DataType::UInt64;
assert_eq!(schema, expected);
let dtype = "i8";
- let schema = dtype_str_to_schema(dtype, Span::unknown()).unwrap();
+ let schema = str_to_dtype(dtype, Span::unknown()).unwrap();
let expected = DataType::Int8;
assert_eq!(schema, expected);
let dtype = "i16";
- let schema = dtype_str_to_schema(dtype, Span::unknown()).unwrap();
+ let schema = str_to_dtype(dtype, Span::unknown()).unwrap();
let expected = DataType::Int16;
assert_eq!(schema, expected);
let dtype = "i32";
- let schema = dtype_str_to_schema(dtype, Span::unknown()).unwrap();
+ let schema = str_to_dtype(dtype, Span::unknown()).unwrap();
let expected = DataType::Int32;
assert_eq!(schema, expected);
let dtype = "i64";
- let schema = dtype_str_to_schema(dtype, Span::unknown()).unwrap();
+ let schema = str_to_dtype(dtype, Span::unknown()).unwrap();
let expected = DataType::Int64;
assert_eq!(schema, expected);
let dtype = "str";
- let schema = dtype_str_to_schema(dtype, Span::unknown()).unwrap();
+ let schema = str_to_dtype(dtype, Span::unknown()).unwrap();
let expected = DataType::String;
assert_eq!(schema, expected);
let dtype = "binary";
- let schema = dtype_str_to_schema(dtype, Span::unknown()).unwrap();
+ let schema = str_to_dtype(dtype, Span::unknown()).unwrap();
let expected = DataType::Binary;
assert_eq!(schema, expected);
let dtype = "date";
- let schema = dtype_str_to_schema(dtype, Span::unknown()).unwrap();
+ let schema = str_to_dtype(dtype, Span::unknown()).unwrap();
let expected = DataType::Date;
assert_eq!(schema, expected);
let dtype = "time";
- let schema = dtype_str_to_schema(dtype, Span::unknown()).unwrap();
+ let schema = str_to_dtype(dtype, Span::unknown()).unwrap();
let expected = DataType::Time;
assert_eq!(schema, expected);
let dtype = "null";
- let schema = dtype_str_to_schema(dtype, Span::unknown()).unwrap();
+ let schema = str_to_dtype(dtype, Span::unknown()).unwrap();
let expected = DataType::Null;
assert_eq!(schema, expected);
let dtype = "unknown";
- let schema = dtype_str_to_schema(dtype, Span::unknown()).unwrap();
+ let schema = str_to_dtype(dtype, Span::unknown()).unwrap();
let expected = DataType::Unknown;
assert_eq!(schema, expected);
let dtype = "object";
- let schema = dtype_str_to_schema(dtype, Span::unknown()).unwrap();
+ let schema = str_to_dtype(dtype, Span::unknown()).unwrap();
let expected = DataType::Object("unknown", None);
assert_eq!(schema, expected);
}
@@ -325,54 +325,54 @@ mod test {
#[test]
fn test_dtype_str_schema_datetime() {
let dtype = "datetime";
- let schema = dtype_str_to_schema(dtype, Span::unknown()).unwrap();
+ let schema = str_to_dtype(dtype, Span::unknown()).unwrap();
let expected = DataType::Datetime(TimeUnit::Milliseconds, None);
assert_eq!(schema, expected);
let dtype = "datetime";
- let schema = dtype_str_to_schema(dtype, Span::unknown()).unwrap();
+ let schema = str_to_dtype(dtype, Span::unknown()).unwrap();
let expected = DataType::Datetime(TimeUnit::Microseconds, None);
assert_eq!(schema, expected);
let dtype = "datetime<μs, *>";
- let schema = dtype_str_to_schema(dtype, Span::unknown()).unwrap();
+ let schema = str_to_dtype(dtype, Span::unknown()).unwrap();
let expected = DataType::Datetime(TimeUnit::Microseconds, None);
assert_eq!(schema, expected);
let dtype = "datetime";
- let schema = dtype_str_to_schema(dtype, Span::unknown()).unwrap();
+ let schema = str_to_dtype(dtype, Span::unknown()).unwrap();
let expected = DataType::Datetime(TimeUnit::Nanoseconds, None);
assert_eq!(schema, expected);
let dtype = "datetime";
- let schema = dtype_str_to_schema(dtype, Span::unknown()).unwrap();
+ let schema = str_to_dtype(dtype, Span::unknown()).unwrap();
let expected = DataType::Datetime(TimeUnit::Milliseconds, Some("UTC".into()));
assert_eq!(schema, expected);
let dtype = "invalid";
- let schema = dtype_str_to_schema(dtype, Span::unknown());
+ let schema = str_to_dtype(dtype, Span::unknown());
assert!(schema.is_err())
}
#[test]
fn test_dtype_str_schema_duration() {
let dtype = "duration";
- let schema = dtype_str_to_schema(dtype, Span::unknown()).unwrap();
+ let schema = str_to_dtype(dtype, Span::unknown()).unwrap();
let expected = DataType::Duration(TimeUnit::Milliseconds);
assert_eq!(schema, expected);
let dtype = "duration";
- let schema = dtype_str_to_schema(dtype, Span::unknown()).unwrap();
+ let schema = str_to_dtype(dtype, Span::unknown()).unwrap();
let expected = DataType::Duration(TimeUnit::Microseconds);
assert_eq!(schema, expected);
let dtype = "duration<μs>";
- let schema = dtype_str_to_schema(dtype, Span::unknown()).unwrap();
+ let schema = str_to_dtype(dtype, Span::unknown()).unwrap();
let expected = DataType::Duration(TimeUnit::Microseconds);
assert_eq!(schema, expected);
let dtype = "duration";
- let schema = dtype_str_to_schema(dtype, Span::unknown()).unwrap();
+ let schema = str_to_dtype(dtype, Span::unknown()).unwrap();
let expected = DataType::Duration(TimeUnit::Nanoseconds);
assert_eq!(schema, expected);
}
@@ -380,17 +380,17 @@ mod test {
#[test]
fn test_dtype_str_to_schema_list_types() {
let dtype = "list";
- let schema = dtype_str_to_schema(dtype, Span::unknown()).unwrap();
+ let schema = str_to_dtype(dtype, Span::unknown()).unwrap();
let expected = DataType::List(Box::new(DataType::Int32));
assert_eq!(schema, expected);
let dtype = "list>";
- let schema = dtype_str_to_schema(dtype, Span::unknown()).unwrap();
+ let schema = str_to_dtype(dtype, Span::unknown()).unwrap();
let expected = DataType::List(Box::new(DataType::Duration(TimeUnit::Milliseconds)));
assert_eq!(schema, expected);
let dtype = "list>";
- let schema = dtype_str_to_schema(dtype, Span::unknown()).unwrap();
+ let schema = str_to_dtype(dtype, Span::unknown()).unwrap();
let expected = DataType::List(Box::new(DataType::Datetime(TimeUnit::Milliseconds, None)));
assert_eq!(schema, expected);
}