From 470d130289315c46f0bef21541b9d39803a0b1a7 Mon Sep 17 00:00:00 2001 From: pyz4 <42039243+pyz4@users.noreply.github.com> Date: Tue, 1 Apr 2025 19:22:05 -0400 Subject: [PATCH] `polars cast`: add decimal option for dtype parameter (#15464) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit # Description This PR expands the `dtype` parameter of the `polars cast` command to include `decimal` type. Setting precision to "*" will compel inferring the value. Note, however, setting scale to a non-integer value will throw an explicit error (the underlying polars crate assigns scale = 0 in such a case, but I opted for throwing an error instead). . ``` $ [[a b]; [1 2] [3 4]] | polars into-df | polars cast decimal<4,2> a | polars schema ╭───┬──────────────╮ │ a │ decimal<4,2> │ │ b │ i64 │ ╰───┴──────────────╯ $ [[a b]; [10.5 2] [3.1 4]] | polars into-df | polars cast decimal<*,2> a | polars schema ╭───┬──────────────╮ │ a │ decimal<*,2> │ │ b │ i64 │ ╰───┴──────────────╯ $ [[a b]; [10.05 2] [3.1 4]] | polars into-df | polars cast decimal<5,*> a | polars schema rror: × Invalid polars data type ╭─[entry #25:1:47] 1 │ [[a b]; [10.05 2] [3.1 4]] | polars into-df | polars cast decimal<5,*> a | polars schema · ─────┬───── · ╰── `*` is not a permitted value for scale ╰──── ``` # User-Facing Changes There are no breaking changes. The user has the additional option to `polars cast` to a decimal type # Tests + Formatting Tests have been added to `nu_plugin_polars/src/dataframe/values/nu_schema.rs` --- .../src/dataframe/values/nu_schema.rs | 93 +++++++++++++++++++ 1 file changed, 93 insertions(+) diff --git a/crates/nu_plugin_polars/src/dataframe/values/nu_schema.rs b/crates/nu_plugin_polars/src/dataframe/values/nu_schema.rs index cf425eb046..1e2ae4723a 100644 --- a/crates/nu_plugin_polars/src/dataframe/values/nu_schema.rs +++ b/crates/nu_plugin_polars/src/dataframe/values/nu_schema.rs @@ -169,6 +169,67 @@ pub fn str_to_dtype(dtype: &str, span: Span) -> Result { let time_unit = str_to_time_unit(next, span)?; Ok(DataType::Duration(time_unit)) } + _ if dtype.starts_with("decimal") => { + let dtype = dtype + .trim_start_matches("decimal") + .trim_start_matches('<') + .trim_end_matches('>'); + let mut split = dtype.split(','); + let next = split + .next() + .ok_or_else(|| ShellError::GenericError { + error: "Invalid polars data type".into(), + msg: "Missing decimal precision".into(), + span: Some(span), + help: None, + inner: vec![], + })? + .trim(); + let precision = match next { + "*" => None, // infer + _ => Some( + next.parse::() + .map_err(|e| ShellError::GenericError { + error: "Invalid polars data type".into(), + msg: format!("Error in parsing decimal precision: {e}"), + span: Some(span), + help: None, + inner: vec![], + })?, + ), + }; + + let next = split + .next() + .ok_or_else(|| ShellError::GenericError { + error: "Invalid polars data type".into(), + msg: "Missing decimal scale".into(), + span: Some(span), + help: None, + inner: vec![], + })? + .trim(); + let scale = match next { + "*" => Err(ShellError::GenericError { + error: "Invalid polars data type".into(), + msg: "`*` is not a permitted value for scale".into(), + span: Some(span), + help: None, + inner: vec![], + }), + _ => next + .parse::() + .map(Some) + .map_err(|e| ShellError::GenericError { + error: "Invalid polars data type".into(), + msg: format!("Error in parsing decimal precision: {e}"), + span: Some(span), + help: None, + inner: vec![], + }), + }?; + Ok(DataType::Decimal(precision, scale)) + } _ => Err(ShellError::GenericError { error: "Invalid polars data type".into(), msg: format!("Unknown type: {dtype}"), @@ -367,6 +428,24 @@ mod test { assert_eq!(schema, expected); } + #[test] + fn test_dtype_str_schema_decimal() { + let dtype = "decimal<7,2>"; + let schema = str_to_dtype(dtype, Span::unknown()).unwrap(); + let expected = DataType::Decimal(Some(7usize), Some(2usize)); + assert_eq!(schema, expected); + + // "*" is not a permitted value for scale + let dtype = "decimal<7,*>"; + let schema = str_to_dtype(dtype, Span::unknown()); + assert!(matches!(schema, Err(ShellError::GenericError { .. }))); + + let dtype = "decimal<*,2>"; + let schema = str_to_dtype(dtype, Span::unknown()).unwrap(); + let expected = DataType::Decimal(None, Some(2usize)); + assert_eq!(schema, expected); + } + #[test] fn test_dtype_str_to_schema_list_types() { let dtype = "list"; @@ -383,5 +462,19 @@ mod test { let schema = str_to_dtype(dtype, Span::unknown()).unwrap(); let expected = DataType::List(Box::new(DataType::Datetime(TimeUnit::Milliseconds, None))); assert_eq!(schema, expected); + + let dtype = "list>"; + let schema = str_to_dtype(dtype, Span::unknown()).unwrap(); + let expected = DataType::List(Box::new(DataType::Decimal(Some(7usize), Some(2usize)))); + assert_eq!(schema, expected); + + let dtype = "list>"; + let schema = str_to_dtype(dtype, Span::unknown()).unwrap(); + let expected = DataType::List(Box::new(DataType::Decimal(None, Some(2usize)))); + assert_eq!(schema, expected); + + let dtype = "list>"; + let schema = str_to_dtype(dtype, Span::unknown()); + assert!(matches!(schema, Err(ShellError::GenericError { .. }))); } }