polars cast: add decimal option for dtype parameter (#15464)

<!--
if this PR closes one or more issues, you can automatically link the PR
with
them by using one of the [*linking
keywords*](https://docs.github.com/en/issues/tracking-your-work-with-issues/linking-a-pull-request-to-an-issue#linking-a-pull-request-to-an-issue-using-a-keyword),
e.g.
- this PR should close #xxxx
- fixes #xxxx

you can also mention related issues, PRs or discussions!
-->

# Description
This PR expands the `dtype` parameter of the `polars cast` command to
include `decimal<precision, scale>` type. Setting precision to "*" will
compel inferring the value. Note, however, setting scale to a
non-integer value will throw an explicit error (the underlying polars
crate assigns scale = 0 in such a case, but I opted for throwing an
error instead). .

```
$ [[a b]; [1 2] [3 4]] | polars into-df | polars cast decimal<4,2> a | polars schema
╭───┬──────────────╮
│ a │ decimal<4,2> │
│ b │ i64          │
╰───┴──────────────╯

$ [[a b]; [10.5 2] [3.1 4]] | polars into-df | polars cast decimal<*,2> a | polars schema
╭───┬──────────────╮
│ a │ decimal<*,2> │
│ b │ i64          │
╰───┴──────────────╯

$ [[a b]; [10.05 2] [3.1 4]] | polars into-df | polars cast decimal<5,*> a | polars schema
rror:   × Invalid polars data type
   ╭─[entry #25:1:47]
 1 │ [[a b]; [10.05 2] [3.1 4]] | polars into-df | polars cast decimal<5,*> a | polars schema
   ·                                               ─────┬─────
   ·                                                    ╰── `*` is not a permitted value for scale
   ╰────
```

# User-Facing Changes
<!-- List of all changes that impact the user experience here. This
helps us keep track of breaking changes. -->
There are no breaking changes. The user has the additional option to
`polars cast` to a decimal type

# Tests + Formatting
Tests have been added to
`nu_plugin_polars/src/dataframe/values/nu_schema.rs`
This commit is contained in:
pyz4 2025-04-01 19:22:05 -04:00 committed by GitHub
parent a23e96c945
commit 470d130289
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -169,6 +169,67 @@ pub fn str_to_dtype(dtype: &str, span: Span) -> Result<DataType, ShellError> {
let time_unit = str_to_time_unit(next, span)?;
Ok(DataType::Duration(time_unit))
}
_ if dtype.starts_with("decimal") => {
let dtype = dtype
.trim_start_matches("decimal")
.trim_start_matches('<')
.trim_end_matches('>');
let mut split = dtype.split(',');
let next = split
.next()
.ok_or_else(|| ShellError::GenericError {
error: "Invalid polars data type".into(),
msg: "Missing decimal precision".into(),
span: Some(span),
help: None,
inner: vec![],
})?
.trim();
let precision = match next {
"*" => None, // infer
_ => Some(
next.parse::<usize>()
.map_err(|e| ShellError::GenericError {
error: "Invalid polars data type".into(),
msg: format!("Error in parsing decimal precision: {e}"),
span: Some(span),
help: None,
inner: vec![],
})?,
),
};
let next = split
.next()
.ok_or_else(|| ShellError::GenericError {
error: "Invalid polars data type".into(),
msg: "Missing decimal scale".into(),
span: Some(span),
help: None,
inner: vec![],
})?
.trim();
let scale = match next {
"*" => Err(ShellError::GenericError {
error: "Invalid polars data type".into(),
msg: "`*` is not a permitted value for scale".into(),
span: Some(span),
help: None,
inner: vec![],
}),
_ => next
.parse::<usize>()
.map(Some)
.map_err(|e| ShellError::GenericError {
error: "Invalid polars data type".into(),
msg: format!("Error in parsing decimal precision: {e}"),
span: Some(span),
help: None,
inner: vec![],
}),
}?;
Ok(DataType::Decimal(precision, scale))
}
_ => Err(ShellError::GenericError {
error: "Invalid polars data type".into(),
msg: format!("Unknown type: {dtype}"),
@ -367,6 +428,24 @@ mod test {
assert_eq!(schema, expected);
}
#[test]
fn test_dtype_str_schema_decimal() {
let dtype = "decimal<7,2>";
let schema = str_to_dtype(dtype, Span::unknown()).unwrap();
let expected = DataType::Decimal(Some(7usize), Some(2usize));
assert_eq!(schema, expected);
// "*" is not a permitted value for scale
let dtype = "decimal<7,*>";
let schema = str_to_dtype(dtype, Span::unknown());
assert!(matches!(schema, Err(ShellError::GenericError { .. })));
let dtype = "decimal<*,2>";
let schema = str_to_dtype(dtype, Span::unknown()).unwrap();
let expected = DataType::Decimal(None, Some(2usize));
assert_eq!(schema, expected);
}
#[test]
fn test_dtype_str_to_schema_list_types() {
let dtype = "list<i32>";
@ -383,5 +462,19 @@ mod test {
let schema = str_to_dtype(dtype, Span::unknown()).unwrap();
let expected = DataType::List(Box::new(DataType::Datetime(TimeUnit::Milliseconds, None)));
assert_eq!(schema, expected);
let dtype = "list<decimal<7,2>>";
let schema = str_to_dtype(dtype, Span::unknown()).unwrap();
let expected = DataType::List(Box::new(DataType::Decimal(Some(7usize), Some(2usize))));
assert_eq!(schema, expected);
let dtype = "list<decimal<*,2>>";
let schema = str_to_dtype(dtype, Span::unknown()).unwrap();
let expected = DataType::List(Box::new(DataType::Decimal(None, Some(2usize))));
assert_eq!(schema, expected);
let dtype = "list<decimal<7,*>>";
let schema = str_to_dtype(dtype, Span::unknown());
assert!(matches!(schema, Err(ShellError::GenericError { .. })));
}
}