diff --git a/crates/nu_plugin_polars/src/cache/list.rs b/crates/nu_plugin_polars/src/cache/list.rs index 4afaf9fa23..fc599dedd7 100644 --- a/crates/nu_plugin_polars/src/cache/list.rs +++ b/crates/nu_plugin_polars/src/cache/list.rs @@ -148,6 +148,21 @@ impl PluginCommand for ListDF { }, call.head, ))), + PolarsPluginObject::NuSchema(_) => Ok(Some(Value::record( + record! { + "key" => Value::string(key.to_string(), call.head), + "created" => Value::date(value.created, call.head), + "columns" => Value::nothing(call.head), + "rows" => Value::nothing(call.head), + "type" => Value::string("Schema", call.head), + "estimated_size" => Value::nothing(call.head), + "span_contents" => Value::string(span_contents, value.span), + "span_start" => Value::int(value.span.start as i64, call.head), + "span_end" => Value::int(value.span.end as i64, call.head), + "reference_count" => Value::int(value.reference_count as i64, call.head), + }, + call.head, + ))), } })?; let vals = vals.into_iter().flatten().collect(); diff --git a/crates/nu_plugin_polars/src/dataframe/command/core/schema.rs b/crates/nu_plugin_polars/src/dataframe/command/core/schema.rs index 01451ff443..4ce1f3404c 100644 --- a/crates/nu_plugin_polars/src/dataframe/command/core/schema.rs +++ b/crates/nu_plugin_polars/src/dataframe/command/core/schema.rs @@ -1,4 +1,7 @@ -use crate::{values::PolarsPluginObject, PolarsPlugin}; +use crate::{ + values::{datatype_list, CustomValueSupport, PolarsPluginObject}, + PolarsPlugin, +}; use nu_plugin::{EngineInterface, EvaluatedCall, PluginCommand}; use nu_protocol::{ @@ -67,12 +70,12 @@ fn command( match PolarsPluginObject::try_from_pipeline(plugin, input, call.head)? { PolarsPluginObject::NuDataFrame(df) => { let schema = df.schema(); - let value: Value = schema.into(); + let value = schema.base_value(call.head)?; Ok(PipelineData::Value(value, None)) } PolarsPluginObject::NuLazyFrame(mut lazy) => { let schema = lazy.schema()?; - let value: Value = schema.into(); + let value = schema.base_value(call.head)?; Ok(PipelineData::Value(value, None)) } _ => Err(ShellError::GenericError { @@ -85,42 +88,6 @@ fn command( } } -fn datatype_list(span: Span) -> Value { - let types: Vec = [ - ("null", ""), - ("bool", ""), - ("u8", ""), - ("u16", ""), - ("u32", ""), - ("u64", ""), - ("i8", ""), - ("i16", ""), - ("i32", ""), - ("i64", ""), - ("f32", ""), - ("f64", ""), - ("str", ""), - ("binary", ""), - ("date", ""), - ("datetime", "Time Unit can be: milliseconds: ms, microseconds: us, nanoseconds: ns. Timezone wildcard is *. Other Timezone examples: UTC, America/Los_Angeles."), - ("duration", "Time Unit can be: milliseconds: ms, microseconds: us, nanoseconds: ns."), - ("time", ""), - ("object", ""), - ("unknown", ""), - ("list", ""), - ] - .iter() - .map(|(dtype, note)| { - Value::record(record! { - "dtype" => Value::string(*dtype, span), - "note" => Value::string(*note, span), - }, - span) - }) - .collect(); - Value::list(types, span) -} - #[cfg(test)] mod test { use super::*; diff --git a/crates/nu_plugin_polars/src/dataframe/values/mod.rs b/crates/nu_plugin_polars/src/dataframe/values/mod.rs index efd183ed61..2b30270ebe 100644 --- a/crates/nu_plugin_polars/src/dataframe/values/mod.rs +++ b/crates/nu_plugin_polars/src/dataframe/values/mod.rs @@ -19,11 +19,11 @@ use uuid::Uuid; pub use file_type::PolarsFileType; pub use nu_dataframe::{Axis, Column, NuDataFrame, NuDataFrameCustomValue}; pub use nu_dtype::NuDataType; +pub use nu_dtype::{datatype_list, str_to_dtype}; pub use nu_expression::{NuExpression, NuExpressionCustomValue}; pub use nu_lazyframe::{NuLazyFrame, NuLazyFrameCustomValue}; pub use nu_lazygroupby::{NuLazyGroupBy, NuLazyGroupByCustomValue}; pub use nu_schema::NuSchema; -pub use nu_dtype::str_to_dtype; pub use nu_when::{NuWhen, NuWhenCustomValue, NuWhenType}; #[derive(Debug, Clone)] @@ -35,6 +35,7 @@ pub enum PolarsPluginType { NuWhen, NuPolarsTestData, NuDataType, + NuSchema, } impl fmt::Display for PolarsPluginType { @@ -47,6 +48,7 @@ impl fmt::Display for PolarsPluginType { Self::NuWhen => write!(f, "NuWhen"), Self::NuPolarsTestData => write!(f, "NuPolarsTestData"), Self::NuDataType => write!(f, "NuDataType"), + Self::NuSchema => write!(f, "NuSchema"), } } } @@ -60,6 +62,7 @@ pub enum PolarsPluginObject { NuWhen(NuWhen), NuPolarsTestData(Uuid, String), NuDataType(NuDataType), + NuSchema(NuSchema), } impl PolarsPluginObject { @@ -77,6 +80,10 @@ impl PolarsPluginObject { NuLazyGroupBy::try_from_value(plugin, value).map(PolarsPluginObject::NuLazyGroupBy) } else if NuWhen::can_downcast(value) { NuWhen::try_from_value(plugin, value).map(PolarsPluginObject::NuWhen) + } else if NuSchema::can_downcast(value) { + NuSchema::try_from_value(plugin, value).map(PolarsPluginObject::NuSchema) + } else if NuDataType::can_downcast(value) { + NuDataType::try_from_value(plugin, value).map(PolarsPluginObject::NuDataType) } else { Err(cant_convert_err( value, @@ -86,6 +93,8 @@ impl PolarsPluginObject { PolarsPluginType::NuExpression, PolarsPluginType::NuLazyGroupBy, PolarsPluginType::NuWhen, + PolarsPluginType::NuDataType, + PolarsPluginType::NuSchema, ], )) } @@ -109,6 +118,7 @@ impl PolarsPluginObject { Self::NuWhen(_) => PolarsPluginType::NuWhen, Self::NuPolarsTestData(_, _) => PolarsPluginType::NuPolarsTestData, Self::NuDataType(_) => PolarsPluginType::NuDataType, + Self::NuSchema(_) => PolarsPluginType::NuSchema, } } @@ -121,6 +131,7 @@ impl PolarsPluginObject { PolarsPluginObject::NuWhen(w) => w.id, PolarsPluginObject::NuPolarsTestData(id, _) => *id, PolarsPluginObject::NuDataType(dt) => dt.id, + PolarsPluginObject::NuSchema(schema) => schema.id, } } @@ -135,6 +146,7 @@ impl PolarsPluginObject { Value::string(format!("{id}:{s}"), Span::test_data()) } PolarsPluginObject::NuDataType(dt) => dt.into_value(span), + PolarsPluginObject::NuSchema(schema) => schema.into_value(span), } } @@ -391,7 +403,6 @@ pub trait CustomValueSupport: Cacheable { } } - #[cfg(test)] mod test { use polars::prelude::{DataType, TimeUnit, UnknownKind}; diff --git a/crates/nu_plugin_polars/src/dataframe/values/nu_dtype/mod.rs b/crates/nu_plugin_polars/src/dataframe/values/nu_dtype/mod.rs index d5a8f9e2d1..ec32bdc5ba 100644 --- a/crates/nu_plugin_polars/src/dataframe/values/nu_dtype/mod.rs +++ b/crates/nu_plugin_polars/src/dataframe/values/nu_dtype/mod.rs @@ -1,16 +1,13 @@ mod custom_value; use custom_value::NuDataTypeCustomValue; -use nu_protocol::{ShellError, Span, Value}; +use nu_protocol::{record, ShellError, Span, Value}; use polars::prelude::{DataType, PlSmallStr, TimeUnit, UnknownKind}; use uuid::Uuid; use crate::Cacheable; -use super::{ - nu_schema::dtype_to_value, CustomValueSupport, PolarsPluginObject, - PolarsPluginType, -}; +use super::{nu_schema::dtype_to_value, CustomValueSupport, PolarsPluginObject, PolarsPluginType}; #[derive(Debug, Clone)] pub struct NuDataType { @@ -107,6 +104,42 @@ impl CustomValueSupport for NuDataType { } } +pub fn datatype_list(span: Span) -> Value { + let types: Vec = [ + ("null", ""), + ("bool", ""), + ("u8", ""), + ("u16", ""), + ("u32", ""), + ("u64", ""), + ("i8", ""), + ("i16", ""), + ("i32", ""), + ("i64", ""), + ("f32", ""), + ("f64", ""), + ("str", ""), + ("binary", ""), + ("date", ""), + ("datetime", "Time Unit can be: milliseconds: ms, microseconds: us, nanoseconds: ns. Timezone wildcard is *. Other Timezone examples: UTC, America/Los_Angeles."), + ("duration", "Time Unit can be: milliseconds: ms, microseconds: us, nanoseconds: ns."), + ("time", ""), + ("object", ""), + ("unknown", ""), + ("list", ""), + ] + .iter() + .map(|(dtype, note)| { + Value::record(record! { + "dtype" => Value::string(*dtype, span), + "note" => Value::string(*note, span), + }, + span) + }) + .collect(); + Value::list(types, span) +} + pub fn str_to_dtype(dtype: &str, span: Span) -> Result { match dtype { "bool" => Ok(DataType::Boolean), @@ -274,4 +307,3 @@ fn str_to_time_unit(ts_string: &str, span: Span) -> Result }), } } - diff --git a/crates/nu_plugin_polars/src/dataframe/values/nu_schema/custom_value.rs b/crates/nu_plugin_polars/src/dataframe/values/nu_schema/custom_value.rs new file mode 100644 index 0000000000..01a5aed996 --- /dev/null +++ b/crates/nu_plugin_polars/src/dataframe/values/nu_schema/custom_value.rs @@ -0,0 +1,65 @@ +use nu_protocol::{CustomValue, ShellError, Span, Value}; +use serde::{Deserialize, Serialize}; +use uuid::Uuid; + +use crate::values::{CustomValueSupport, PolarsPluginCustomValue}; + +use super::NuSchema; + +#[derive(Clone, Debug, Deserialize, Serialize)] +pub struct NuSchemaCustomValue { + pub id: Uuid, + #[serde(skip)] + pub datatype: Option, +} + +#[typetag::serde] +impl CustomValue for NuSchemaCustomValue { + fn clone_value(&self, span: nu_protocol::Span) -> Value { + Value::custom(Box::new(self.clone()), span) + } + + fn type_name(&self) -> String { + "NuSchema".into() + } + + fn to_base_value(&self, span: Span) -> Result { + Ok(Value::string( + "NuSchema: custom_value_to_base_value should've been called", + span, + )) + } + + fn as_mut_any(&mut self) -> &mut dyn std::any::Any { + self + } + + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn notify_plugin_on_drop(&self) -> bool { + true + } +} + +impl PolarsPluginCustomValue for NuSchemaCustomValue { + type PolarsPluginObjectType = NuSchema; + + fn id(&self) -> &Uuid { + &self.id + } + + fn internal(&self) -> &Option { + &self.datatype + } + + fn custom_value_to_base_value( + &self, + plugin: &crate::PolarsPlugin, + _engine: &nu_plugin::EngineInterface, + ) -> Result { + let dtype = NuSchema::try_from_custom_value(plugin, self)?; + dtype.base_value(Span::unknown()) + } +} diff --git a/crates/nu_plugin_polars/src/dataframe/values/nu_schema/mod.rs b/crates/nu_plugin_polars/src/dataframe/values/nu_schema/mod.rs index 3a3b2416c6..ae2bd8750e 100644 --- a/crates/nu_plugin_polars/src/dataframe/values/nu_schema/mod.rs +++ b/crates/nu_plugin_polars/src/dataframe/values/nu_schema/mod.rs @@ -1,18 +1,28 @@ +mod custom_value; + use std::sync::Arc; +use custom_value::NuSchemaCustomValue; use nu_protocol::{ShellError, Span, Value}; use polars::prelude::{DataType, Field, Schema, SchemaExt, SchemaRef}; +use uuid::Uuid; -use super::str_to_dtype; +use crate::Cacheable; + +use super::{str_to_dtype, CustomValueSupport, PolarsPluginObject, PolarsPluginType}; #[derive(Debug, Clone)] pub struct NuSchema { + pub id: Uuid, pub schema: SchemaRef, } impl NuSchema { pub fn new(schema: SchemaRef) -> Self { - Self { schema } + Self { + id: Uuid::new_v4(), + schema, + } } } @@ -24,12 +34,6 @@ impl TryFrom<&Value> for NuSchema { } } -impl From for Value { - fn from(schema: NuSchema) -> Self { - fields_to_value(schema.schema.iter_fields(), Span::unknown()) - } -} - impl From for SchemaRef { fn from(val: NuSchema) -> Self { Arc::clone(&val.schema) @@ -38,7 +42,49 @@ impl From for SchemaRef { impl From for NuSchema { fn from(val: SchemaRef) -> Self { - Self { schema: val } + Self::new(val) + } +} + +impl Cacheable for NuSchema { + fn cache_id(&self) -> &Uuid { + &self.id + } + + fn to_cache_value(&self) -> Result { + Ok(PolarsPluginObject::NuSchema(self.clone())) + } + + fn from_cache_value(cv: super::PolarsPluginObject) -> Result { + match cv { + PolarsPluginObject::NuSchema(dt) => Ok(dt), + _ => Err(ShellError::GenericError { + error: "Cache value is not a dataframe".into(), + msg: "".into(), + span: None, + help: None, + inner: vec![], + }), + } + } +} + +impl CustomValueSupport for NuSchema { + type CV = NuSchemaCustomValue; + + fn get_type_static() -> super::PolarsPluginType { + PolarsPluginType::NuSchema + } + + fn custom_value(self) -> Self::CV { + NuSchemaCustomValue { + id: self.id, + datatype: Some(self), + } + } + + fn base_value(self, span: Span) -> Result { + Ok(fields_to_value(self.schema.iter_fields(), span)) } }