mirror of
https://github.com/nushell/nushell.git
synced 2025-05-06 11:04:24 +02:00
polars cast
: add decimal option for dtype parameter (#15464)
<!-- if this PR closes one or more issues, you can automatically link the PR with them by using one of the [*linking keywords*](https://docs.github.com/en/issues/tracking-your-work-with-issues/linking-a-pull-request-to-an-issue#linking-a-pull-request-to-an-issue-using-a-keyword), e.g. - this PR should close #xxxx - fixes #xxxx you can also mention related issues, PRs or discussions! --> # Description This PR expands the `dtype` parameter of the `polars cast` command to include `decimal<precision, scale>` type. Setting precision to "*" will compel inferring the value. Note, however, setting scale to a non-integer value will throw an explicit error (the underlying polars crate assigns scale = 0 in such a case, but I opted for throwing an error instead). . ``` $ [[a b]; [1 2] [3 4]] | polars into-df | polars cast decimal<4,2> a | polars schema ╭───┬──────────────╮ │ a │ decimal<4,2> │ │ b │ i64 │ ╰───┴──────────────╯ $ [[a b]; [10.5 2] [3.1 4]] | polars into-df | polars cast decimal<*,2> a | polars schema ╭───┬──────────────╮ │ a │ decimal<*,2> │ │ b │ i64 │ ╰───┴──────────────╯ $ [[a b]; [10.05 2] [3.1 4]] | polars into-df | polars cast decimal<5,*> a | polars schema rror: × Invalid polars data type ╭─[entry #25:1:47] 1 │ [[a b]; [10.05 2] [3.1 4]] | polars into-df | polars cast decimal<5,*> a | polars schema · ─────┬───── · ╰── `*` is not a permitted value for scale ╰──── ``` # User-Facing Changes <!-- List of all changes that impact the user experience here. This helps us keep track of breaking changes. --> There are no breaking changes. The user has the additional option to `polars cast` to a decimal type # Tests + Formatting Tests have been added to `nu_plugin_polars/src/dataframe/values/nu_schema.rs`
This commit is contained in:
parent
a23e96c945
commit
470d130289
@ -169,6 +169,67 @@ pub fn str_to_dtype(dtype: &str, span: Span) -> Result<DataType, ShellError> {
|
||||
let time_unit = str_to_time_unit(next, span)?;
|
||||
Ok(DataType::Duration(time_unit))
|
||||
}
|
||||
_ if dtype.starts_with("decimal") => {
|
||||
let dtype = dtype
|
||||
.trim_start_matches("decimal")
|
||||
.trim_start_matches('<')
|
||||
.trim_end_matches('>');
|
||||
let mut split = dtype.split(',');
|
||||
let next = split
|
||||
.next()
|
||||
.ok_or_else(|| ShellError::GenericError {
|
||||
error: "Invalid polars data type".into(),
|
||||
msg: "Missing decimal precision".into(),
|
||||
span: Some(span),
|
||||
help: None,
|
||||
inner: vec![],
|
||||
})?
|
||||
.trim();
|
||||
let precision = match next {
|
||||
"*" => None, // infer
|
||||
_ => Some(
|
||||
next.parse::<usize>()
|
||||
.map_err(|e| ShellError::GenericError {
|
||||
error: "Invalid polars data type".into(),
|
||||
msg: format!("Error in parsing decimal precision: {e}"),
|
||||
span: Some(span),
|
||||
help: None,
|
||||
inner: vec![],
|
||||
})?,
|
||||
),
|
||||
};
|
||||
|
||||
let next = split
|
||||
.next()
|
||||
.ok_or_else(|| ShellError::GenericError {
|
||||
error: "Invalid polars data type".into(),
|
||||
msg: "Missing decimal scale".into(),
|
||||
span: Some(span),
|
||||
help: None,
|
||||
inner: vec![],
|
||||
})?
|
||||
.trim();
|
||||
let scale = match next {
|
||||
"*" => Err(ShellError::GenericError {
|
||||
error: "Invalid polars data type".into(),
|
||||
msg: "`*` is not a permitted value for scale".into(),
|
||||
span: Some(span),
|
||||
help: None,
|
||||
inner: vec![],
|
||||
}),
|
||||
_ => next
|
||||
.parse::<usize>()
|
||||
.map(Some)
|
||||
.map_err(|e| ShellError::GenericError {
|
||||
error: "Invalid polars data type".into(),
|
||||
msg: format!("Error in parsing decimal precision: {e}"),
|
||||
span: Some(span),
|
||||
help: None,
|
||||
inner: vec![],
|
||||
}),
|
||||
}?;
|
||||
Ok(DataType::Decimal(precision, scale))
|
||||
}
|
||||
_ => Err(ShellError::GenericError {
|
||||
error: "Invalid polars data type".into(),
|
||||
msg: format!("Unknown type: {dtype}"),
|
||||
@ -367,6 +428,24 @@ mod test {
|
||||
assert_eq!(schema, expected);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_dtype_str_schema_decimal() {
|
||||
let dtype = "decimal<7,2>";
|
||||
let schema = str_to_dtype(dtype, Span::unknown()).unwrap();
|
||||
let expected = DataType::Decimal(Some(7usize), Some(2usize));
|
||||
assert_eq!(schema, expected);
|
||||
|
||||
// "*" is not a permitted value for scale
|
||||
let dtype = "decimal<7,*>";
|
||||
let schema = str_to_dtype(dtype, Span::unknown());
|
||||
assert!(matches!(schema, Err(ShellError::GenericError { .. })));
|
||||
|
||||
let dtype = "decimal<*,2>";
|
||||
let schema = str_to_dtype(dtype, Span::unknown()).unwrap();
|
||||
let expected = DataType::Decimal(None, Some(2usize));
|
||||
assert_eq!(schema, expected);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_dtype_str_to_schema_list_types() {
|
||||
let dtype = "list<i32>";
|
||||
@ -383,5 +462,19 @@ mod test {
|
||||
let schema = str_to_dtype(dtype, Span::unknown()).unwrap();
|
||||
let expected = DataType::List(Box::new(DataType::Datetime(TimeUnit::Milliseconds, None)));
|
||||
assert_eq!(schema, expected);
|
||||
|
||||
let dtype = "list<decimal<7,2>>";
|
||||
let schema = str_to_dtype(dtype, Span::unknown()).unwrap();
|
||||
let expected = DataType::List(Box::new(DataType::Decimal(Some(7usize), Some(2usize))));
|
||||
assert_eq!(schema, expected);
|
||||
|
||||
let dtype = "list<decimal<*,2>>";
|
||||
let schema = str_to_dtype(dtype, Span::unknown()).unwrap();
|
||||
let expected = DataType::List(Box::new(DataType::Decimal(None, Some(2usize))));
|
||||
assert_eq!(schema, expected);
|
||||
|
||||
let dtype = "list<decimal<7,*>>";
|
||||
let schema = str_to_dtype(dtype, Span::unknown());
|
||||
assert!(matches!(schema, Err(ShellError::GenericError { .. })));
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user