mirror of
https://github.com/nushell/nushell.git
synced 2024-11-07 17:14:23 +01:00
Ability to cast a dataframe's column to a different dtype (#11803)
Provides the ability to cast columns in dataframes, lazy dataframes, and expressions. <img width="587" alt="Screenshot 2024-02-14 at 13 53 01" src="https://github.com/nushell/nushell/assets/56345/b894f746-0e37-472e-9fb0-eb6f71f2bf27"> <img width="616" alt="Screenshot 2024-02-14 at 13 52 37" src="https://github.com/nushell/nushell/assets/56345/cf10efa7-d89c-4189-ab71-d368b2354d19"> <img width="626" alt="Screenshot 2024-02-14 at 13 54 58" src="https://github.com/nushell/nushell/assets/56345/cd57cdf0-5096-41dd-8ab5-46e3d1e061b8"> --------- Co-authored-by: Jack Wright <jack.wright@disqo.com>
This commit is contained in:
parent
cb67de675e
commit
525acf9d9e
207
crates/nu-cmd-dataframe/src/dataframe/eager/cast.rs
Normal file
207
crates/nu-cmd-dataframe/src/dataframe/eager/cast.rs
Normal file
@ -0,0 +1,207 @@
|
||||
use crate::dataframe::values::{str_to_dtype, NuExpression, NuLazyFrame};
|
||||
|
||||
use super::super::values::NuDataFrame;
|
||||
use nu_engine::CallExt;
|
||||
use nu_protocol::{
|
||||
ast::Call,
|
||||
engine::{Command, EngineState, Stack},
|
||||
Category, Example, PipelineData, Record, ShellError, Signature, Span, SyntaxShape, Type, Value,
|
||||
};
|
||||
use polars::prelude::*;
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct CastDF;
|
||||
|
||||
impl Command for CastDF {
|
||||
fn name(&self) -> &str {
|
||||
"dfr cast"
|
||||
}
|
||||
|
||||
fn usage(&self) -> &str {
|
||||
"Cast a column to a different dtype."
|
||||
}
|
||||
|
||||
fn signature(&self) -> Signature {
|
||||
Signature::build(self.name())
|
||||
.input_output_types(vec![
|
||||
(
|
||||
Type::Custom("expression".into()),
|
||||
Type::Custom("expression".into()),
|
||||
),
|
||||
(
|
||||
Type::Custom("dataframe".into()),
|
||||
Type::Custom("dataframe".into()),
|
||||
),
|
||||
])
|
||||
.required(
|
||||
"dtype",
|
||||
SyntaxShape::String,
|
||||
"The dtype to cast the column to",
|
||||
)
|
||||
.optional(
|
||||
"column",
|
||||
SyntaxShape::String,
|
||||
"The column to cast. Required when used with a dataframe.",
|
||||
)
|
||||
.category(Category::Custom("dataframe".into()))
|
||||
}
|
||||
|
||||
fn examples(&self) -> Vec<Example> {
|
||||
vec![
|
||||
Example {
|
||||
description: "Cast a column in a dataframe to a different dtype",
|
||||
example: "[[a b]; [1 2] [3 4]] | dfr into-df | dfr cast u8 a | dfr schema",
|
||||
result: Some(Value::record(
|
||||
Record::from_raw_cols_vals_unchecked(
|
||||
vec!["a".to_string(), "b".to_string()],
|
||||
vec![
|
||||
Value::string("u8", Span::test_data()),
|
||||
Value::string("i64", Span::test_data()),
|
||||
],
|
||||
),
|
||||
Span::test_data(),
|
||||
)),
|
||||
},
|
||||
Example {
|
||||
description: "Cast a column in a lazy dataframe to a different dtype",
|
||||
example: "[[a b]; [1 2] [3 4]] | dfr into-df | dfr into-lazy | dfr cast u8 a | dfr schema",
|
||||
result: Some(Value::record(
|
||||
Record::from_raw_cols_vals_unchecked(
|
||||
vec!["a".to_string(), "b".to_string()],
|
||||
vec![
|
||||
Value::string("u8", Span::test_data()),
|
||||
Value::string("i64", Span::test_data()),
|
||||
],
|
||||
),
|
||||
Span::test_data(),
|
||||
)),
|
||||
},
|
||||
Example {
|
||||
description: "Cast a column in a expression to a different dtype",
|
||||
example: r#"[[a b]; [1 2] [1 4]] | dfr into-df | dfr group-by a | dfr agg [ (dfr col b | dfr cast u8 | dfr min | dfr as "b_min") ] | dfr schema"#,
|
||||
result: None
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
fn run(
|
||||
&self,
|
||||
engine_state: &EngineState,
|
||||
stack: &mut Stack,
|
||||
call: &Call,
|
||||
input: PipelineData,
|
||||
) -> Result<PipelineData, ShellError> {
|
||||
let value = input.into_value(call.head);
|
||||
if NuLazyFrame::can_downcast(&value) {
|
||||
let (dtype, column_nm) = df_args(engine_state, stack, call)?;
|
||||
let df = NuLazyFrame::try_from_value(value)?;
|
||||
command_lazy(call, column_nm, dtype, df)
|
||||
} else if NuDataFrame::can_downcast(&value) {
|
||||
let (dtype, column_nm) = df_args(engine_state, stack, call)?;
|
||||
let df = NuDataFrame::try_from_value(value)?;
|
||||
command_eager(call, column_nm, dtype, df)
|
||||
} else {
|
||||
let dtype: String = call.req(engine_state, stack, 0)?;
|
||||
let dtype = str_to_dtype(&dtype, call.head)?;
|
||||
|
||||
let expr = NuExpression::try_from_value(value)?;
|
||||
let expr: NuExpression = expr.into_polars().cast(dtype).into();
|
||||
|
||||
Ok(PipelineData::Value(
|
||||
NuExpression::into_value(expr, call.head),
|
||||
None,
|
||||
))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn df_args(
|
||||
engine_state: &EngineState,
|
||||
stack: &mut Stack,
|
||||
call: &Call,
|
||||
) -> Result<(DataType, String), ShellError> {
|
||||
let dtype = dtype_arg(engine_state, stack, call)?;
|
||||
let column_nm: String =
|
||||
call.opt(engine_state, stack, 1)?
|
||||
.ok_or(ShellError::MissingParameter {
|
||||
param_name: "column_name".into(),
|
||||
span: call.head,
|
||||
})?;
|
||||
Ok((dtype, column_nm))
|
||||
}
|
||||
|
||||
fn dtype_arg(
|
||||
engine_state: &EngineState,
|
||||
stack: &mut Stack,
|
||||
call: &Call,
|
||||
) -> Result<DataType, ShellError> {
|
||||
let dtype: String = call.req(engine_state, stack, 0)?;
|
||||
str_to_dtype(&dtype, call.head)
|
||||
}
|
||||
|
||||
fn command_lazy(
|
||||
call: &Call,
|
||||
column_nm: String,
|
||||
dtype: DataType,
|
||||
lazy: NuLazyFrame,
|
||||
) -> Result<PipelineData, ShellError> {
|
||||
let column = col(&column_nm).cast(dtype);
|
||||
let lazy = lazy.into_polars().with_columns(&[column]);
|
||||
let lazy = NuLazyFrame::new(false, lazy);
|
||||
|
||||
Ok(PipelineData::Value(
|
||||
NuLazyFrame::into_value(lazy, call.head)?,
|
||||
None,
|
||||
))
|
||||
}
|
||||
|
||||
fn command_eager(
|
||||
call: &Call,
|
||||
column_nm: String,
|
||||
dtype: DataType,
|
||||
nu_df: NuDataFrame,
|
||||
) -> Result<PipelineData, ShellError> {
|
||||
let mut df = nu_df.df;
|
||||
let column = df
|
||||
.column(&column_nm)
|
||||
.map_err(|e| ShellError::GenericError {
|
||||
error: format!("{e}"),
|
||||
msg: "".into(),
|
||||
span: Some(call.head),
|
||||
help: None,
|
||||
inner: vec![],
|
||||
})?;
|
||||
|
||||
let casted = column.cast(&dtype).map_err(|e| ShellError::GenericError {
|
||||
error: format!("{e}"),
|
||||
msg: "".into(),
|
||||
span: Some(call.head),
|
||||
help: None,
|
||||
inner: vec![],
|
||||
})?;
|
||||
|
||||
let _ = df
|
||||
.with_column(casted)
|
||||
.map_err(|e| ShellError::GenericError {
|
||||
error: format!("{e}"),
|
||||
msg: "".into(),
|
||||
span: Some(call.head),
|
||||
help: None,
|
||||
inner: vec![],
|
||||
})?;
|
||||
|
||||
let df = NuDataFrame::new(false, df);
|
||||
Ok(PipelineData::Value(df.into_value(call.head), None))
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
|
||||
use super::super::super::test_dataframe::test_dataframe;
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_examples() {
|
||||
test_dataframe(vec![Box::new(CastDF {})])
|
||||
}
|
||||
}
|
@ -1,4 +1,5 @@
|
||||
mod append;
|
||||
mod cast;
|
||||
mod columns;
|
||||
mod drop;
|
||||
mod drop_duplicates;
|
||||
@ -35,6 +36,7 @@ use nu_protocol::engine::StateWorkingSet;
|
||||
|
||||
pub use self::open::OpenDataFrame;
|
||||
pub use append::AppendDF;
|
||||
pub use cast::CastDF;
|
||||
pub use columns::ColumnsDF;
|
||||
pub use drop::DropDF;
|
||||
pub use drop_duplicates::DropDuplicates;
|
||||
@ -78,6 +80,7 @@ pub fn add_eager_decls(working_set: &mut StateWorkingSet) {
|
||||
// Dataframe commands
|
||||
bind_command!(
|
||||
AppendDF,
|
||||
CastDF,
|
||||
ColumnsDF,
|
||||
DataTypes,
|
||||
Summary,
|
||||
|
@ -5,7 +5,7 @@ use nu_protocol::{
|
||||
Example, PipelineData, Span,
|
||||
};
|
||||
|
||||
use super::eager::ToDataFrame;
|
||||
use super::eager::{SchemaDF, ToDataFrame};
|
||||
use super::expressions::ExprCol;
|
||||
use super::lazy::{LazyCollect, ToLazyFrame};
|
||||
use nu_cmd_lang::Let;
|
||||
@ -36,6 +36,7 @@ pub fn build_test_engine_state(cmds: Vec<Box<dyn Command + 'static>>) -> Box<Eng
|
||||
working_set.add_decl(Box::new(ToLazyFrame));
|
||||
working_set.add_decl(Box::new(LazyCollect));
|
||||
working_set.add_decl(Box::new(ExprCol));
|
||||
working_set.add_decl(Box::new(SchemaDF));
|
||||
|
||||
// Adding the command that is being tested to the working set
|
||||
for cmd in cmds.clone() {
|
||||
|
@ -10,5 +10,5 @@ pub use nu_dataframe::{Axis, Column, NuDataFrame};
|
||||
pub use nu_expression::NuExpression;
|
||||
pub use nu_lazyframe::NuLazyFrame;
|
||||
pub use nu_lazygroupby::NuLazyGroupBy;
|
||||
pub use nu_schema::NuSchema;
|
||||
pub use nu_schema::{str_to_dtype, NuSchema};
|
||||
pub use nu_when::NuWhen;
|
||||
|
@ -73,7 +73,7 @@ fn value_to_fields(value: &Value, span: Span) -> Result<Vec<Field>, ShellError>
|
||||
Ok(Field::new(col, dtype))
|
||||
}
|
||||
_ => {
|
||||
let dtype = dtype_str_to_schema(&val.as_string()?, span)?;
|
||||
let dtype = str_to_dtype(&val.as_string()?, span)?;
|
||||
Ok(Field::new(col, dtype))
|
||||
}
|
||||
})
|
||||
@ -81,7 +81,7 @@ fn value_to_fields(value: &Value, span: Span) -> Result<Vec<Field>, ShellError>
|
||||
Ok(fields)
|
||||
}
|
||||
|
||||
fn dtype_str_to_schema(dtype: &str, span: Span) -> Result<DataType, ShellError> {
|
||||
pub fn str_to_dtype(dtype: &str, span: Span) -> Result<DataType, ShellError> {
|
||||
match dtype {
|
||||
"bool" => Ok(DataType::Boolean),
|
||||
"u8" => Ok(DataType::UInt8),
|
||||
@ -107,7 +107,7 @@ fn dtype_str_to_schema(dtype: &str, span: Span) -> Result<DataType, ShellError>
|
||||
.trim_start_matches('<')
|
||||
.trim_end_matches('>')
|
||||
.trim();
|
||||
let dtype = dtype_str_to_schema(dtype, span)?;
|
||||
let dtype = str_to_dtype(dtype, span)?;
|
||||
Ok(DataType::List(Box::new(dtype)))
|
||||
}
|
||||
_ if dtype.starts_with("datetime") => {
|
||||
@ -242,82 +242,82 @@ mod test {
|
||||
#[test]
|
||||
fn test_dtype_str_to_schema_simple_types() {
|
||||
let dtype = "bool";
|
||||
let schema = dtype_str_to_schema(dtype, Span::unknown()).unwrap();
|
||||
let schema = str_to_dtype(dtype, Span::unknown()).unwrap();
|
||||
let expected = DataType::Boolean;
|
||||
assert_eq!(schema, expected);
|
||||
|
||||
let dtype = "u8";
|
||||
let schema = dtype_str_to_schema(dtype, Span::unknown()).unwrap();
|
||||
let schema = str_to_dtype(dtype, Span::unknown()).unwrap();
|
||||
let expected = DataType::UInt8;
|
||||
assert_eq!(schema, expected);
|
||||
|
||||
let dtype = "u16";
|
||||
let schema = dtype_str_to_schema(dtype, Span::unknown()).unwrap();
|
||||
let schema = str_to_dtype(dtype, Span::unknown()).unwrap();
|
||||
let expected = DataType::UInt16;
|
||||
assert_eq!(schema, expected);
|
||||
|
||||
let dtype = "u32";
|
||||
let schema = dtype_str_to_schema(dtype, Span::unknown()).unwrap();
|
||||
let schema = str_to_dtype(dtype, Span::unknown()).unwrap();
|
||||
let expected = DataType::UInt32;
|
||||
assert_eq!(schema, expected);
|
||||
|
||||
let dtype = "u64";
|
||||
let schema = dtype_str_to_schema(dtype, Span::unknown()).unwrap();
|
||||
let schema = str_to_dtype(dtype, Span::unknown()).unwrap();
|
||||
let expected = DataType::UInt64;
|
||||
assert_eq!(schema, expected);
|
||||
|
||||
let dtype = "i8";
|
||||
let schema = dtype_str_to_schema(dtype, Span::unknown()).unwrap();
|
||||
let schema = str_to_dtype(dtype, Span::unknown()).unwrap();
|
||||
let expected = DataType::Int8;
|
||||
assert_eq!(schema, expected);
|
||||
|
||||
let dtype = "i16";
|
||||
let schema = dtype_str_to_schema(dtype, Span::unknown()).unwrap();
|
||||
let schema = str_to_dtype(dtype, Span::unknown()).unwrap();
|
||||
let expected = DataType::Int16;
|
||||
assert_eq!(schema, expected);
|
||||
|
||||
let dtype = "i32";
|
||||
let schema = dtype_str_to_schema(dtype, Span::unknown()).unwrap();
|
||||
let schema = str_to_dtype(dtype, Span::unknown()).unwrap();
|
||||
let expected = DataType::Int32;
|
||||
assert_eq!(schema, expected);
|
||||
|
||||
let dtype = "i64";
|
||||
let schema = dtype_str_to_schema(dtype, Span::unknown()).unwrap();
|
||||
let schema = str_to_dtype(dtype, Span::unknown()).unwrap();
|
||||
let expected = DataType::Int64;
|
||||
assert_eq!(schema, expected);
|
||||
|
||||
let dtype = "str";
|
||||
let schema = dtype_str_to_schema(dtype, Span::unknown()).unwrap();
|
||||
let schema = str_to_dtype(dtype, Span::unknown()).unwrap();
|
||||
let expected = DataType::String;
|
||||
assert_eq!(schema, expected);
|
||||
|
||||
let dtype = "binary";
|
||||
let schema = dtype_str_to_schema(dtype, Span::unknown()).unwrap();
|
||||
let schema = str_to_dtype(dtype, Span::unknown()).unwrap();
|
||||
let expected = DataType::Binary;
|
||||
assert_eq!(schema, expected);
|
||||
|
||||
let dtype = "date";
|
||||
let schema = dtype_str_to_schema(dtype, Span::unknown()).unwrap();
|
||||
let schema = str_to_dtype(dtype, Span::unknown()).unwrap();
|
||||
let expected = DataType::Date;
|
||||
assert_eq!(schema, expected);
|
||||
|
||||
let dtype = "time";
|
||||
let schema = dtype_str_to_schema(dtype, Span::unknown()).unwrap();
|
||||
let schema = str_to_dtype(dtype, Span::unknown()).unwrap();
|
||||
let expected = DataType::Time;
|
||||
assert_eq!(schema, expected);
|
||||
|
||||
let dtype = "null";
|
||||
let schema = dtype_str_to_schema(dtype, Span::unknown()).unwrap();
|
||||
let schema = str_to_dtype(dtype, Span::unknown()).unwrap();
|
||||
let expected = DataType::Null;
|
||||
assert_eq!(schema, expected);
|
||||
|
||||
let dtype = "unknown";
|
||||
let schema = dtype_str_to_schema(dtype, Span::unknown()).unwrap();
|
||||
let schema = str_to_dtype(dtype, Span::unknown()).unwrap();
|
||||
let expected = DataType::Unknown;
|
||||
assert_eq!(schema, expected);
|
||||
|
||||
let dtype = "object";
|
||||
let schema = dtype_str_to_schema(dtype, Span::unknown()).unwrap();
|
||||
let schema = str_to_dtype(dtype, Span::unknown()).unwrap();
|
||||
let expected = DataType::Object("unknown", None);
|
||||
assert_eq!(schema, expected);
|
||||
}
|
||||
@ -325,54 +325,54 @@ mod test {
|
||||
#[test]
|
||||
fn test_dtype_str_schema_datetime() {
|
||||
let dtype = "datetime<ms, *>";
|
||||
let schema = dtype_str_to_schema(dtype, Span::unknown()).unwrap();
|
||||
let schema = str_to_dtype(dtype, Span::unknown()).unwrap();
|
||||
let expected = DataType::Datetime(TimeUnit::Milliseconds, None);
|
||||
assert_eq!(schema, expected);
|
||||
|
||||
let dtype = "datetime<us, *>";
|
||||
let schema = dtype_str_to_schema(dtype, Span::unknown()).unwrap();
|
||||
let schema = str_to_dtype(dtype, Span::unknown()).unwrap();
|
||||
let expected = DataType::Datetime(TimeUnit::Microseconds, None);
|
||||
assert_eq!(schema, expected);
|
||||
|
||||
let dtype = "datetime<μs, *>";
|
||||
let schema = dtype_str_to_schema(dtype, Span::unknown()).unwrap();
|
||||
let schema = str_to_dtype(dtype, Span::unknown()).unwrap();
|
||||
let expected = DataType::Datetime(TimeUnit::Microseconds, None);
|
||||
assert_eq!(schema, expected);
|
||||
|
||||
let dtype = "datetime<ns, *>";
|
||||
let schema = dtype_str_to_schema(dtype, Span::unknown()).unwrap();
|
||||
let schema = str_to_dtype(dtype, Span::unknown()).unwrap();
|
||||
let expected = DataType::Datetime(TimeUnit::Nanoseconds, None);
|
||||
assert_eq!(schema, expected);
|
||||
|
||||
let dtype = "datetime<ms, UTC>";
|
||||
let schema = dtype_str_to_schema(dtype, Span::unknown()).unwrap();
|
||||
let schema = str_to_dtype(dtype, Span::unknown()).unwrap();
|
||||
let expected = DataType::Datetime(TimeUnit::Milliseconds, Some("UTC".into()));
|
||||
assert_eq!(schema, expected);
|
||||
|
||||
let dtype = "invalid";
|
||||
let schema = dtype_str_to_schema(dtype, Span::unknown());
|
||||
let schema = str_to_dtype(dtype, Span::unknown());
|
||||
assert!(schema.is_err())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_dtype_str_schema_duration() {
|
||||
let dtype = "duration<ms>";
|
||||
let schema = dtype_str_to_schema(dtype, Span::unknown()).unwrap();
|
||||
let schema = str_to_dtype(dtype, Span::unknown()).unwrap();
|
||||
let expected = DataType::Duration(TimeUnit::Milliseconds);
|
||||
assert_eq!(schema, expected);
|
||||
|
||||
let dtype = "duration<us>";
|
||||
let schema = dtype_str_to_schema(dtype, Span::unknown()).unwrap();
|
||||
let schema = str_to_dtype(dtype, Span::unknown()).unwrap();
|
||||
let expected = DataType::Duration(TimeUnit::Microseconds);
|
||||
assert_eq!(schema, expected);
|
||||
|
||||
let dtype = "duration<μs>";
|
||||
let schema = dtype_str_to_schema(dtype, Span::unknown()).unwrap();
|
||||
let schema = str_to_dtype(dtype, Span::unknown()).unwrap();
|
||||
let expected = DataType::Duration(TimeUnit::Microseconds);
|
||||
assert_eq!(schema, expected);
|
||||
|
||||
let dtype = "duration<ns>";
|
||||
let schema = dtype_str_to_schema(dtype, Span::unknown()).unwrap();
|
||||
let schema = str_to_dtype(dtype, Span::unknown()).unwrap();
|
||||
let expected = DataType::Duration(TimeUnit::Nanoseconds);
|
||||
assert_eq!(schema, expected);
|
||||
}
|
||||
@ -380,17 +380,17 @@ mod test {
|
||||
#[test]
|
||||
fn test_dtype_str_to_schema_list_types() {
|
||||
let dtype = "list<i32>";
|
||||
let schema = dtype_str_to_schema(dtype, Span::unknown()).unwrap();
|
||||
let schema = str_to_dtype(dtype, Span::unknown()).unwrap();
|
||||
let expected = DataType::List(Box::new(DataType::Int32));
|
||||
assert_eq!(schema, expected);
|
||||
|
||||
let dtype = "list<duration<ms>>";
|
||||
let schema = dtype_str_to_schema(dtype, Span::unknown()).unwrap();
|
||||
let schema = str_to_dtype(dtype, Span::unknown()).unwrap();
|
||||
let expected = DataType::List(Box::new(DataType::Duration(TimeUnit::Milliseconds)));
|
||||
assert_eq!(schema, expected);
|
||||
|
||||
let dtype = "list<datetime<ms, *>>";
|
||||
let schema = dtype_str_to_schema(dtype, Span::unknown()).unwrap();
|
||||
let schema = str_to_dtype(dtype, Span::unknown()).unwrap();
|
||||
let expected = DataType::List(Box::new(DataType::Datetime(TimeUnit::Milliseconds, None)));
|
||||
assert_eq!(schema, expected);
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user