diff --git a/crates/nu_plugin_polars/src/dataframe/command/core/mod.rs b/crates/nu_plugin_polars/src/dataframe/command/core/mod.rs index fecbba2d69..c2b3d26b18 100644 --- a/crates/nu_plugin_polars/src/dataframe/command/core/mod.rs +++ b/crates/nu_plugin_polars/src/dataframe/command/core/mod.rs @@ -9,6 +9,7 @@ mod schema; mod shape; mod summary; mod to_df; +mod to_dtype; mod to_lazy; mod to_nu; mod to_repr; @@ -40,5 +41,6 @@ pub(crate) fn core_commands() -> Vec &str { + "polars to-dtype" + } + + fn description(&self) -> &str { + "Convert a string to a specific datatype." + } + + fn signature(&self) -> Signature { + Signature::build(self.name()) + .input_output_type(Type::String, Type::Custom("datatype".into())) + .category(Category::Custom("dataframe".into())) + } + + fn examples(&self) -> Vec { + vec![Example { + description: "Convert a string to a specific datatype", + example: r#""i64" | polars to-dtype"#, + result: Some(Value::string("i64", Span::test_data())), + }] + } + + fn run( + &self, + plugin: &Self::Plugin, + engine: &nu_plugin::EngineInterface, + call: &nu_plugin::EvaluatedCall, + input: nu_protocol::PipelineData, + ) -> Result { + command(plugin, engine, call, input).map_err(nu_protocol::LabeledError::from) + } +} + +fn command( + plugin: &PolarsPlugin, + engine: &nu_plugin::EngineInterface, + call: &nu_plugin::EvaluatedCall, + input: nu_protocol::PipelineData, +) -> Result { + NuDataType::try_from_pipeline(plugin, input, call.head)? + .to_pipeline_data(plugin, engine, call.head) +} diff --git a/crates/nu_plugin_polars/src/dataframe/values/mod.rs b/crates/nu_plugin_polars/src/dataframe/values/mod.rs index 2b30270ebe..5184049ed7 100644 --- a/crates/nu_plugin_polars/src/dataframe/values/mod.rs +++ b/crates/nu_plugin_polars/src/dataframe/values/mod.rs @@ -9,10 +9,12 @@ mod nu_when; pub mod utils; use crate::{Cacheable, PolarsPlugin}; +use nu_dtype::custom_value::NuDataTypeCustomValue; use nu_plugin::EngineInterface; use nu_protocol::{ ast::Operator, CustomValue, PipelineData, ShellError, Span, Spanned, Type, Value, }; +use nu_schema::custom_value::NuSchemaCustomValue; use std::{cmp::Ordering, fmt}; use uuid::Uuid; @@ -172,6 +174,8 @@ pub enum CustomValueType { NuExpression(NuExpressionCustomValue), NuLazyGroupBy(NuLazyGroupByCustomValue), NuWhen(NuWhenCustomValue), + NuDataType(NuDataTypeCustomValue), + NuSchema(NuSchemaCustomValue), } impl CustomValueType { @@ -182,6 +186,8 @@ impl CustomValueType { CustomValueType::NuExpression(e_cv) => e_cv.id, CustomValueType::NuLazyGroupBy(lg_cv) => lg_cv.id, CustomValueType::NuWhen(w_cv) => w_cv.id, + CustomValueType::NuDataType(dt_cv) => dt_cv.id, + CustomValueType::NuSchema(schema_cv) => schema_cv.id, } } @@ -196,6 +202,10 @@ impl CustomValueType { Ok(CustomValueType::NuLazyGroupBy(lg_cv.clone())) } else if let Some(w_cv) = val.as_any().downcast_ref::() { Ok(CustomValueType::NuWhen(w_cv.clone())) + } else if let Some(w_cv) = val.as_any().downcast_ref::() { + Ok(CustomValueType::NuDataType(w_cv.clone())) + } else if let Some(w_cv) = val.as_any().downcast_ref::() { + Ok(CustomValueType::NuSchema(w_cv.clone())) } else { Err(ShellError::CantConvert { to_type: "physical type".into(), diff --git a/crates/nu_plugin_polars/src/dataframe/values/nu_dtype/mod.rs b/crates/nu_plugin_polars/src/dataframe/values/nu_dtype/mod.rs index ec32bdc5ba..58bda94719 100644 --- a/crates/nu_plugin_polars/src/dataframe/values/nu_dtype/mod.rs +++ b/crates/nu_plugin_polars/src/dataframe/values/nu_dtype/mod.rs @@ -1,11 +1,11 @@ -mod custom_value; +pub mod custom_value; use custom_value::NuDataTypeCustomValue; use nu_protocol::{record, ShellError, Span, Value}; use polars::prelude::{DataType, PlSmallStr, TimeUnit, UnknownKind}; use uuid::Uuid; -use crate::Cacheable; +use crate::{Cacheable, PolarsPlugin}; use super::{nu_schema::dtype_to_value, CustomValueSupport, PolarsPluginObject, PolarsPluginType}; @@ -36,23 +36,6 @@ impl NuDataType { } } -impl TryFrom<&Value> for NuDataType { - type Error = ShellError; - - fn try_from(value: &Value) -> Result { - match value { - Value::String { val, internal_span } => NuDataType::new_with_str(val, *internal_span), - _ => Err(ShellError::GenericError { - error: format!("Unsupported value: {:?}", value), - msg: "".into(), - span: Some(value.span()), - help: None, - inner: vec![], - }), - } - } -} - impl From for Value { fn from(nu_dtype: NuDataType) -> Self { Value::String { @@ -102,6 +85,30 @@ impl CustomValueSupport for NuDataType { fn base_value(self, span: Span) -> Result { Ok(dtype_to_value(&self.dtype, span)) } + + fn try_from_value(plugin: &PolarsPlugin, value: &Value) -> Result { + match value { + Value::Custom { val, .. } => { + if let Some(cv) = val.as_any().downcast_ref::() { + Self::try_from_custom_value(plugin, cv) + } else { + Err(ShellError::CantConvert { + to_type: Self::get_type_static().to_string(), + from_type: value.get_type().to_string(), + span: value.span(), + help: None, + }) + } + } + Value::String { val, internal_span } => NuDataType::new_with_str(val, *internal_span), + _ => Err(ShellError::CantConvert { + to_type: Self::get_type_static().to_string(), + from_type: value.get_type().to_string(), + span: value.span(), + help: None, + }), + } + } } pub fn datatype_list(span: Span) -> Value { diff --git a/crates/nu_plugin_polars/src/dataframe/values/nu_schema/mod.rs b/crates/nu_plugin_polars/src/dataframe/values/nu_schema/mod.rs index ae2bd8750e..4011f21572 100644 --- a/crates/nu_plugin_polars/src/dataframe/values/nu_schema/mod.rs +++ b/crates/nu_plugin_polars/src/dataframe/values/nu_schema/mod.rs @@ -1,4 +1,4 @@ -mod custom_value; +pub mod custom_value; use std::sync::Arc; diff --git a/crates/nu_plugin_polars/src/lib.rs b/crates/nu_plugin_polars/src/lib.rs index f497f59567..6a0644c82c 100644 --- a/crates/nu_plugin_polars/src/lib.rs +++ b/crates/nu_plugin_polars/src/lib.rs @@ -123,6 +123,8 @@ impl Plugin for PolarsPlugin { CustomValueType::NuExpression(cv) => cv.custom_value_to_base_value(self, engine), CustomValueType::NuLazyGroupBy(cv) => cv.custom_value_to_base_value(self, engine), CustomValueType::NuWhen(cv) => cv.custom_value_to_base_value(self, engine), + CustomValueType::NuDataType(cv) => cv.custom_value_to_base_value(self, engine), + CustomValueType::NuSchema(cv) => cv.custom_value_to_base_value(self, engine), }; Ok(result?) } @@ -150,6 +152,12 @@ impl Plugin for PolarsPlugin { CustomValueType::NuWhen(cv) => { cv.custom_value_operation(self, engine, left.span, operator, right) } + CustomValueType::NuDataType(cv) => { + cv.custom_value_operation(self, engine, left.span, operator, right) + } + CustomValueType::NuSchema(cv) => { + cv.custom_value_operation(self, engine, left.span, operator, right) + } }; Ok(result?) } @@ -176,6 +184,12 @@ impl Plugin for PolarsPlugin { CustomValueType::NuWhen(cv) => { cv.custom_value_follow_path_int(self, engine, custom_value.span, index) } + CustomValueType::NuDataType(cv) => { + cv.custom_value_follow_path_int(self, engine, custom_value.span, index) + } + CustomValueType::NuSchema(cv) => { + cv.custom_value_follow_path_int(self, engine, custom_value.span, index) + } }; Ok(result?) } @@ -202,6 +216,12 @@ impl Plugin for PolarsPlugin { CustomValueType::NuWhen(cv) => { cv.custom_value_follow_path_string(self, engine, custom_value.span, column_name) } + CustomValueType::NuDataType(cv) => { + cv.custom_value_follow_path_string(self, engine, custom_value.span, column_name) + } + CustomValueType::NuSchema(cv) => { + cv.custom_value_follow_path_string(self, engine, custom_value.span, column_name) + } }; Ok(result?) } @@ -226,6 +246,10 @@ impl Plugin for PolarsPlugin { cv.custom_value_partial_cmp(self, engine, other_value) } CustomValueType::NuWhen(cv) => cv.custom_value_partial_cmp(self, engine, other_value), + CustomValueType::NuDataType(cv) => { + cv.custom_value_partial_cmp(self, engine, other_value) + } + CustomValueType::NuSchema(cv) => cv.custom_value_partial_cmp(self, engine, other_value), }; Ok(result?) }