Ability to cast a dataframe's column to a different dtype (#11803)

Provides the ability to cast columns in dataframes, lazy dataframes, and
expressions.

<img width="587" alt="Screenshot 2024-02-14 at 13 53 01"
src="https://github.com/nushell/nushell/assets/56345/b894f746-0e37-472e-9fb0-eb6f71f2bf27">

<img width="616" alt="Screenshot 2024-02-14 at 13 52 37"
src="https://github.com/nushell/nushell/assets/56345/cf10efa7-d89c-4189-ab71-d368b2354d19">

<img width="626" alt="Screenshot 2024-02-14 at 13 54 58"
src="https://github.com/nushell/nushell/assets/56345/cd57cdf0-5096-41dd-8ab5-46e3d1e061b8">

---------

Co-authored-by: Jack Wright <jack.wright@disqo.com>
This commit is contained in:
Jack Wright 2024-02-14 16:15:00 -08:00 committed by GitHub
parent cb67de675e
commit 525acf9d9e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 245 additions and 34 deletions

View File

@ -0,0 +1,207 @@
use crate::dataframe::values::{str_to_dtype, NuExpression, NuLazyFrame};
use super::super::values::NuDataFrame;
use nu_engine::CallExt;
use nu_protocol::{
ast::Call,
engine::{Command, EngineState, Stack},
Category, Example, PipelineData, Record, ShellError, Signature, Span, SyntaxShape, Type, Value,
};
use polars::prelude::*;
#[derive(Clone)]
pub struct CastDF;
impl Command for CastDF {
fn name(&self) -> &str {
"dfr cast"
}
fn usage(&self) -> &str {
"Cast a column to a different dtype."
}
fn signature(&self) -> Signature {
Signature::build(self.name())
.input_output_types(vec![
(
Type::Custom("expression".into()),
Type::Custom("expression".into()),
),
(
Type::Custom("dataframe".into()),
Type::Custom("dataframe".into()),
),
])
.required(
"dtype",
SyntaxShape::String,
"The dtype to cast the column to",
)
.optional(
"column",
SyntaxShape::String,
"The column to cast. Required when used with a dataframe.",
)
.category(Category::Custom("dataframe".into()))
}
fn examples(&self) -> Vec<Example> {
vec![
Example {
description: "Cast a column in a dataframe to a different dtype",
example: "[[a b]; [1 2] [3 4]] | dfr into-df | dfr cast u8 a | dfr schema",
result: Some(Value::record(
Record::from_raw_cols_vals_unchecked(
vec!["a".to_string(), "b".to_string()],
vec![
Value::string("u8", Span::test_data()),
Value::string("i64", Span::test_data()),
],
),
Span::test_data(),
)),
},
Example {
description: "Cast a column in a lazy dataframe to a different dtype",
example: "[[a b]; [1 2] [3 4]] | dfr into-df | dfr into-lazy | dfr cast u8 a | dfr schema",
result: Some(Value::record(
Record::from_raw_cols_vals_unchecked(
vec!["a".to_string(), "b".to_string()],
vec![
Value::string("u8", Span::test_data()),
Value::string("i64", Span::test_data()),
],
),
Span::test_data(),
)),
},
Example {
description: "Cast a column in a expression to a different dtype",
example: r#"[[a b]; [1 2] [1 4]] | dfr into-df | dfr group-by a | dfr agg [ (dfr col b | dfr cast u8 | dfr min | dfr as "b_min") ] | dfr schema"#,
result: None
}
]
}
fn run(
&self,
engine_state: &EngineState,
stack: &mut Stack,
call: &Call,
input: PipelineData,
) -> Result<PipelineData, ShellError> {
let value = input.into_value(call.head);
if NuLazyFrame::can_downcast(&value) {
let (dtype, column_nm) = df_args(engine_state, stack, call)?;
let df = NuLazyFrame::try_from_value(value)?;
command_lazy(call, column_nm, dtype, df)
} else if NuDataFrame::can_downcast(&value) {
let (dtype, column_nm) = df_args(engine_state, stack, call)?;
let df = NuDataFrame::try_from_value(value)?;
command_eager(call, column_nm, dtype, df)
} else {
let dtype: String = call.req(engine_state, stack, 0)?;
let dtype = str_to_dtype(&dtype, call.head)?;
let expr = NuExpression::try_from_value(value)?;
let expr: NuExpression = expr.into_polars().cast(dtype).into();
Ok(PipelineData::Value(
NuExpression::into_value(expr, call.head),
None,
))
}
}
}
fn df_args(
engine_state: &EngineState,
stack: &mut Stack,
call: &Call,
) -> Result<(DataType, String), ShellError> {
let dtype = dtype_arg(engine_state, stack, call)?;
let column_nm: String =
call.opt(engine_state, stack, 1)?
.ok_or(ShellError::MissingParameter {
param_name: "column_name".into(),
span: call.head,
})?;
Ok((dtype, column_nm))
}
fn dtype_arg(
engine_state: &EngineState,
stack: &mut Stack,
call: &Call,
) -> Result<DataType, ShellError> {
let dtype: String = call.req(engine_state, stack, 0)?;
str_to_dtype(&dtype, call.head)
}
fn command_lazy(
call: &Call,
column_nm: String,
dtype: DataType,
lazy: NuLazyFrame,
) -> Result<PipelineData, ShellError> {
let column = col(&column_nm).cast(dtype);
let lazy = lazy.into_polars().with_columns(&[column]);
let lazy = NuLazyFrame::new(false, lazy);
Ok(PipelineData::Value(
NuLazyFrame::into_value(lazy, call.head)?,
None,
))
}
fn command_eager(
call: &Call,
column_nm: String,
dtype: DataType,
nu_df: NuDataFrame,
) -> Result<PipelineData, ShellError> {
let mut df = nu_df.df;
let column = df
.column(&column_nm)
.map_err(|e| ShellError::GenericError {
error: format!("{e}"),
msg: "".into(),
span: Some(call.head),
help: None,
inner: vec![],
})?;
let casted = column.cast(&dtype).map_err(|e| ShellError::GenericError {
error: format!("{e}"),
msg: "".into(),
span: Some(call.head),
help: None,
inner: vec![],
})?;
let _ = df
.with_column(casted)
.map_err(|e| ShellError::GenericError {
error: format!("{e}"),
msg: "".into(),
span: Some(call.head),
help: None,
inner: vec![],
})?;
let df = NuDataFrame::new(false, df);
Ok(PipelineData::Value(df.into_value(call.head), None))
}
#[cfg(test)]
mod test {
use super::super::super::test_dataframe::test_dataframe;
use super::*;
#[test]
fn test_examples() {
test_dataframe(vec![Box::new(CastDF {})])
}
}

View File

@ -1,4 +1,5 @@
mod append;
mod cast;
mod columns;
mod drop;
mod drop_duplicates;
@ -35,6 +36,7 @@ use nu_protocol::engine::StateWorkingSet;
pub use self::open::OpenDataFrame;
pub use append::AppendDF;
pub use cast::CastDF;
pub use columns::ColumnsDF;
pub use drop::DropDF;
pub use drop_duplicates::DropDuplicates;
@ -78,6 +80,7 @@ pub fn add_eager_decls(working_set: &mut StateWorkingSet) {
// Dataframe commands
bind_command!(
AppendDF,
CastDF,
ColumnsDF,
DataTypes,
Summary,

View File

@ -5,7 +5,7 @@ use nu_protocol::{
Example, PipelineData, Span,
};
use super::eager::ToDataFrame;
use super::eager::{SchemaDF, ToDataFrame};
use super::expressions::ExprCol;
use super::lazy::{LazyCollect, ToLazyFrame};
use nu_cmd_lang::Let;
@ -36,6 +36,7 @@ pub fn build_test_engine_state(cmds: Vec<Box<dyn Command + 'static>>) -> Box<Eng
working_set.add_decl(Box::new(ToLazyFrame));
working_set.add_decl(Box::new(LazyCollect));
working_set.add_decl(Box::new(ExprCol));
working_set.add_decl(Box::new(SchemaDF));
// Adding the command that is being tested to the working set
for cmd in cmds.clone() {

View File

@ -10,5 +10,5 @@ pub use nu_dataframe::{Axis, Column, NuDataFrame};
pub use nu_expression::NuExpression;
pub use nu_lazyframe::NuLazyFrame;
pub use nu_lazygroupby::NuLazyGroupBy;
pub use nu_schema::NuSchema;
pub use nu_schema::{str_to_dtype, NuSchema};
pub use nu_when::NuWhen;

View File

@ -73,7 +73,7 @@ fn value_to_fields(value: &Value, span: Span) -> Result<Vec<Field>, ShellError>
Ok(Field::new(col, dtype))
}
_ => {
let dtype = dtype_str_to_schema(&val.as_string()?, span)?;
let dtype = str_to_dtype(&val.as_string()?, span)?;
Ok(Field::new(col, dtype))
}
})
@ -81,7 +81,7 @@ fn value_to_fields(value: &Value, span: Span) -> Result<Vec<Field>, ShellError>
Ok(fields)
}
fn dtype_str_to_schema(dtype: &str, span: Span) -> Result<DataType, ShellError> {
pub fn str_to_dtype(dtype: &str, span: Span) -> Result<DataType, ShellError> {
match dtype {
"bool" => Ok(DataType::Boolean),
"u8" => Ok(DataType::UInt8),
@ -107,7 +107,7 @@ fn dtype_str_to_schema(dtype: &str, span: Span) -> Result<DataType, ShellError>
.trim_start_matches('<')
.trim_end_matches('>')
.trim();
let dtype = dtype_str_to_schema(dtype, span)?;
let dtype = str_to_dtype(dtype, span)?;
Ok(DataType::List(Box::new(dtype)))
}
_ if dtype.starts_with("datetime") => {
@ -242,82 +242,82 @@ mod test {
#[test]
fn test_dtype_str_to_schema_simple_types() {
let dtype = "bool";
let schema = dtype_str_to_schema(dtype, Span::unknown()).unwrap();
let schema = str_to_dtype(dtype, Span::unknown()).unwrap();
let expected = DataType::Boolean;
assert_eq!(schema, expected);
let dtype = "u8";
let schema = dtype_str_to_schema(dtype, Span::unknown()).unwrap();
let schema = str_to_dtype(dtype, Span::unknown()).unwrap();
let expected = DataType::UInt8;
assert_eq!(schema, expected);
let dtype = "u16";
let schema = dtype_str_to_schema(dtype, Span::unknown()).unwrap();
let schema = str_to_dtype(dtype, Span::unknown()).unwrap();
let expected = DataType::UInt16;
assert_eq!(schema, expected);
let dtype = "u32";
let schema = dtype_str_to_schema(dtype, Span::unknown()).unwrap();
let schema = str_to_dtype(dtype, Span::unknown()).unwrap();
let expected = DataType::UInt32;
assert_eq!(schema, expected);
let dtype = "u64";
let schema = dtype_str_to_schema(dtype, Span::unknown()).unwrap();
let schema = str_to_dtype(dtype, Span::unknown()).unwrap();
let expected = DataType::UInt64;
assert_eq!(schema, expected);
let dtype = "i8";
let schema = dtype_str_to_schema(dtype, Span::unknown()).unwrap();
let schema = str_to_dtype(dtype, Span::unknown()).unwrap();
let expected = DataType::Int8;
assert_eq!(schema, expected);
let dtype = "i16";
let schema = dtype_str_to_schema(dtype, Span::unknown()).unwrap();
let schema = str_to_dtype(dtype, Span::unknown()).unwrap();
let expected = DataType::Int16;
assert_eq!(schema, expected);
let dtype = "i32";
let schema = dtype_str_to_schema(dtype, Span::unknown()).unwrap();
let schema = str_to_dtype(dtype, Span::unknown()).unwrap();
let expected = DataType::Int32;
assert_eq!(schema, expected);
let dtype = "i64";
let schema = dtype_str_to_schema(dtype, Span::unknown()).unwrap();
let schema = str_to_dtype(dtype, Span::unknown()).unwrap();
let expected = DataType::Int64;
assert_eq!(schema, expected);
let dtype = "str";
let schema = dtype_str_to_schema(dtype, Span::unknown()).unwrap();
let schema = str_to_dtype(dtype, Span::unknown()).unwrap();
let expected = DataType::String;
assert_eq!(schema, expected);
let dtype = "binary";
let schema = dtype_str_to_schema(dtype, Span::unknown()).unwrap();
let schema = str_to_dtype(dtype, Span::unknown()).unwrap();
let expected = DataType::Binary;
assert_eq!(schema, expected);
let dtype = "date";
let schema = dtype_str_to_schema(dtype, Span::unknown()).unwrap();
let schema = str_to_dtype(dtype, Span::unknown()).unwrap();
let expected = DataType::Date;
assert_eq!(schema, expected);
let dtype = "time";
let schema = dtype_str_to_schema(dtype, Span::unknown()).unwrap();
let schema = str_to_dtype(dtype, Span::unknown()).unwrap();
let expected = DataType::Time;
assert_eq!(schema, expected);
let dtype = "null";
let schema = dtype_str_to_schema(dtype, Span::unknown()).unwrap();
let schema = str_to_dtype(dtype, Span::unknown()).unwrap();
let expected = DataType::Null;
assert_eq!(schema, expected);
let dtype = "unknown";
let schema = dtype_str_to_schema(dtype, Span::unknown()).unwrap();
let schema = str_to_dtype(dtype, Span::unknown()).unwrap();
let expected = DataType::Unknown;
assert_eq!(schema, expected);
let dtype = "object";
let schema = dtype_str_to_schema(dtype, Span::unknown()).unwrap();
let schema = str_to_dtype(dtype, Span::unknown()).unwrap();
let expected = DataType::Object("unknown", None);
assert_eq!(schema, expected);
}
@ -325,54 +325,54 @@ mod test {
#[test]
fn test_dtype_str_schema_datetime() {
let dtype = "datetime<ms, *>";
let schema = dtype_str_to_schema(dtype, Span::unknown()).unwrap();
let schema = str_to_dtype(dtype, Span::unknown()).unwrap();
let expected = DataType::Datetime(TimeUnit::Milliseconds, None);
assert_eq!(schema, expected);
let dtype = "datetime<us, *>";
let schema = dtype_str_to_schema(dtype, Span::unknown()).unwrap();
let schema = str_to_dtype(dtype, Span::unknown()).unwrap();
let expected = DataType::Datetime(TimeUnit::Microseconds, None);
assert_eq!(schema, expected);
let dtype = "datetime<μs, *>";
let schema = dtype_str_to_schema(dtype, Span::unknown()).unwrap();
let schema = str_to_dtype(dtype, Span::unknown()).unwrap();
let expected = DataType::Datetime(TimeUnit::Microseconds, None);
assert_eq!(schema, expected);
let dtype = "datetime<ns, *>";
let schema = dtype_str_to_schema(dtype, Span::unknown()).unwrap();
let schema = str_to_dtype(dtype, Span::unknown()).unwrap();
let expected = DataType::Datetime(TimeUnit::Nanoseconds, None);
assert_eq!(schema, expected);
let dtype = "datetime<ms, UTC>";
let schema = dtype_str_to_schema(dtype, Span::unknown()).unwrap();
let schema = str_to_dtype(dtype, Span::unknown()).unwrap();
let expected = DataType::Datetime(TimeUnit::Milliseconds, Some("UTC".into()));
assert_eq!(schema, expected);
let dtype = "invalid";
let schema = dtype_str_to_schema(dtype, Span::unknown());
let schema = str_to_dtype(dtype, Span::unknown());
assert!(schema.is_err())
}
#[test]
fn test_dtype_str_schema_duration() {
let dtype = "duration<ms>";
let schema = dtype_str_to_schema(dtype, Span::unknown()).unwrap();
let schema = str_to_dtype(dtype, Span::unknown()).unwrap();
let expected = DataType::Duration(TimeUnit::Milliseconds);
assert_eq!(schema, expected);
let dtype = "duration<us>";
let schema = dtype_str_to_schema(dtype, Span::unknown()).unwrap();
let schema = str_to_dtype(dtype, Span::unknown()).unwrap();
let expected = DataType::Duration(TimeUnit::Microseconds);
assert_eq!(schema, expected);
let dtype = "duration<μs>";
let schema = dtype_str_to_schema(dtype, Span::unknown()).unwrap();
let schema = str_to_dtype(dtype, Span::unknown()).unwrap();
let expected = DataType::Duration(TimeUnit::Microseconds);
assert_eq!(schema, expected);
let dtype = "duration<ns>";
let schema = dtype_str_to_schema(dtype, Span::unknown()).unwrap();
let schema = str_to_dtype(dtype, Span::unknown()).unwrap();
let expected = DataType::Duration(TimeUnit::Nanoseconds);
assert_eq!(schema, expected);
}
@ -380,17 +380,17 @@ mod test {
#[test]
fn test_dtype_str_to_schema_list_types() {
let dtype = "list<i32>";
let schema = dtype_str_to_schema(dtype, Span::unknown()).unwrap();
let schema = str_to_dtype(dtype, Span::unknown()).unwrap();
let expected = DataType::List(Box::new(DataType::Int32));
assert_eq!(schema, expected);
let dtype = "list<duration<ms>>";
let schema = dtype_str_to_schema(dtype, Span::unknown()).unwrap();
let schema = str_to_dtype(dtype, Span::unknown()).unwrap();
let expected = DataType::List(Box::new(DataType::Duration(TimeUnit::Milliseconds)));
assert_eq!(schema, expected);
let dtype = "list<datetime<ms, *>>";
let schema = dtype_str_to_schema(dtype, Span::unknown()).unwrap();
let schema = str_to_dtype(dtype, Span::unknown()).unwrap();
let expected = DataType::List(Box::new(DataType::Datetime(TimeUnit::Milliseconds, None)));
assert_eq!(schema, expected);
}