diff --git a/crates/nu_plugin_polars/src/cache/list.rs b/crates/nu_plugin_polars/src/cache/list.rs index c68f910a81..fc599dedd7 100644 --- a/crates/nu_plugin_polars/src/cache/list.rs +++ b/crates/nu_plugin_polars/src/cache/list.rs @@ -132,7 +132,37 @@ impl PluginCommand for ListDF { "reference_count" => Value::int(value.reference_count as i64, call.head), }, call.head, - ))) + ))), + PolarsPluginObject::NuDataType(_) => Ok(Some(Value::record( + record! { + "key" => Value::string(key.to_string(), call.head), + "created" => Value::date(value.created, call.head), + "columns" => Value::nothing(call.head), + "rows" => Value::nothing(call.head), + "type" => Value::string("DataType", call.head), + "estimated_size" => Value::nothing(call.head), + "span_contents" => Value::string(span_contents, value.span), + "span_start" => Value::int(value.span.start as i64, call.head), + "span_end" => Value::int(value.span.end as i64, call.head), + "reference_count" => Value::int(value.reference_count as i64, call.head), + }, + call.head, + ))), + PolarsPluginObject::NuSchema(_) => Ok(Some(Value::record( + record! { + "key" => Value::string(key.to_string(), call.head), + "created" => Value::date(value.created, call.head), + "columns" => Value::nothing(call.head), + "rows" => Value::nothing(call.head), + "type" => Value::string("Schema", call.head), + "estimated_size" => Value::nothing(call.head), + "span_contents" => Value::string(span_contents, value.span), + "span_start" => Value::int(value.span.start as i64, call.head), + "span_end" => Value::int(value.span.end as i64, call.head), + "reference_count" => Value::int(value.reference_count as i64, call.head), + }, + call.head, + ))), } })?; let vals = vals.into_iter().flatten().collect(); diff --git a/crates/nu_plugin_polars/src/dataframe/command/core/mod.rs b/crates/nu_plugin_polars/src/dataframe/command/core/mod.rs index fecbba2d69..c2b3d26b18 100644 --- a/crates/nu_plugin_polars/src/dataframe/command/core/mod.rs +++ b/crates/nu_plugin_polars/src/dataframe/command/core/mod.rs @@ -9,6 +9,7 @@ mod schema; mod shape; mod summary; mod to_df; +mod to_dtype; mod to_lazy; mod to_nu; mod to_repr; @@ -40,5 +41,6 @@ pub(crate) fn core_commands() -> Vec match PolarsFileType::from(ext.as_str()) { @@ -388,7 +388,7 @@ fn from_json( })?; let maybe_schema = call .get_flag("schema")? - .map(|schema| NuSchema::try_from(&schema)) + .map(|schema| NuSchema::try_from_value(plugin, &schema)) .transpose()?; let buf_reader = BufReader::new(file); @@ -429,11 +429,7 @@ fn from_ndjson( NonZeroUsize::new(DEFAULT_INFER_SCHEMA) .expect("The default infer-schema should be non zero"), ); - let maybe_schema = call - .get_flag("schema")? - .map(|schema| NuSchema::try_from(&schema)) - .transpose()?; - + let maybe_schema = get_schema(plugin, call)?; if !is_eager { let start_time = std::time::Instant::now(); @@ -511,10 +507,7 @@ fn from_csv( .unwrap_or(DEFAULT_INFER_SCHEMA); let skip_rows: Option = call.get_flag("skip-rows")?; let columns: Option> = call.get_flag("columns")?; - let maybe_schema = call - .get_flag("schema")? - .map(|schema| NuSchema::try_from(&schema)) - .transpose()?; + let maybe_schema = get_schema(plugin, call)?; let truncate_ragged_lines: bool = call.has_flag("truncate-ragged-lines")?; if !is_eager { @@ -627,12 +620,15 @@ fn cloud_not_supported(file_type: PolarsFileType, span: Span) -> ShellError { } } -fn build_hive_options(call: &EvaluatedCall) -> Result { +fn build_hive_options( + plugin: &PolarsPlugin, + call: &EvaluatedCall, +) -> Result { let enabled: Option = call.get_flag("hive-enabled")?; let hive_start_idx: Option = call.get_flag("hive-start-idx")?; let schema: Option = call .get_flag::("hive-schema")? - .map(|schema| NuSchema::try_from(&schema)) + .map(|schema| NuSchema::try_from_value(plugin, &schema)) .transpose()?; let try_parse_dates: bool = call.has_flag("hive-try-parse-dates")?; @@ -643,3 +639,11 @@ fn build_hive_options(call: &EvaluatedCall) -> Result { try_parse_dates, }) } + +fn get_schema(plugin: &PolarsPlugin, call: &EvaluatedCall) -> Result, ShellError> { + let schema: Option = call + .get_flag("schema")? + .map(|schema| NuSchema::try_from_value(plugin, &schema)) + .transpose()?; + Ok(schema) +} diff --git a/crates/nu_plugin_polars/src/dataframe/command/core/schema.rs b/crates/nu_plugin_polars/src/dataframe/command/core/schema.rs index 01451ff443..4ce1f3404c 100644 --- a/crates/nu_plugin_polars/src/dataframe/command/core/schema.rs +++ b/crates/nu_plugin_polars/src/dataframe/command/core/schema.rs @@ -1,4 +1,7 @@ -use crate::{values::PolarsPluginObject, PolarsPlugin}; +use crate::{ + values::{datatype_list, CustomValueSupport, PolarsPluginObject}, + PolarsPlugin, +}; use nu_plugin::{EngineInterface, EvaluatedCall, PluginCommand}; use nu_protocol::{ @@ -67,12 +70,12 @@ fn command( match PolarsPluginObject::try_from_pipeline(plugin, input, call.head)? { PolarsPluginObject::NuDataFrame(df) => { let schema = df.schema(); - let value: Value = schema.into(); + let value = schema.base_value(call.head)?; Ok(PipelineData::Value(value, None)) } PolarsPluginObject::NuLazyFrame(mut lazy) => { let schema = lazy.schema()?; - let value: Value = schema.into(); + let value = schema.base_value(call.head)?; Ok(PipelineData::Value(value, None)) } _ => Err(ShellError::GenericError { @@ -85,42 +88,6 @@ fn command( } } -fn datatype_list(span: Span) -> Value { - let types: Vec = [ - ("null", ""), - ("bool", ""), - ("u8", ""), - ("u16", ""), - ("u32", ""), - ("u64", ""), - ("i8", ""), - ("i16", ""), - ("i32", ""), - ("i64", ""), - ("f32", ""), - ("f64", ""), - ("str", ""), - ("binary", ""), - ("date", ""), - ("datetime", "Time Unit can be: milliseconds: ms, microseconds: us, nanoseconds: ns. Timezone wildcard is *. Other Timezone examples: UTC, America/Los_Angeles."), - ("duration", "Time Unit can be: milliseconds: ms, microseconds: us, nanoseconds: ns."), - ("time", ""), - ("object", ""), - ("unknown", ""), - ("list", ""), - ] - .iter() - .map(|(dtype, note)| { - Value::record(record! { - "dtype" => Value::string(*dtype, span), - "note" => Value::string(*note, span), - }, - span) - }) - .collect(); - Value::list(types, span) -} - #[cfg(test)] mod test { use super::*; diff --git a/crates/nu_plugin_polars/src/dataframe/command/core/to_df.rs b/crates/nu_plugin_polars/src/dataframe/command/core/to_df.rs index 90f96a8a99..72da002333 100644 --- a/crates/nu_plugin_polars/src/dataframe/command/core/to_df.rs +++ b/crates/nu_plugin_polars/src/dataframe/command/core/to_df.rs @@ -216,7 +216,7 @@ impl PluginCommand for ToDataFrame { ) -> Result { let maybe_schema = call .get_flag("schema")? - .map(|schema| NuSchema::try_from(&schema)) + .map(|schema| NuSchema::try_from_value(plugin, &schema)) .transpose()?; debug!("schema: {:?}", maybe_schema); diff --git a/crates/nu_plugin_polars/src/dataframe/command/core/to_dtype.rs b/crates/nu_plugin_polars/src/dataframe/command/core/to_dtype.rs new file mode 100644 index 0000000000..f62d0c2b23 --- /dev/null +++ b/crates/nu_plugin_polars/src/dataframe/command/core/to_dtype.rs @@ -0,0 +1,55 @@ +use nu_plugin::PluginCommand; +use nu_protocol::{Category, Example, ShellError, Signature, Span, Type, Value}; + +use crate::{ + values::{CustomValueSupport, NuDataType}, + PolarsPlugin, +}; + +pub struct ToDataType; + +impl PluginCommand for ToDataType { + type Plugin = PolarsPlugin; + + fn name(&self) -> &str { + "polars into-dtype" + } + + fn description(&self) -> &str { + "Convert a string to a specific datatype." + } + + fn signature(&self) -> Signature { + Signature::build(self.name()) + .input_output_type(Type::String, Type::Custom("datatype".into())) + .category(Category::Custom("dataframe".into())) + } + + fn examples(&self) -> Vec { + vec![Example { + description: "Convert a string to a specific datatype", + example: r#""i64" | polars into-dtype"#, + result: Some(Value::string("i64", Span::test_data())), + }] + } + + fn run( + &self, + plugin: &Self::Plugin, + engine: &nu_plugin::EngineInterface, + call: &nu_plugin::EvaluatedCall, + input: nu_protocol::PipelineData, + ) -> Result { + command(plugin, engine, call, input).map_err(nu_protocol::LabeledError::from) + } +} + +fn command( + plugin: &PolarsPlugin, + engine: &nu_plugin::EngineInterface, + call: &nu_plugin::EvaluatedCall, + input: nu_protocol::PipelineData, +) -> Result { + NuDataType::try_from_pipeline(plugin, input, call.head)? + .to_pipeline_data(plugin, engine, call.head) +} diff --git a/crates/nu_plugin_polars/src/dataframe/command/core/to_lazy.rs b/crates/nu_plugin_polars/src/dataframe/command/core/to_lazy.rs index 6086925aac..ccffde21ff 100644 --- a/crates/nu_plugin_polars/src/dataframe/command/core/to_lazy.rs +++ b/crates/nu_plugin_polars/src/dataframe/command/core/to_lazy.rs @@ -56,7 +56,7 @@ impl PluginCommand for ToLazyFrame { ) -> Result { let maybe_schema = call .get_flag("schema")? - .map(|schema| NuSchema::try_from(&schema)) + .map(|schema| NuSchema::try_from_value(plugin, &schema)) .transpose()?; let df = NuDataFrame::try_from_iter(plugin, input.into_iter(), maybe_schema)?; diff --git a/crates/nu_plugin_polars/src/dataframe/values/mod.rs b/crates/nu_plugin_polars/src/dataframe/values/mod.rs index fec45fe1bb..5184049ed7 100644 --- a/crates/nu_plugin_polars/src/dataframe/values/mod.rs +++ b/crates/nu_plugin_polars/src/dataframe/values/mod.rs @@ -1,5 +1,6 @@ mod file_type; mod nu_dataframe; +mod nu_dtype; mod nu_expression; mod nu_lazyframe; mod nu_lazygroupby; @@ -8,19 +9,23 @@ mod nu_when; pub mod utils; use crate::{Cacheable, PolarsPlugin}; +use nu_dtype::custom_value::NuDataTypeCustomValue; use nu_plugin::EngineInterface; use nu_protocol::{ ast::Operator, CustomValue, PipelineData, ShellError, Span, Spanned, Type, Value, }; +use nu_schema::custom_value::NuSchemaCustomValue; use std::{cmp::Ordering, fmt}; use uuid::Uuid; pub use file_type::PolarsFileType; pub use nu_dataframe::{Axis, Column, NuDataFrame, NuDataFrameCustomValue}; +pub use nu_dtype::NuDataType; +pub use nu_dtype::{datatype_list, str_to_dtype}; pub use nu_expression::{NuExpression, NuExpressionCustomValue}; pub use nu_lazyframe::{NuLazyFrame, NuLazyFrameCustomValue}; pub use nu_lazygroupby::{NuLazyGroupBy, NuLazyGroupByCustomValue}; -pub use nu_schema::{str_to_dtype, NuSchema}; +pub use nu_schema::NuSchema; pub use nu_when::{NuWhen, NuWhenCustomValue, NuWhenType}; #[derive(Debug, Clone)] @@ -31,6 +36,8 @@ pub enum PolarsPluginType { NuLazyGroupBy, NuWhen, NuPolarsTestData, + NuDataType, + NuSchema, } impl fmt::Display for PolarsPluginType { @@ -42,6 +49,8 @@ impl fmt::Display for PolarsPluginType { Self::NuLazyGroupBy => write!(f, "NuLazyGroupBy"), Self::NuWhen => write!(f, "NuWhen"), Self::NuPolarsTestData => write!(f, "NuPolarsTestData"), + Self::NuDataType => write!(f, "NuDataType"), + Self::NuSchema => write!(f, "NuSchema"), } } } @@ -54,6 +63,8 @@ pub enum PolarsPluginObject { NuLazyGroupBy(NuLazyGroupBy), NuWhen(NuWhen), NuPolarsTestData(Uuid, String), + NuDataType(NuDataType), + NuSchema(NuSchema), } impl PolarsPluginObject { @@ -71,6 +82,10 @@ impl PolarsPluginObject { NuLazyGroupBy::try_from_value(plugin, value).map(PolarsPluginObject::NuLazyGroupBy) } else if NuWhen::can_downcast(value) { NuWhen::try_from_value(plugin, value).map(PolarsPluginObject::NuWhen) + } else if NuSchema::can_downcast(value) { + NuSchema::try_from_value(plugin, value).map(PolarsPluginObject::NuSchema) + } else if NuDataType::can_downcast(value) { + NuDataType::try_from_value(plugin, value).map(PolarsPluginObject::NuDataType) } else { Err(cant_convert_err( value, @@ -80,6 +95,8 @@ impl PolarsPluginObject { PolarsPluginType::NuExpression, PolarsPluginType::NuLazyGroupBy, PolarsPluginType::NuWhen, + PolarsPluginType::NuDataType, + PolarsPluginType::NuSchema, ], )) } @@ -102,6 +119,8 @@ impl PolarsPluginObject { Self::NuLazyGroupBy(_) => PolarsPluginType::NuLazyGroupBy, Self::NuWhen(_) => PolarsPluginType::NuWhen, Self::NuPolarsTestData(_, _) => PolarsPluginType::NuPolarsTestData, + Self::NuDataType(_) => PolarsPluginType::NuDataType, + Self::NuSchema(_) => PolarsPluginType::NuSchema, } } @@ -113,6 +132,8 @@ impl PolarsPluginObject { PolarsPluginObject::NuLazyGroupBy(lg) => lg.id, PolarsPluginObject::NuWhen(w) => w.id, PolarsPluginObject::NuPolarsTestData(id, _) => *id, + PolarsPluginObject::NuDataType(dt) => dt.id, + PolarsPluginObject::NuSchema(schema) => schema.id, } } @@ -126,6 +147,8 @@ impl PolarsPluginObject { PolarsPluginObject::NuPolarsTestData(id, s) => { Value::string(format!("{id}:{s}"), Span::test_data()) } + PolarsPluginObject::NuDataType(dt) => dt.into_value(span), + PolarsPluginObject::NuSchema(schema) => schema.into_value(span), } } @@ -151,6 +174,8 @@ pub enum CustomValueType { NuExpression(NuExpressionCustomValue), NuLazyGroupBy(NuLazyGroupByCustomValue), NuWhen(NuWhenCustomValue), + NuDataType(NuDataTypeCustomValue), + NuSchema(NuSchemaCustomValue), } impl CustomValueType { @@ -161,6 +186,8 @@ impl CustomValueType { CustomValueType::NuExpression(e_cv) => e_cv.id, CustomValueType::NuLazyGroupBy(lg_cv) => lg_cv.id, CustomValueType::NuWhen(w_cv) => w_cv.id, + CustomValueType::NuDataType(dt_cv) => dt_cv.id, + CustomValueType::NuSchema(schema_cv) => schema_cv.id, } } @@ -175,6 +202,10 @@ impl CustomValueType { Ok(CustomValueType::NuLazyGroupBy(lg_cv.clone())) } else if let Some(w_cv) = val.as_any().downcast_ref::() { Ok(CustomValueType::NuWhen(w_cv.clone())) + } else if let Some(w_cv) = val.as_any().downcast_ref::() { + Ok(CustomValueType::NuDataType(w_cv.clone())) + } else if let Some(w_cv) = val.as_any().downcast_ref::() { + Ok(CustomValueType::NuSchema(w_cv.clone())) } else { Err(ShellError::CantConvert { to_type: "physical type".into(), @@ -381,3 +412,198 @@ pub trait CustomValueSupport: Cacheable { )) } } + +#[cfg(test)] +mod test { + use polars::prelude::{DataType, TimeUnit, UnknownKind}; + + use super::*; + + #[test] + fn test_dtype_str_to_schema_simple_types() { + let dtype = "bool"; + let schema = str_to_dtype(dtype, Span::unknown()).unwrap(); + let expected = DataType::Boolean; + assert_eq!(schema, expected); + + let dtype = "u8"; + let schema = str_to_dtype(dtype, Span::unknown()).unwrap(); + let expected = DataType::UInt8; + assert_eq!(schema, expected); + + let dtype = "u16"; + let schema = str_to_dtype(dtype, Span::unknown()).unwrap(); + let expected = DataType::UInt16; + assert_eq!(schema, expected); + + let dtype = "u32"; + let schema = str_to_dtype(dtype, Span::unknown()).unwrap(); + let expected = DataType::UInt32; + assert_eq!(schema, expected); + + let dtype = "u64"; + let schema = str_to_dtype(dtype, Span::unknown()).unwrap(); + let expected = DataType::UInt64; + assert_eq!(schema, expected); + + let dtype = "i8"; + let schema = str_to_dtype(dtype, Span::unknown()).unwrap(); + let expected = DataType::Int8; + assert_eq!(schema, expected); + + let dtype = "i16"; + let schema = str_to_dtype(dtype, Span::unknown()).unwrap(); + let expected = DataType::Int16; + assert_eq!(schema, expected); + + let dtype = "i32"; + let schema = str_to_dtype(dtype, Span::unknown()).unwrap(); + let expected = DataType::Int32; + assert_eq!(schema, expected); + + let dtype = "i64"; + let schema = str_to_dtype(dtype, Span::unknown()).unwrap(); + let expected = DataType::Int64; + assert_eq!(schema, expected); + + let dtype = "str"; + let schema = str_to_dtype(dtype, Span::unknown()).unwrap(); + let expected = DataType::String; + assert_eq!(schema, expected); + + let dtype = "binary"; + let schema = str_to_dtype(dtype, Span::unknown()).unwrap(); + let expected = DataType::Binary; + assert_eq!(schema, expected); + + let dtype = "date"; + let schema = str_to_dtype(dtype, Span::unknown()).unwrap(); + let expected = DataType::Date; + assert_eq!(schema, expected); + + let dtype = "time"; + let schema = str_to_dtype(dtype, Span::unknown()).unwrap(); + let expected = DataType::Time; + assert_eq!(schema, expected); + + let dtype = "null"; + let schema = str_to_dtype(dtype, Span::unknown()).unwrap(); + let expected = DataType::Null; + assert_eq!(schema, expected); + + let dtype = "unknown"; + let schema = str_to_dtype(dtype, Span::unknown()).unwrap(); + let expected = DataType::Unknown(UnknownKind::Any); + assert_eq!(schema, expected); + + let dtype = "object"; + let schema = str_to_dtype(dtype, Span::unknown()).unwrap(); + let expected = DataType::Object("unknown", None); + assert_eq!(schema, expected); + } + + #[test] + fn test_dtype_str_schema_datetime() { + let dtype = "datetime"; + let schema = str_to_dtype(dtype, Span::unknown()).unwrap(); + let expected = DataType::Datetime(TimeUnit::Milliseconds, None); + assert_eq!(schema, expected); + + let dtype = "datetime"; + let schema = str_to_dtype(dtype, Span::unknown()).unwrap(); + let expected = DataType::Datetime(TimeUnit::Microseconds, None); + assert_eq!(schema, expected); + + let dtype = "datetime<μs, *>"; + let schema = str_to_dtype(dtype, Span::unknown()).unwrap(); + let expected = DataType::Datetime(TimeUnit::Microseconds, None); + assert_eq!(schema, expected); + + let dtype = "datetime"; + let schema = str_to_dtype(dtype, Span::unknown()).unwrap(); + let expected = DataType::Datetime(TimeUnit::Nanoseconds, None); + assert_eq!(schema, expected); + + let dtype = "datetime"; + let schema = str_to_dtype(dtype, Span::unknown()).unwrap(); + let expected = DataType::Datetime(TimeUnit::Milliseconds, Some("UTC".into())); + assert_eq!(schema, expected); + + let dtype = "invalid"; + let schema = str_to_dtype(dtype, Span::unknown()); + assert!(schema.is_err()) + } + + #[test] + fn test_dtype_str_schema_duration() { + let dtype = "duration"; + let schema = str_to_dtype(dtype, Span::unknown()).unwrap(); + let expected = DataType::Duration(TimeUnit::Milliseconds); + assert_eq!(schema, expected); + + let dtype = "duration"; + let schema = str_to_dtype(dtype, Span::unknown()).unwrap(); + let expected = DataType::Duration(TimeUnit::Microseconds); + assert_eq!(schema, expected); + + let dtype = "duration<μs>"; + let schema = str_to_dtype(dtype, Span::unknown()).unwrap(); + let expected = DataType::Duration(TimeUnit::Microseconds); + assert_eq!(schema, expected); + + let dtype = "duration"; + let schema = str_to_dtype(dtype, Span::unknown()).unwrap(); + let expected = DataType::Duration(TimeUnit::Nanoseconds); + assert_eq!(schema, expected); + } + + #[test] + fn test_dtype_str_schema_decimal() { + let dtype = "decimal<7,2>"; + let schema = str_to_dtype(dtype, Span::unknown()).unwrap(); + let expected = DataType::Decimal(Some(7usize), Some(2usize)); + assert_eq!(schema, expected); + + // "*" is not a permitted value for scale + let dtype = "decimal<7,*>"; + let schema = str_to_dtype(dtype, Span::unknown()); + assert!(matches!(schema, Err(ShellError::GenericError { .. }))); + + let dtype = "decimal<*,2>"; + let schema = str_to_dtype(dtype, Span::unknown()).unwrap(); + let expected = DataType::Decimal(None, Some(2usize)); + assert_eq!(schema, expected); + } + + #[test] + fn test_dtype_str_to_schema_list_types() { + let dtype = "list"; + let schema = str_to_dtype(dtype, Span::unknown()).unwrap(); + let expected = DataType::List(Box::new(DataType::Int32)); + assert_eq!(schema, expected); + + let dtype = "list>"; + let schema = str_to_dtype(dtype, Span::unknown()).unwrap(); + let expected = DataType::List(Box::new(DataType::Duration(TimeUnit::Milliseconds))); + assert_eq!(schema, expected); + + let dtype = "list>"; + let schema = str_to_dtype(dtype, Span::unknown()).unwrap(); + let expected = DataType::List(Box::new(DataType::Datetime(TimeUnit::Milliseconds, None))); + assert_eq!(schema, expected); + + let dtype = "list>"; + let schema = str_to_dtype(dtype, Span::unknown()).unwrap(); + let expected = DataType::List(Box::new(DataType::Decimal(Some(7usize), Some(2usize)))); + assert_eq!(schema, expected); + + let dtype = "list>"; + let schema = str_to_dtype(dtype, Span::unknown()).unwrap(); + let expected = DataType::List(Box::new(DataType::Decimal(None, Some(2usize)))); + assert_eq!(schema, expected); + + let dtype = "list>"; + let schema = str_to_dtype(dtype, Span::unknown()); + assert!(matches!(schema, Err(ShellError::GenericError { .. }))); + } +} diff --git a/crates/nu_plugin_polars/src/dataframe/values/nu_dtype/custom_value.rs b/crates/nu_plugin_polars/src/dataframe/values/nu_dtype/custom_value.rs new file mode 100644 index 0000000000..0be9ceb71e --- /dev/null +++ b/crates/nu_plugin_polars/src/dataframe/values/nu_dtype/custom_value.rs @@ -0,0 +1,65 @@ +use nu_protocol::{CustomValue, ShellError, Span, Value}; +use serde::{Deserialize, Serialize}; +use uuid::Uuid; + +use crate::values::{CustomValueSupport, PolarsPluginCustomValue}; + +use super::NuDataType; + +#[derive(Clone, Debug, Deserialize, Serialize)] +pub struct NuDataTypeCustomValue { + pub id: Uuid, + #[serde(skip)] + pub datatype: Option, +} + +#[typetag::serde] +impl CustomValue for NuDataTypeCustomValue { + fn clone_value(&self, span: nu_protocol::Span) -> Value { + Value::custom(Box::new(self.clone()), span) + } + + fn type_name(&self) -> String { + "NuDataType".into() + } + + fn to_base_value(&self, span: Span) -> Result { + Ok(Value::string( + "NuDataType: custom_value_to_base_value should've been called", + span, + )) + } + + fn as_mut_any(&mut self) -> &mut dyn std::any::Any { + self + } + + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn notify_plugin_on_drop(&self) -> bool { + true + } +} + +impl PolarsPluginCustomValue for NuDataTypeCustomValue { + type PolarsPluginObjectType = NuDataType; + + fn id(&self) -> &Uuid { + &self.id + } + + fn internal(&self) -> &Option { + &self.datatype + } + + fn custom_value_to_base_value( + &self, + plugin: &crate::PolarsPlugin, + _engine: &nu_plugin::EngineInterface, + ) -> Result { + let dtype = NuDataType::try_from_custom_value(plugin, self)?; + dtype.base_value(Span::unknown()) + } +} diff --git a/crates/nu_plugin_polars/src/dataframe/values/nu_dtype/mod.rs b/crates/nu_plugin_polars/src/dataframe/values/nu_dtype/mod.rs new file mode 100644 index 0000000000..58bda94719 --- /dev/null +++ b/crates/nu_plugin_polars/src/dataframe/values/nu_dtype/mod.rs @@ -0,0 +1,316 @@ +pub mod custom_value; + +use custom_value::NuDataTypeCustomValue; +use nu_protocol::{record, ShellError, Span, Value}; +use polars::prelude::{DataType, PlSmallStr, TimeUnit, UnknownKind}; +use uuid::Uuid; + +use crate::{Cacheable, PolarsPlugin}; + +use super::{nu_schema::dtype_to_value, CustomValueSupport, PolarsPluginObject, PolarsPluginType}; + +#[derive(Debug, Clone)] +pub struct NuDataType { + pub id: uuid::Uuid, + dtype: DataType, +} + +impl NuDataType { + pub fn new(dtype: DataType) -> Self { + Self { + id: uuid::Uuid::new_v4(), + dtype, + } + } + + pub fn new_with_str(dtype: &str, span: Span) -> Result { + let dtype = str_to_dtype(dtype, span)?; + Ok(Self { + id: uuid::Uuid::new_v4(), + dtype, + }) + } + + pub fn to_polars(&self) -> DataType { + self.dtype.clone() + } +} + +impl From for Value { + fn from(nu_dtype: NuDataType) -> Self { + Value::String { + val: nu_dtype.dtype.to_string(), + internal_span: Span::unknown(), + } + } +} + +impl Cacheable for NuDataType { + fn cache_id(&self) -> &Uuid { + &self.id + } + + fn to_cache_value(&self) -> Result { + Ok(PolarsPluginObject::NuDataType(self.clone())) + } + + fn from_cache_value(cv: super::PolarsPluginObject) -> Result { + match cv { + PolarsPluginObject::NuDataType(dt) => Ok(dt), + _ => Err(ShellError::GenericError { + error: "Cache value is not a dataframe".into(), + msg: "".into(), + span: None, + help: None, + inner: vec![], + }), + } + } +} + +impl CustomValueSupport for NuDataType { + type CV = NuDataTypeCustomValue; + + fn get_type_static() -> super::PolarsPluginType { + PolarsPluginType::NuDataType + } + + fn custom_value(self) -> Self::CV { + NuDataTypeCustomValue { + id: self.id, + datatype: Some(self), + } + } + + fn base_value(self, span: Span) -> Result { + Ok(dtype_to_value(&self.dtype, span)) + } + + fn try_from_value(plugin: &PolarsPlugin, value: &Value) -> Result { + match value { + Value::Custom { val, .. } => { + if let Some(cv) = val.as_any().downcast_ref::() { + Self::try_from_custom_value(plugin, cv) + } else { + Err(ShellError::CantConvert { + to_type: Self::get_type_static().to_string(), + from_type: value.get_type().to_string(), + span: value.span(), + help: None, + }) + } + } + Value::String { val, internal_span } => NuDataType::new_with_str(val, *internal_span), + _ => Err(ShellError::CantConvert { + to_type: Self::get_type_static().to_string(), + from_type: value.get_type().to_string(), + span: value.span(), + help: None, + }), + } + } +} + +pub fn datatype_list(span: Span) -> Value { + let types: Vec = [ + ("null", ""), + ("bool", ""), + ("u8", ""), + ("u16", ""), + ("u32", ""), + ("u64", ""), + ("i8", ""), + ("i16", ""), + ("i32", ""), + ("i64", ""), + ("f32", ""), + ("f64", ""), + ("str", ""), + ("binary", ""), + ("date", ""), + ("datetime", "Time Unit can be: milliseconds: ms, microseconds: us, nanoseconds: ns. Timezone wildcard is *. Other Timezone examples: UTC, America/Los_Angeles."), + ("duration", "Time Unit can be: milliseconds: ms, microseconds: us, nanoseconds: ns."), + ("time", ""), + ("object", ""), + ("unknown", ""), + ("list", ""), + ] + .iter() + .map(|(dtype, note)| { + Value::record(record! { + "dtype" => Value::string(*dtype, span), + "note" => Value::string(*note, span), + }, + span) + }) + .collect(); + Value::list(types, span) +} + +pub fn str_to_dtype(dtype: &str, span: Span) -> Result { + match dtype { + "bool" => Ok(DataType::Boolean), + "u8" => Ok(DataType::UInt8), + "u16" => Ok(DataType::UInt16), + "u32" => Ok(DataType::UInt32), + "u64" => Ok(DataType::UInt64), + "i8" => Ok(DataType::Int8), + "i16" => Ok(DataType::Int16), + "i32" => Ok(DataType::Int32), + "i64" => Ok(DataType::Int64), + "f32" => Ok(DataType::Float32), + "f64" => Ok(DataType::Float64), + "str" => Ok(DataType::String), + "binary" => Ok(DataType::Binary), + "date" => Ok(DataType::Date), + "time" => Ok(DataType::Time), + "null" => Ok(DataType::Null), + "unknown" => Ok(DataType::Unknown(UnknownKind::Any)), + "object" => Ok(DataType::Object("unknown", None)), + _ if dtype.starts_with("list") => { + let dtype = dtype + .trim_start_matches("list") + .trim_start_matches('<') + .trim_end_matches('>') + .trim(); + let dtype = str_to_dtype(dtype, span)?; + Ok(DataType::List(Box::new(dtype))) + } + _ if dtype.starts_with("datetime") => { + let dtype = dtype + .trim_start_matches("datetime") + .trim_start_matches('<') + .trim_end_matches('>'); + let mut split = dtype.split(','); + let next = split + .next() + .ok_or_else(|| ShellError::GenericError { + error: "Invalid polars data type".into(), + msg: "Missing time unit".into(), + span: Some(span), + help: None, + inner: vec![], + })? + .trim(); + let time_unit = str_to_time_unit(next, span)?; + let next = split + .next() + .ok_or_else(|| ShellError::GenericError { + error: "Invalid polars data type".into(), + msg: "Missing time zone".into(), + span: Some(span), + help: None, + inner: vec![], + })? + .trim(); + let timezone = if "*" == next { + None + } else { + Some(next.to_string()) + }; + Ok(DataType::Datetime( + time_unit, + timezone.map(PlSmallStr::from), + )) + } + _ if dtype.starts_with("duration") => { + let inner = dtype.trim_start_matches("duration<").trim_end_matches('>'); + let next = inner + .split(',') + .next() + .ok_or_else(|| ShellError::GenericError { + error: "Invalid polars data type".into(), + msg: "Missing time unit".into(), + span: Some(span), + help: None, + inner: vec![], + })? + .trim(); + let time_unit = str_to_time_unit(next, span)?; + Ok(DataType::Duration(time_unit)) + } + _ if dtype.starts_with("decimal") => { + let dtype = dtype + .trim_start_matches("decimal") + .trim_start_matches('<') + .trim_end_matches('>'); + let mut split = dtype.split(','); + let next = split + .next() + .ok_or_else(|| ShellError::GenericError { + error: "Invalid polars data type".into(), + msg: "Missing decimal precision".into(), + span: Some(span), + help: None, + inner: vec![], + })? + .trim(); + let precision = match next { + "*" => None, // infer + _ => Some( + next.parse::() + .map_err(|e| ShellError::GenericError { + error: "Invalid polars data type".into(), + msg: format!("Error in parsing decimal precision: {e}"), + span: Some(span), + help: None, + inner: vec![], + })?, + ), + }; + + let next = split + .next() + .ok_or_else(|| ShellError::GenericError { + error: "Invalid polars data type".into(), + msg: "Missing decimal scale".into(), + span: Some(span), + help: None, + inner: vec![], + })? + .trim(); + let scale = match next { + "*" => Err(ShellError::GenericError { + error: "Invalid polars data type".into(), + msg: "`*` is not a permitted value for scale".into(), + span: Some(span), + help: None, + inner: vec![], + }), + _ => next + .parse::() + .map(Some) + .map_err(|e| ShellError::GenericError { + error: "Invalid polars data type".into(), + msg: format!("Error in parsing decimal precision: {e}"), + span: Some(span), + help: None, + inner: vec![], + }), + }?; + Ok(DataType::Decimal(precision, scale)) + } + _ => Err(ShellError::GenericError { + error: "Invalid polars data type".into(), + msg: format!("Unknown type: {dtype}"), + span: Some(span), + help: None, + inner: vec![], + }), + } +} + +fn str_to_time_unit(ts_string: &str, span: Span) -> Result { + match ts_string { + "ms" => Ok(TimeUnit::Milliseconds), + "us" | "μs" => Ok(TimeUnit::Microseconds), + "ns" => Ok(TimeUnit::Nanoseconds), + _ => Err(ShellError::GenericError { + error: "Invalid polars data type".into(), + msg: "Invalid time unit".into(), + span: Some(span), + help: None, + inner: vec![], + }), + } +} diff --git a/crates/nu_plugin_polars/src/dataframe/values/nu_schema.rs b/crates/nu_plugin_polars/src/dataframe/values/nu_schema.rs deleted file mode 100644 index 1e2ae4723a..0000000000 --- a/crates/nu_plugin_polars/src/dataframe/values/nu_schema.rs +++ /dev/null @@ -1,480 +0,0 @@ -use std::sync::Arc; - -use nu_protocol::{ShellError, Span, Value}; -use polars::{ - datatypes::UnknownKind, - prelude::{DataType, Field, PlSmallStr, Schema, SchemaExt, SchemaRef, TimeUnit}, -}; - -#[derive(Debug, Clone)] -pub struct NuSchema { - pub schema: SchemaRef, -} - -impl NuSchema { - pub fn new(schema: SchemaRef) -> Self { - Self { schema } - } -} - -impl TryFrom<&Value> for NuSchema { - type Error = ShellError; - fn try_from(value: &Value) -> Result { - let schema = value_to_schema(value, Span::unknown())?; - Ok(Self::new(Arc::new(schema))) - } -} - -impl From for Value { - fn from(schema: NuSchema) -> Self { - fields_to_value(schema.schema.iter_fields(), Span::unknown()) - } -} - -impl From for SchemaRef { - fn from(val: NuSchema) -> Self { - Arc::clone(&val.schema) - } -} - -impl From for NuSchema { - fn from(val: SchemaRef) -> Self { - Self { schema: val } - } -} - -fn fields_to_value(fields: impl Iterator, span: Span) -> Value { - let record = fields - .map(|field| { - let col = field.name().to_string(); - let val = dtype_to_value(field.dtype(), span); - (col, val) - }) - .collect(); - - Value::record(record, Span::unknown()) -} - -fn dtype_to_value(dtype: &DataType, span: Span) -> Value { - match dtype { - DataType::Struct(fields) => fields_to_value(fields.iter().cloned(), span), - _ => Value::string(dtype.to_string().replace('[', "<").replace(']', ">"), span), - } -} - -fn value_to_schema(value: &Value, span: Span) -> Result { - let fields = value_to_fields(value, span)?; - let schema = Schema::from_iter(fields); - Ok(schema) -} - -fn value_to_fields(value: &Value, span: Span) -> Result, ShellError> { - let fields = value - .as_record()? - .into_iter() - .map(|(col, val)| match val { - Value::Record { .. } => { - let fields = value_to_fields(val, span)?; - let dtype = DataType::Struct(fields); - Ok(Field::new(col.into(), dtype)) - } - _ => { - let dtype = str_to_dtype(&val.coerce_string()?, span)?; - Ok(Field::new(col.into(), dtype)) - } - }) - .collect::, ShellError>>()?; - Ok(fields) -} - -pub fn str_to_dtype(dtype: &str, span: Span) -> Result { - match dtype { - "bool" => Ok(DataType::Boolean), - "u8" => Ok(DataType::UInt8), - "u16" => Ok(DataType::UInt16), - "u32" => Ok(DataType::UInt32), - "u64" => Ok(DataType::UInt64), - "i8" => Ok(DataType::Int8), - "i16" => Ok(DataType::Int16), - "i32" => Ok(DataType::Int32), - "i64" => Ok(DataType::Int64), - "f32" => Ok(DataType::Float32), - "f64" => Ok(DataType::Float64), - "str" => Ok(DataType::String), - "binary" => Ok(DataType::Binary), - "date" => Ok(DataType::Date), - "time" => Ok(DataType::Time), - "null" => Ok(DataType::Null), - "unknown" => Ok(DataType::Unknown(UnknownKind::Any)), - "object" => Ok(DataType::Object("unknown", None)), - _ if dtype.starts_with("list") => { - let dtype = dtype - .trim_start_matches("list") - .trim_start_matches('<') - .trim_end_matches('>') - .trim(); - let dtype = str_to_dtype(dtype, span)?; - Ok(DataType::List(Box::new(dtype))) - } - _ if dtype.starts_with("datetime") => { - let dtype = dtype - .trim_start_matches("datetime") - .trim_start_matches('<') - .trim_end_matches('>'); - let mut split = dtype.split(','); - let next = split - .next() - .ok_or_else(|| ShellError::GenericError { - error: "Invalid polars data type".into(), - msg: "Missing time unit".into(), - span: Some(span), - help: None, - inner: vec![], - })? - .trim(); - let time_unit = str_to_time_unit(next, span)?; - let next = split - .next() - .ok_or_else(|| ShellError::GenericError { - error: "Invalid polars data type".into(), - msg: "Missing time zone".into(), - span: Some(span), - help: None, - inner: vec![], - })? - .trim(); - let timezone = if "*" == next { - None - } else { - Some(next.to_string()) - }; - Ok(DataType::Datetime( - time_unit, - timezone.map(PlSmallStr::from), - )) - } - _ if dtype.starts_with("duration") => { - let inner = dtype.trim_start_matches("duration<").trim_end_matches('>'); - let next = inner - .split(',') - .next() - .ok_or_else(|| ShellError::GenericError { - error: "Invalid polars data type".into(), - msg: "Missing time unit".into(), - span: Some(span), - help: None, - inner: vec![], - })? - .trim(); - let time_unit = str_to_time_unit(next, span)?; - Ok(DataType::Duration(time_unit)) - } - _ if dtype.starts_with("decimal") => { - let dtype = dtype - .trim_start_matches("decimal") - .trim_start_matches('<') - .trim_end_matches('>'); - let mut split = dtype.split(','); - let next = split - .next() - .ok_or_else(|| ShellError::GenericError { - error: "Invalid polars data type".into(), - msg: "Missing decimal precision".into(), - span: Some(span), - help: None, - inner: vec![], - })? - .trim(); - let precision = match next { - "*" => None, // infer - _ => Some( - next.parse::() - .map_err(|e| ShellError::GenericError { - error: "Invalid polars data type".into(), - msg: format!("Error in parsing decimal precision: {e}"), - span: Some(span), - help: None, - inner: vec![], - })?, - ), - }; - - let next = split - .next() - .ok_or_else(|| ShellError::GenericError { - error: "Invalid polars data type".into(), - msg: "Missing decimal scale".into(), - span: Some(span), - help: None, - inner: vec![], - })? - .trim(); - let scale = match next { - "*" => Err(ShellError::GenericError { - error: "Invalid polars data type".into(), - msg: "`*` is not a permitted value for scale".into(), - span: Some(span), - help: None, - inner: vec![], - }), - _ => next - .parse::() - .map(Some) - .map_err(|e| ShellError::GenericError { - error: "Invalid polars data type".into(), - msg: format!("Error in parsing decimal precision: {e}"), - span: Some(span), - help: None, - inner: vec![], - }), - }?; - Ok(DataType::Decimal(precision, scale)) - } - _ => Err(ShellError::GenericError { - error: "Invalid polars data type".into(), - msg: format!("Unknown type: {dtype}"), - span: Some(span), - help: None, - inner: vec![], - }), - } -} - -fn str_to_time_unit(ts_string: &str, span: Span) -> Result { - match ts_string { - "ms" => Ok(TimeUnit::Milliseconds), - "us" | "μs" => Ok(TimeUnit::Microseconds), - "ns" => Ok(TimeUnit::Nanoseconds), - _ => Err(ShellError::GenericError { - error: "Invalid polars data type".into(), - msg: "Invalid time unit".into(), - span: Some(span), - help: None, - inner: vec![], - }), - } -} - -#[cfg(test)] -mod test { - - use nu_protocol::record; - - use super::*; - - #[test] - fn test_value_to_schema() { - let address = record! { - "street" => Value::test_string("str"), - "city" => Value::test_string("str"), - }; - - let value = Value::test_record(record! { - "name" => Value::test_string("str"), - "age" => Value::test_string("i32"), - "address" => Value::test_record(address) - }); - - let schema = value_to_schema(&value, Span::unknown()).unwrap(); - let expected = Schema::from_iter(vec![ - Field::new("name".into(), DataType::String), - Field::new("age".into(), DataType::Int32), - Field::new( - "address".into(), - DataType::Struct(vec![ - Field::new("street".into(), DataType::String), - Field::new("city".into(), DataType::String), - ]), - ), - ]); - assert_eq!(schema, expected); - } - - #[test] - fn test_dtype_str_to_schema_simple_types() { - let dtype = "bool"; - let schema = str_to_dtype(dtype, Span::unknown()).unwrap(); - let expected = DataType::Boolean; - assert_eq!(schema, expected); - - let dtype = "u8"; - let schema = str_to_dtype(dtype, Span::unknown()).unwrap(); - let expected = DataType::UInt8; - assert_eq!(schema, expected); - - let dtype = "u16"; - let schema = str_to_dtype(dtype, Span::unknown()).unwrap(); - let expected = DataType::UInt16; - assert_eq!(schema, expected); - - let dtype = "u32"; - let schema = str_to_dtype(dtype, Span::unknown()).unwrap(); - let expected = DataType::UInt32; - assert_eq!(schema, expected); - - let dtype = "u64"; - let schema = str_to_dtype(dtype, Span::unknown()).unwrap(); - let expected = DataType::UInt64; - assert_eq!(schema, expected); - - let dtype = "i8"; - let schema = str_to_dtype(dtype, Span::unknown()).unwrap(); - let expected = DataType::Int8; - assert_eq!(schema, expected); - - let dtype = "i16"; - let schema = str_to_dtype(dtype, Span::unknown()).unwrap(); - let expected = DataType::Int16; - assert_eq!(schema, expected); - - let dtype = "i32"; - let schema = str_to_dtype(dtype, Span::unknown()).unwrap(); - let expected = DataType::Int32; - assert_eq!(schema, expected); - - let dtype = "i64"; - let schema = str_to_dtype(dtype, Span::unknown()).unwrap(); - let expected = DataType::Int64; - assert_eq!(schema, expected); - - let dtype = "str"; - let schema = str_to_dtype(dtype, Span::unknown()).unwrap(); - let expected = DataType::String; - assert_eq!(schema, expected); - - let dtype = "binary"; - let schema = str_to_dtype(dtype, Span::unknown()).unwrap(); - let expected = DataType::Binary; - assert_eq!(schema, expected); - - let dtype = "date"; - let schema = str_to_dtype(dtype, Span::unknown()).unwrap(); - let expected = DataType::Date; - assert_eq!(schema, expected); - - let dtype = "time"; - let schema = str_to_dtype(dtype, Span::unknown()).unwrap(); - let expected = DataType::Time; - assert_eq!(schema, expected); - - let dtype = "null"; - let schema = str_to_dtype(dtype, Span::unknown()).unwrap(); - let expected = DataType::Null; - assert_eq!(schema, expected); - - let dtype = "unknown"; - let schema = str_to_dtype(dtype, Span::unknown()).unwrap(); - let expected = DataType::Unknown(UnknownKind::Any); - assert_eq!(schema, expected); - - let dtype = "object"; - let schema = str_to_dtype(dtype, Span::unknown()).unwrap(); - let expected = DataType::Object("unknown", None); - assert_eq!(schema, expected); - } - - #[test] - fn test_dtype_str_schema_datetime() { - let dtype = "datetime"; - let schema = str_to_dtype(dtype, Span::unknown()).unwrap(); - let expected = DataType::Datetime(TimeUnit::Milliseconds, None); - assert_eq!(schema, expected); - - let dtype = "datetime"; - let schema = str_to_dtype(dtype, Span::unknown()).unwrap(); - let expected = DataType::Datetime(TimeUnit::Microseconds, None); - assert_eq!(schema, expected); - - let dtype = "datetime<μs, *>"; - let schema = str_to_dtype(dtype, Span::unknown()).unwrap(); - let expected = DataType::Datetime(TimeUnit::Microseconds, None); - assert_eq!(schema, expected); - - let dtype = "datetime"; - let schema = str_to_dtype(dtype, Span::unknown()).unwrap(); - let expected = DataType::Datetime(TimeUnit::Nanoseconds, None); - assert_eq!(schema, expected); - - let dtype = "datetime"; - let schema = str_to_dtype(dtype, Span::unknown()).unwrap(); - let expected = DataType::Datetime(TimeUnit::Milliseconds, Some("UTC".into())); - assert_eq!(schema, expected); - - let dtype = "invalid"; - let schema = str_to_dtype(dtype, Span::unknown()); - assert!(schema.is_err()) - } - - #[test] - fn test_dtype_str_schema_duration() { - let dtype = "duration"; - let schema = str_to_dtype(dtype, Span::unknown()).unwrap(); - let expected = DataType::Duration(TimeUnit::Milliseconds); - assert_eq!(schema, expected); - - let dtype = "duration"; - let schema = str_to_dtype(dtype, Span::unknown()).unwrap(); - let expected = DataType::Duration(TimeUnit::Microseconds); - assert_eq!(schema, expected); - - let dtype = "duration<μs>"; - let schema = str_to_dtype(dtype, Span::unknown()).unwrap(); - let expected = DataType::Duration(TimeUnit::Microseconds); - assert_eq!(schema, expected); - - let dtype = "duration"; - let schema = str_to_dtype(dtype, Span::unknown()).unwrap(); - let expected = DataType::Duration(TimeUnit::Nanoseconds); - assert_eq!(schema, expected); - } - - #[test] - fn test_dtype_str_schema_decimal() { - let dtype = "decimal<7,2>"; - let schema = str_to_dtype(dtype, Span::unknown()).unwrap(); - let expected = DataType::Decimal(Some(7usize), Some(2usize)); - assert_eq!(schema, expected); - - // "*" is not a permitted value for scale - let dtype = "decimal<7,*>"; - let schema = str_to_dtype(dtype, Span::unknown()); - assert!(matches!(schema, Err(ShellError::GenericError { .. }))); - - let dtype = "decimal<*,2>"; - let schema = str_to_dtype(dtype, Span::unknown()).unwrap(); - let expected = DataType::Decimal(None, Some(2usize)); - assert_eq!(schema, expected); - } - - #[test] - fn test_dtype_str_to_schema_list_types() { - let dtype = "list"; - let schema = str_to_dtype(dtype, Span::unknown()).unwrap(); - let expected = DataType::List(Box::new(DataType::Int32)); - assert_eq!(schema, expected); - - let dtype = "list>"; - let schema = str_to_dtype(dtype, Span::unknown()).unwrap(); - let expected = DataType::List(Box::new(DataType::Duration(TimeUnit::Milliseconds))); - assert_eq!(schema, expected); - - let dtype = "list>"; - let schema = str_to_dtype(dtype, Span::unknown()).unwrap(); - let expected = DataType::List(Box::new(DataType::Datetime(TimeUnit::Milliseconds, None))); - assert_eq!(schema, expected); - - let dtype = "list>"; - let schema = str_to_dtype(dtype, Span::unknown()).unwrap(); - let expected = DataType::List(Box::new(DataType::Decimal(Some(7usize), Some(2usize)))); - assert_eq!(schema, expected); - - let dtype = "list>"; - let schema = str_to_dtype(dtype, Span::unknown()).unwrap(); - let expected = DataType::List(Box::new(DataType::Decimal(None, Some(2usize)))); - assert_eq!(schema, expected); - - let dtype = "list>"; - let schema = str_to_dtype(dtype, Span::unknown()); - assert!(matches!(schema, Err(ShellError::GenericError { .. }))); - } -} diff --git a/crates/nu_plugin_polars/src/dataframe/values/nu_schema/custom_value.rs b/crates/nu_plugin_polars/src/dataframe/values/nu_schema/custom_value.rs new file mode 100644 index 0000000000..01a5aed996 --- /dev/null +++ b/crates/nu_plugin_polars/src/dataframe/values/nu_schema/custom_value.rs @@ -0,0 +1,65 @@ +use nu_protocol::{CustomValue, ShellError, Span, Value}; +use serde::{Deserialize, Serialize}; +use uuid::Uuid; + +use crate::values::{CustomValueSupport, PolarsPluginCustomValue}; + +use super::NuSchema; + +#[derive(Clone, Debug, Deserialize, Serialize)] +pub struct NuSchemaCustomValue { + pub id: Uuid, + #[serde(skip)] + pub datatype: Option, +} + +#[typetag::serde] +impl CustomValue for NuSchemaCustomValue { + fn clone_value(&self, span: nu_protocol::Span) -> Value { + Value::custom(Box::new(self.clone()), span) + } + + fn type_name(&self) -> String { + "NuSchema".into() + } + + fn to_base_value(&self, span: Span) -> Result { + Ok(Value::string( + "NuSchema: custom_value_to_base_value should've been called", + span, + )) + } + + fn as_mut_any(&mut self) -> &mut dyn std::any::Any { + self + } + + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn notify_plugin_on_drop(&self) -> bool { + true + } +} + +impl PolarsPluginCustomValue for NuSchemaCustomValue { + type PolarsPluginObjectType = NuSchema; + + fn id(&self) -> &Uuid { + &self.id + } + + fn internal(&self) -> &Option { + &self.datatype + } + + fn custom_value_to_base_value( + &self, + plugin: &crate::PolarsPlugin, + _engine: &nu_plugin::EngineInterface, + ) -> Result { + let dtype = NuSchema::try_from_custom_value(plugin, self)?; + dtype.base_value(Span::unknown()) + } +} diff --git a/crates/nu_plugin_polars/src/dataframe/values/nu_schema/mod.rs b/crates/nu_plugin_polars/src/dataframe/values/nu_schema/mod.rs new file mode 100644 index 0000000000..0dbd8cdd53 --- /dev/null +++ b/crates/nu_plugin_polars/src/dataframe/values/nu_schema/mod.rs @@ -0,0 +1,189 @@ +pub mod custom_value; + +use std::sync::Arc; + +use custom_value::NuSchemaCustomValue; +use nu_protocol::{ShellError, Span, Value}; +use polars::prelude::{DataType, Field, Schema, SchemaExt, SchemaRef}; +use uuid::Uuid; + +use crate::{Cacheable, PolarsPlugin}; + +use super::{str_to_dtype, CustomValueSupport, NuDataType, PolarsPluginObject, PolarsPluginType}; + +#[derive(Debug, Clone)] +pub struct NuSchema { + pub id: Uuid, + pub schema: SchemaRef, +} + +impl NuSchema { + pub fn new(schema: SchemaRef) -> Self { + Self { + id: Uuid::new_v4(), + schema, + } + } +} + +impl From for SchemaRef { + fn from(val: NuSchema) -> Self { + Arc::clone(&val.schema) + } +} + +impl From for NuSchema { + fn from(val: SchemaRef) -> Self { + Self::new(val) + } +} + +impl Cacheable for NuSchema { + fn cache_id(&self) -> &Uuid { + &self.id + } + + fn to_cache_value(&self) -> Result { + Ok(PolarsPluginObject::NuSchema(self.clone())) + } + + fn from_cache_value(cv: super::PolarsPluginObject) -> Result { + match cv { + PolarsPluginObject::NuSchema(dt) => Ok(dt), + _ => Err(ShellError::GenericError { + error: "Cache value is not a dataframe".into(), + msg: "".into(), + span: None, + help: None, + inner: vec![], + }), + } + } +} + +impl CustomValueSupport for NuSchema { + type CV = NuSchemaCustomValue; + + fn get_type_static() -> super::PolarsPluginType { + PolarsPluginType::NuSchema + } + + fn custom_value(self) -> Self::CV { + NuSchemaCustomValue { + id: self.id, + datatype: Some(self), + } + } + + fn base_value(self, span: Span) -> Result { + Ok(fields_to_value(self.schema.iter_fields(), span)) + } + + fn try_from_value(plugin: &PolarsPlugin, value: &Value) -> Result { + if let Value::Custom { val, .. } = value { + if let Some(cv) = val.as_any().downcast_ref::() { + Self::try_from_custom_value(plugin, cv) + } else { + Err(ShellError::CantConvert { + to_type: Self::get_type_static().to_string(), + from_type: value.get_type().to_string(), + span: value.span(), + help: None, + }) + } + } else { + let schema = value_to_schema(plugin, value, Span::unknown())?; + Ok(Self::new(Arc::new(schema))) + } + } +} + +fn fields_to_value(fields: impl Iterator, span: Span) -> Value { + let record = fields + .map(|field| { + let col = field.name().to_string(); + let val = dtype_to_value(field.dtype(), span); + (col, val) + }) + .collect(); + + Value::record(record, Span::unknown()) +} + +pub fn dtype_to_value(dtype: &DataType, span: Span) -> Value { + match dtype { + DataType::Struct(fields) => fields_to_value(fields.iter().cloned(), span), + _ => Value::string(dtype.to_string().replace('[', "<").replace(']', ">"), span), + } +} + +fn value_to_schema(plugin: &PolarsPlugin, value: &Value, span: Span) -> Result { + let fields = value_to_fields(plugin, value, span)?; + let schema = Schema::from_iter(fields); + Ok(schema) +} + +fn value_to_fields( + plugin: &PolarsPlugin, + value: &Value, + span: Span, +) -> Result, ShellError> { + let fields = value + .as_record()? + .into_iter() + .map(|(col, val)| match val { + Value::Record { .. } => { + let fields = value_to_fields(plugin, val, span)?; + let dtype = DataType::Struct(fields); + Ok(Field::new(col.into(), dtype)) + } + Value::Custom { .. } => { + let dtype = NuDataType::try_from_value(plugin, val)?; + Ok(Field::new(col.into(), dtype.to_polars())) + } + _ => { + let dtype = str_to_dtype(&val.coerce_string()?, span)?; + Ok(Field::new(col.into(), dtype)) + } + }) + .collect::, ShellError>>()?; + Ok(fields) +} + +#[cfg(test)] +mod test { + + use nu_protocol::record; + + use super::*; + + #[test] + fn test_value_to_schema() { + let plugin = PolarsPlugin::new_test_mode().expect("Failed to create plugin"); + + let address = record! { + "street" => Value::test_string("str"), + "city" => Value::test_string("str"), + }; + + let value = Value::test_record(record! { + "name" => Value::test_string("str"), + "age" => Value::test_string("i32"), + "address" => Value::test_record(address) + }); + + let schema = value_to_schema(&plugin, &value, Span::unknown()).unwrap(); + let expected = Schema::from_iter(vec![ + Field::new("name".into(), DataType::String), + Field::new("age".into(), DataType::Int32), + Field::new( + "address".into(), + DataType::Struct(vec![ + Field::new("street".into(), DataType::String), + Field::new("city".into(), DataType::String), + ]), + ), + ]); + assert_eq!(schema, expected); + } +} diff --git a/crates/nu_plugin_polars/src/lib.rs b/crates/nu_plugin_polars/src/lib.rs index f497f59567..6a0644c82c 100644 --- a/crates/nu_plugin_polars/src/lib.rs +++ b/crates/nu_plugin_polars/src/lib.rs @@ -123,6 +123,8 @@ impl Plugin for PolarsPlugin { CustomValueType::NuExpression(cv) => cv.custom_value_to_base_value(self, engine), CustomValueType::NuLazyGroupBy(cv) => cv.custom_value_to_base_value(self, engine), CustomValueType::NuWhen(cv) => cv.custom_value_to_base_value(self, engine), + CustomValueType::NuDataType(cv) => cv.custom_value_to_base_value(self, engine), + CustomValueType::NuSchema(cv) => cv.custom_value_to_base_value(self, engine), }; Ok(result?) } @@ -150,6 +152,12 @@ impl Plugin for PolarsPlugin { CustomValueType::NuWhen(cv) => { cv.custom_value_operation(self, engine, left.span, operator, right) } + CustomValueType::NuDataType(cv) => { + cv.custom_value_operation(self, engine, left.span, operator, right) + } + CustomValueType::NuSchema(cv) => { + cv.custom_value_operation(self, engine, left.span, operator, right) + } }; Ok(result?) } @@ -176,6 +184,12 @@ impl Plugin for PolarsPlugin { CustomValueType::NuWhen(cv) => { cv.custom_value_follow_path_int(self, engine, custom_value.span, index) } + CustomValueType::NuDataType(cv) => { + cv.custom_value_follow_path_int(self, engine, custom_value.span, index) + } + CustomValueType::NuSchema(cv) => { + cv.custom_value_follow_path_int(self, engine, custom_value.span, index) + } }; Ok(result?) } @@ -202,6 +216,12 @@ impl Plugin for PolarsPlugin { CustomValueType::NuWhen(cv) => { cv.custom_value_follow_path_string(self, engine, custom_value.span, column_name) } + CustomValueType::NuDataType(cv) => { + cv.custom_value_follow_path_string(self, engine, custom_value.span, column_name) + } + CustomValueType::NuSchema(cv) => { + cv.custom_value_follow_path_string(self, engine, custom_value.span, column_name) + } }; Ok(result?) } @@ -226,6 +246,10 @@ impl Plugin for PolarsPlugin { cv.custom_value_partial_cmp(self, engine, other_value) } CustomValueType::NuWhen(cv) => cv.custom_value_partial_cmp(self, engine, other_value), + CustomValueType::NuDataType(cv) => { + cv.custom_value_partial_cmp(self, engine, other_value) + } + CustomValueType::NuSchema(cv) => cv.custom_value_partial_cmp(self, engine, other_value), }; Ok(result?) }