Introduction of NuDataType and polars dtype (#15529)

# Description
This pull request does a lot of the heavy lifting needed to supported
more complex dtypes like categorical dtypes. It introduces a new
CustomValue, NuDataType and makes NuSchema a full CustomValue. Further
more it introduces a new command `polars into-dtype` that allows a dtype
to be created. This can then be passed into schemas when they are
created.

```nu
> ❯ : let dt = ("str" | polars to-dtype)

> ❯ : [[a b]; ["one" "two"]] | polars into-df -s {a: $dt, b: str} | polars schema
╭───┬─────╮
│ a │ str │
│ b │ str │
╰───┴─────╯
```

# User-Facing Changes
- Introduces new command `polars into-dtype`, allows dtype variables to
be passed in during schema creation.
This commit is contained in:
Jack Wright 2025-04-09 08:13:49 -07:00 committed by GitHub
parent 173162df2e
commit b0f9cda9b5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
14 changed files with 999 additions and 536 deletions

View File

@ -132,7 +132,37 @@ impl PluginCommand for ListDF {
"reference_count" => Value::int(value.reference_count as i64, call.head),
},
call.head,
)))
))),
PolarsPluginObject::NuDataType(_) => Ok(Some(Value::record(
record! {
"key" => Value::string(key.to_string(), call.head),
"created" => Value::date(value.created, call.head),
"columns" => Value::nothing(call.head),
"rows" => Value::nothing(call.head),
"type" => Value::string("DataType", call.head),
"estimated_size" => Value::nothing(call.head),
"span_contents" => Value::string(span_contents, value.span),
"span_start" => Value::int(value.span.start as i64, call.head),
"span_end" => Value::int(value.span.end as i64, call.head),
"reference_count" => Value::int(value.reference_count as i64, call.head),
},
call.head,
))),
PolarsPluginObject::NuSchema(_) => Ok(Some(Value::record(
record! {
"key" => Value::string(key.to_string(), call.head),
"created" => Value::date(value.created, call.head),
"columns" => Value::nothing(call.head),
"rows" => Value::nothing(call.head),
"type" => Value::string("Schema", call.head),
"estimated_size" => Value::nothing(call.head),
"span_contents" => Value::string(span_contents, value.span),
"span_start" => Value::int(value.span.start as i64, call.head),
"span_end" => Value::int(value.span.end as i64, call.head),
"reference_count" => Value::int(value.reference_count as i64, call.head),
},
call.head,
))),
}
})?;
let vals = vals.into_iter().flatten().collect();

View File

@ -9,6 +9,7 @@ mod schema;
mod shape;
mod summary;
mod to_df;
mod to_dtype;
mod to_lazy;
mod to_nu;
mod to_repr;
@ -40,5 +41,6 @@ pub(crate) fn core_commands() -> Vec<Box<dyn PluginCommand<Plugin = PolarsPlugin
Box::new(save::SaveDF),
Box::new(ToLazyFrame),
Box::new(ToRepr),
Box::new(to_dtype::ToDataType),
]
}

View File

@ -163,7 +163,7 @@ fn command(
});
}
let hive_options = build_hive_options(call)?;
let hive_options = build_hive_options(plugin, call)?;
match type_option {
Some((ext, blamed)) => match PolarsFileType::from(ext.as_str()) {
@ -388,7 +388,7 @@ fn from_json(
})?;
let maybe_schema = call
.get_flag("schema")?
.map(|schema| NuSchema::try_from(&schema))
.map(|schema| NuSchema::try_from_value(plugin, &schema))
.transpose()?;
let buf_reader = BufReader::new(file);
@ -429,11 +429,7 @@ fn from_ndjson(
NonZeroUsize::new(DEFAULT_INFER_SCHEMA)
.expect("The default infer-schema should be non zero"),
);
let maybe_schema = call
.get_flag("schema")?
.map(|schema| NuSchema::try_from(&schema))
.transpose()?;
let maybe_schema = get_schema(plugin, call)?;
if !is_eager {
let start_time = std::time::Instant::now();
@ -511,10 +507,7 @@ fn from_csv(
.unwrap_or(DEFAULT_INFER_SCHEMA);
let skip_rows: Option<usize> = call.get_flag("skip-rows")?;
let columns: Option<Vec<String>> = call.get_flag("columns")?;
let maybe_schema = call
.get_flag("schema")?
.map(|schema| NuSchema::try_from(&schema))
.transpose()?;
let maybe_schema = get_schema(plugin, call)?;
let truncate_ragged_lines: bool = call.has_flag("truncate-ragged-lines")?;
if !is_eager {
@ -627,12 +620,15 @@ fn cloud_not_supported(file_type: PolarsFileType, span: Span) -> ShellError {
}
}
fn build_hive_options(call: &EvaluatedCall) -> Result<HiveOptions, ShellError> {
fn build_hive_options(
plugin: &PolarsPlugin,
call: &EvaluatedCall,
) -> Result<HiveOptions, ShellError> {
let enabled: Option<bool> = call.get_flag("hive-enabled")?;
let hive_start_idx: Option<usize> = call.get_flag("hive-start-idx")?;
let schema: Option<NuSchema> = call
.get_flag::<Value>("hive-schema")?
.map(|schema| NuSchema::try_from(&schema))
.map(|schema| NuSchema::try_from_value(plugin, &schema))
.transpose()?;
let try_parse_dates: bool = call.has_flag("hive-try-parse-dates")?;
@ -643,3 +639,11 @@ fn build_hive_options(call: &EvaluatedCall) -> Result<HiveOptions, ShellError> {
try_parse_dates,
})
}
fn get_schema(plugin: &PolarsPlugin, call: &EvaluatedCall) -> Result<Option<NuSchema>, ShellError> {
let schema: Option<NuSchema> = call
.get_flag("schema")?
.map(|schema| NuSchema::try_from_value(plugin, &schema))
.transpose()?;
Ok(schema)
}

View File

@ -1,4 +1,7 @@
use crate::{values::PolarsPluginObject, PolarsPlugin};
use crate::{
values::{datatype_list, CustomValueSupport, PolarsPluginObject},
PolarsPlugin,
};
use nu_plugin::{EngineInterface, EvaluatedCall, PluginCommand};
use nu_protocol::{
@ -67,12 +70,12 @@ fn command(
match PolarsPluginObject::try_from_pipeline(plugin, input, call.head)? {
PolarsPluginObject::NuDataFrame(df) => {
let schema = df.schema();
let value: Value = schema.into();
let value = schema.base_value(call.head)?;
Ok(PipelineData::Value(value, None))
}
PolarsPluginObject::NuLazyFrame(mut lazy) => {
let schema = lazy.schema()?;
let value: Value = schema.into();
let value = schema.base_value(call.head)?;
Ok(PipelineData::Value(value, None))
}
_ => Err(ShellError::GenericError {
@ -85,42 +88,6 @@ fn command(
}
}
fn datatype_list(span: Span) -> Value {
let types: Vec<Value> = [
("null", ""),
("bool", ""),
("u8", ""),
("u16", ""),
("u32", ""),
("u64", ""),
("i8", ""),
("i16", ""),
("i32", ""),
("i64", ""),
("f32", ""),
("f64", ""),
("str", ""),
("binary", ""),
("date", ""),
("datetime<time_unit: (ms, us, ns) timezone (optional)>", "Time Unit can be: milliseconds: ms, microseconds: us, nanoseconds: ns. Timezone wildcard is *. Other Timezone examples: UTC, America/Los_Angeles."),
("duration<time_unit: (ms, us, ns)>", "Time Unit can be: milliseconds: ms, microseconds: us, nanoseconds: ns."),
("time", ""),
("object", ""),
("unknown", ""),
("list<dtype>", ""),
]
.iter()
.map(|(dtype, note)| {
Value::record(record! {
"dtype" => Value::string(*dtype, span),
"note" => Value::string(*note, span),
},
span)
})
.collect();
Value::list(types, span)
}
#[cfg(test)]
mod test {
use super::*;

View File

@ -216,7 +216,7 @@ impl PluginCommand for ToDataFrame {
) -> Result<PipelineData, LabeledError> {
let maybe_schema = call
.get_flag("schema")?
.map(|schema| NuSchema::try_from(&schema))
.map(|schema| NuSchema::try_from_value(plugin, &schema))
.transpose()?;
debug!("schema: {:?}", maybe_schema);

View File

@ -0,0 +1,55 @@
use nu_plugin::PluginCommand;
use nu_protocol::{Category, Example, ShellError, Signature, Span, Type, Value};
use crate::{
values::{CustomValueSupport, NuDataType},
PolarsPlugin,
};
pub struct ToDataType;
impl PluginCommand for ToDataType {
type Plugin = PolarsPlugin;
fn name(&self) -> &str {
"polars into-dtype"
}
fn description(&self) -> &str {
"Convert a string to a specific datatype."
}
fn signature(&self) -> Signature {
Signature::build(self.name())
.input_output_type(Type::String, Type::Custom("datatype".into()))
.category(Category::Custom("dataframe".into()))
}
fn examples(&self) -> Vec<Example> {
vec![Example {
description: "Convert a string to a specific datatype",
example: r#""i64" | polars into-dtype"#,
result: Some(Value::string("i64", Span::test_data())),
}]
}
fn run(
&self,
plugin: &Self::Plugin,
engine: &nu_plugin::EngineInterface,
call: &nu_plugin::EvaluatedCall,
input: nu_protocol::PipelineData,
) -> Result<nu_protocol::PipelineData, nu_protocol::LabeledError> {
command(plugin, engine, call, input).map_err(nu_protocol::LabeledError::from)
}
}
fn command(
plugin: &PolarsPlugin,
engine: &nu_plugin::EngineInterface,
call: &nu_plugin::EvaluatedCall,
input: nu_protocol::PipelineData,
) -> Result<nu_protocol::PipelineData, ShellError> {
NuDataType::try_from_pipeline(plugin, input, call.head)?
.to_pipeline_data(plugin, engine, call.head)
}

View File

@ -56,7 +56,7 @@ impl PluginCommand for ToLazyFrame {
) -> Result<PipelineData, LabeledError> {
let maybe_schema = call
.get_flag("schema")?
.map(|schema| NuSchema::try_from(&schema))
.map(|schema| NuSchema::try_from_value(plugin, &schema))
.transpose()?;
let df = NuDataFrame::try_from_iter(plugin, input.into_iter(), maybe_schema)?;

View File

@ -1,5 +1,6 @@
mod file_type;
mod nu_dataframe;
mod nu_dtype;
mod nu_expression;
mod nu_lazyframe;
mod nu_lazygroupby;
@ -8,19 +9,23 @@ mod nu_when;
pub mod utils;
use crate::{Cacheable, PolarsPlugin};
use nu_dtype::custom_value::NuDataTypeCustomValue;
use nu_plugin::EngineInterface;
use nu_protocol::{
ast::Operator, CustomValue, PipelineData, ShellError, Span, Spanned, Type, Value,
};
use nu_schema::custom_value::NuSchemaCustomValue;
use std::{cmp::Ordering, fmt};
use uuid::Uuid;
pub use file_type::PolarsFileType;
pub use nu_dataframe::{Axis, Column, NuDataFrame, NuDataFrameCustomValue};
pub use nu_dtype::NuDataType;
pub use nu_dtype::{datatype_list, str_to_dtype};
pub use nu_expression::{NuExpression, NuExpressionCustomValue};
pub use nu_lazyframe::{NuLazyFrame, NuLazyFrameCustomValue};
pub use nu_lazygroupby::{NuLazyGroupBy, NuLazyGroupByCustomValue};
pub use nu_schema::{str_to_dtype, NuSchema};
pub use nu_schema::NuSchema;
pub use nu_when::{NuWhen, NuWhenCustomValue, NuWhenType};
#[derive(Debug, Clone)]
@ -31,6 +36,8 @@ pub enum PolarsPluginType {
NuLazyGroupBy,
NuWhen,
NuPolarsTestData,
NuDataType,
NuSchema,
}
impl fmt::Display for PolarsPluginType {
@ -42,6 +49,8 @@ impl fmt::Display for PolarsPluginType {
Self::NuLazyGroupBy => write!(f, "NuLazyGroupBy"),
Self::NuWhen => write!(f, "NuWhen"),
Self::NuPolarsTestData => write!(f, "NuPolarsTestData"),
Self::NuDataType => write!(f, "NuDataType"),
Self::NuSchema => write!(f, "NuSchema"),
}
}
}
@ -54,6 +63,8 @@ pub enum PolarsPluginObject {
NuLazyGroupBy(NuLazyGroupBy),
NuWhen(NuWhen),
NuPolarsTestData(Uuid, String),
NuDataType(NuDataType),
NuSchema(NuSchema),
}
impl PolarsPluginObject {
@ -71,6 +82,10 @@ impl PolarsPluginObject {
NuLazyGroupBy::try_from_value(plugin, value).map(PolarsPluginObject::NuLazyGroupBy)
} else if NuWhen::can_downcast(value) {
NuWhen::try_from_value(plugin, value).map(PolarsPluginObject::NuWhen)
} else if NuSchema::can_downcast(value) {
NuSchema::try_from_value(plugin, value).map(PolarsPluginObject::NuSchema)
} else if NuDataType::can_downcast(value) {
NuDataType::try_from_value(plugin, value).map(PolarsPluginObject::NuDataType)
} else {
Err(cant_convert_err(
value,
@ -80,6 +95,8 @@ impl PolarsPluginObject {
PolarsPluginType::NuExpression,
PolarsPluginType::NuLazyGroupBy,
PolarsPluginType::NuWhen,
PolarsPluginType::NuDataType,
PolarsPluginType::NuSchema,
],
))
}
@ -102,6 +119,8 @@ impl PolarsPluginObject {
Self::NuLazyGroupBy(_) => PolarsPluginType::NuLazyGroupBy,
Self::NuWhen(_) => PolarsPluginType::NuWhen,
Self::NuPolarsTestData(_, _) => PolarsPluginType::NuPolarsTestData,
Self::NuDataType(_) => PolarsPluginType::NuDataType,
Self::NuSchema(_) => PolarsPluginType::NuSchema,
}
}
@ -113,6 +132,8 @@ impl PolarsPluginObject {
PolarsPluginObject::NuLazyGroupBy(lg) => lg.id,
PolarsPluginObject::NuWhen(w) => w.id,
PolarsPluginObject::NuPolarsTestData(id, _) => *id,
PolarsPluginObject::NuDataType(dt) => dt.id,
PolarsPluginObject::NuSchema(schema) => schema.id,
}
}
@ -126,6 +147,8 @@ impl PolarsPluginObject {
PolarsPluginObject::NuPolarsTestData(id, s) => {
Value::string(format!("{id}:{s}"), Span::test_data())
}
PolarsPluginObject::NuDataType(dt) => dt.into_value(span),
PolarsPluginObject::NuSchema(schema) => schema.into_value(span),
}
}
@ -151,6 +174,8 @@ pub enum CustomValueType {
NuExpression(NuExpressionCustomValue),
NuLazyGroupBy(NuLazyGroupByCustomValue),
NuWhen(NuWhenCustomValue),
NuDataType(NuDataTypeCustomValue),
NuSchema(NuSchemaCustomValue),
}
impl CustomValueType {
@ -161,6 +186,8 @@ impl CustomValueType {
CustomValueType::NuExpression(e_cv) => e_cv.id,
CustomValueType::NuLazyGroupBy(lg_cv) => lg_cv.id,
CustomValueType::NuWhen(w_cv) => w_cv.id,
CustomValueType::NuDataType(dt_cv) => dt_cv.id,
CustomValueType::NuSchema(schema_cv) => schema_cv.id,
}
}
@ -175,6 +202,10 @@ impl CustomValueType {
Ok(CustomValueType::NuLazyGroupBy(lg_cv.clone()))
} else if let Some(w_cv) = val.as_any().downcast_ref::<NuWhenCustomValue>() {
Ok(CustomValueType::NuWhen(w_cv.clone()))
} else if let Some(w_cv) = val.as_any().downcast_ref::<NuDataTypeCustomValue>() {
Ok(CustomValueType::NuDataType(w_cv.clone()))
} else if let Some(w_cv) = val.as_any().downcast_ref::<NuSchemaCustomValue>() {
Ok(CustomValueType::NuSchema(w_cv.clone()))
} else {
Err(ShellError::CantConvert {
to_type: "physical type".into(),
@ -381,3 +412,198 @@ pub trait CustomValueSupport: Cacheable {
))
}
}
#[cfg(test)]
mod test {
use polars::prelude::{DataType, TimeUnit, UnknownKind};
use super::*;
#[test]
fn test_dtype_str_to_schema_simple_types() {
let dtype = "bool";
let schema = str_to_dtype(dtype, Span::unknown()).unwrap();
let expected = DataType::Boolean;
assert_eq!(schema, expected);
let dtype = "u8";
let schema = str_to_dtype(dtype, Span::unknown()).unwrap();
let expected = DataType::UInt8;
assert_eq!(schema, expected);
let dtype = "u16";
let schema = str_to_dtype(dtype, Span::unknown()).unwrap();
let expected = DataType::UInt16;
assert_eq!(schema, expected);
let dtype = "u32";
let schema = str_to_dtype(dtype, Span::unknown()).unwrap();
let expected = DataType::UInt32;
assert_eq!(schema, expected);
let dtype = "u64";
let schema = str_to_dtype(dtype, Span::unknown()).unwrap();
let expected = DataType::UInt64;
assert_eq!(schema, expected);
let dtype = "i8";
let schema = str_to_dtype(dtype, Span::unknown()).unwrap();
let expected = DataType::Int8;
assert_eq!(schema, expected);
let dtype = "i16";
let schema = str_to_dtype(dtype, Span::unknown()).unwrap();
let expected = DataType::Int16;
assert_eq!(schema, expected);
let dtype = "i32";
let schema = str_to_dtype(dtype, Span::unknown()).unwrap();
let expected = DataType::Int32;
assert_eq!(schema, expected);
let dtype = "i64";
let schema = str_to_dtype(dtype, Span::unknown()).unwrap();
let expected = DataType::Int64;
assert_eq!(schema, expected);
let dtype = "str";
let schema = str_to_dtype(dtype, Span::unknown()).unwrap();
let expected = DataType::String;
assert_eq!(schema, expected);
let dtype = "binary";
let schema = str_to_dtype(dtype, Span::unknown()).unwrap();
let expected = DataType::Binary;
assert_eq!(schema, expected);
let dtype = "date";
let schema = str_to_dtype(dtype, Span::unknown()).unwrap();
let expected = DataType::Date;
assert_eq!(schema, expected);
let dtype = "time";
let schema = str_to_dtype(dtype, Span::unknown()).unwrap();
let expected = DataType::Time;
assert_eq!(schema, expected);
let dtype = "null";
let schema = str_to_dtype(dtype, Span::unknown()).unwrap();
let expected = DataType::Null;
assert_eq!(schema, expected);
let dtype = "unknown";
let schema = str_to_dtype(dtype, Span::unknown()).unwrap();
let expected = DataType::Unknown(UnknownKind::Any);
assert_eq!(schema, expected);
let dtype = "object";
let schema = str_to_dtype(dtype, Span::unknown()).unwrap();
let expected = DataType::Object("unknown", None);
assert_eq!(schema, expected);
}
#[test]
fn test_dtype_str_schema_datetime() {
let dtype = "datetime<ms, *>";
let schema = str_to_dtype(dtype, Span::unknown()).unwrap();
let expected = DataType::Datetime(TimeUnit::Milliseconds, None);
assert_eq!(schema, expected);
let dtype = "datetime<us, *>";
let schema = str_to_dtype(dtype, Span::unknown()).unwrap();
let expected = DataType::Datetime(TimeUnit::Microseconds, None);
assert_eq!(schema, expected);
let dtype = "datetime<μs, *>";
let schema = str_to_dtype(dtype, Span::unknown()).unwrap();
let expected = DataType::Datetime(TimeUnit::Microseconds, None);
assert_eq!(schema, expected);
let dtype = "datetime<ns, *>";
let schema = str_to_dtype(dtype, Span::unknown()).unwrap();
let expected = DataType::Datetime(TimeUnit::Nanoseconds, None);
assert_eq!(schema, expected);
let dtype = "datetime<ms, UTC>";
let schema = str_to_dtype(dtype, Span::unknown()).unwrap();
let expected = DataType::Datetime(TimeUnit::Milliseconds, Some("UTC".into()));
assert_eq!(schema, expected);
let dtype = "invalid";
let schema = str_to_dtype(dtype, Span::unknown());
assert!(schema.is_err())
}
#[test]
fn test_dtype_str_schema_duration() {
let dtype = "duration<ms>";
let schema = str_to_dtype(dtype, Span::unknown()).unwrap();
let expected = DataType::Duration(TimeUnit::Milliseconds);
assert_eq!(schema, expected);
let dtype = "duration<us>";
let schema = str_to_dtype(dtype, Span::unknown()).unwrap();
let expected = DataType::Duration(TimeUnit::Microseconds);
assert_eq!(schema, expected);
let dtype = "duration<μs>";
let schema = str_to_dtype(dtype, Span::unknown()).unwrap();
let expected = DataType::Duration(TimeUnit::Microseconds);
assert_eq!(schema, expected);
let dtype = "duration<ns>";
let schema = str_to_dtype(dtype, Span::unknown()).unwrap();
let expected = DataType::Duration(TimeUnit::Nanoseconds);
assert_eq!(schema, expected);
}
#[test]
fn test_dtype_str_schema_decimal() {
let dtype = "decimal<7,2>";
let schema = str_to_dtype(dtype, Span::unknown()).unwrap();
let expected = DataType::Decimal(Some(7usize), Some(2usize));
assert_eq!(schema, expected);
// "*" is not a permitted value for scale
let dtype = "decimal<7,*>";
let schema = str_to_dtype(dtype, Span::unknown());
assert!(matches!(schema, Err(ShellError::GenericError { .. })));
let dtype = "decimal<*,2>";
let schema = str_to_dtype(dtype, Span::unknown()).unwrap();
let expected = DataType::Decimal(None, Some(2usize));
assert_eq!(schema, expected);
}
#[test]
fn test_dtype_str_to_schema_list_types() {
let dtype = "list<i32>";
let schema = str_to_dtype(dtype, Span::unknown()).unwrap();
let expected = DataType::List(Box::new(DataType::Int32));
assert_eq!(schema, expected);
let dtype = "list<duration<ms>>";
let schema = str_to_dtype(dtype, Span::unknown()).unwrap();
let expected = DataType::List(Box::new(DataType::Duration(TimeUnit::Milliseconds)));
assert_eq!(schema, expected);
let dtype = "list<datetime<ms, *>>";
let schema = str_to_dtype(dtype, Span::unknown()).unwrap();
let expected = DataType::List(Box::new(DataType::Datetime(TimeUnit::Milliseconds, None)));
assert_eq!(schema, expected);
let dtype = "list<decimal<7,2>>";
let schema = str_to_dtype(dtype, Span::unknown()).unwrap();
let expected = DataType::List(Box::new(DataType::Decimal(Some(7usize), Some(2usize))));
assert_eq!(schema, expected);
let dtype = "list<decimal<*,2>>";
let schema = str_to_dtype(dtype, Span::unknown()).unwrap();
let expected = DataType::List(Box::new(DataType::Decimal(None, Some(2usize))));
assert_eq!(schema, expected);
let dtype = "list<decimal<7,*>>";
let schema = str_to_dtype(dtype, Span::unknown());
assert!(matches!(schema, Err(ShellError::GenericError { .. })));
}
}

View File

@ -0,0 +1,65 @@
use nu_protocol::{CustomValue, ShellError, Span, Value};
use serde::{Deserialize, Serialize};
use uuid::Uuid;
use crate::values::{CustomValueSupport, PolarsPluginCustomValue};
use super::NuDataType;
#[derive(Clone, Debug, Deserialize, Serialize)]
pub struct NuDataTypeCustomValue {
pub id: Uuid,
#[serde(skip)]
pub datatype: Option<NuDataType>,
}
#[typetag::serde]
impl CustomValue for NuDataTypeCustomValue {
fn clone_value(&self, span: nu_protocol::Span) -> Value {
Value::custom(Box::new(self.clone()), span)
}
fn type_name(&self) -> String {
"NuDataType".into()
}
fn to_base_value(&self, span: Span) -> Result<Value, ShellError> {
Ok(Value::string(
"NuDataType: custom_value_to_base_value should've been called",
span,
))
}
fn as_mut_any(&mut self) -> &mut dyn std::any::Any {
self
}
fn as_any(&self) -> &dyn std::any::Any {
self
}
fn notify_plugin_on_drop(&self) -> bool {
true
}
}
impl PolarsPluginCustomValue for NuDataTypeCustomValue {
type PolarsPluginObjectType = NuDataType;
fn id(&self) -> &Uuid {
&self.id
}
fn internal(&self) -> &Option<Self::PolarsPluginObjectType> {
&self.datatype
}
fn custom_value_to_base_value(
&self,
plugin: &crate::PolarsPlugin,
_engine: &nu_plugin::EngineInterface,
) -> Result<Value, ShellError> {
let dtype = NuDataType::try_from_custom_value(plugin, self)?;
dtype.base_value(Span::unknown())
}
}

View File

@ -0,0 +1,316 @@
pub mod custom_value;
use custom_value::NuDataTypeCustomValue;
use nu_protocol::{record, ShellError, Span, Value};
use polars::prelude::{DataType, PlSmallStr, TimeUnit, UnknownKind};
use uuid::Uuid;
use crate::{Cacheable, PolarsPlugin};
use super::{nu_schema::dtype_to_value, CustomValueSupport, PolarsPluginObject, PolarsPluginType};
#[derive(Debug, Clone)]
pub struct NuDataType {
pub id: uuid::Uuid,
dtype: DataType,
}
impl NuDataType {
pub fn new(dtype: DataType) -> Self {
Self {
id: uuid::Uuid::new_v4(),
dtype,
}
}
pub fn new_with_str(dtype: &str, span: Span) -> Result<Self, ShellError> {
let dtype = str_to_dtype(dtype, span)?;
Ok(Self {
id: uuid::Uuid::new_v4(),
dtype,
})
}
pub fn to_polars(&self) -> DataType {
self.dtype.clone()
}
}
impl From<NuDataType> for Value {
fn from(nu_dtype: NuDataType) -> Self {
Value::String {
val: nu_dtype.dtype.to_string(),
internal_span: Span::unknown(),
}
}
}
impl Cacheable for NuDataType {
fn cache_id(&self) -> &Uuid {
&self.id
}
fn to_cache_value(&self) -> Result<super::PolarsPluginObject, ShellError> {
Ok(PolarsPluginObject::NuDataType(self.clone()))
}
fn from_cache_value(cv: super::PolarsPluginObject) -> Result<Self, ShellError> {
match cv {
PolarsPluginObject::NuDataType(dt) => Ok(dt),
_ => Err(ShellError::GenericError {
error: "Cache value is not a dataframe".into(),
msg: "".into(),
span: None,
help: None,
inner: vec![],
}),
}
}
}
impl CustomValueSupport for NuDataType {
type CV = NuDataTypeCustomValue;
fn get_type_static() -> super::PolarsPluginType {
PolarsPluginType::NuDataType
}
fn custom_value(self) -> Self::CV {
NuDataTypeCustomValue {
id: self.id,
datatype: Some(self),
}
}
fn base_value(self, span: Span) -> Result<Value, ShellError> {
Ok(dtype_to_value(&self.dtype, span))
}
fn try_from_value(plugin: &PolarsPlugin, value: &Value) -> Result<Self, ShellError> {
match value {
Value::Custom { val, .. } => {
if let Some(cv) = val.as_any().downcast_ref::<Self::CV>() {
Self::try_from_custom_value(plugin, cv)
} else {
Err(ShellError::CantConvert {
to_type: Self::get_type_static().to_string(),
from_type: value.get_type().to_string(),
span: value.span(),
help: None,
})
}
}
Value::String { val, internal_span } => NuDataType::new_with_str(val, *internal_span),
_ => Err(ShellError::CantConvert {
to_type: Self::get_type_static().to_string(),
from_type: value.get_type().to_string(),
span: value.span(),
help: None,
}),
}
}
}
pub fn datatype_list(span: Span) -> Value {
let types: Vec<Value> = [
("null", ""),
("bool", ""),
("u8", ""),
("u16", ""),
("u32", ""),
("u64", ""),
("i8", ""),
("i16", ""),
("i32", ""),
("i64", ""),
("f32", ""),
("f64", ""),
("str", ""),
("binary", ""),
("date", ""),
("datetime<time_unit: (ms, us, ns) timezone (optional)>", "Time Unit can be: milliseconds: ms, microseconds: us, nanoseconds: ns. Timezone wildcard is *. Other Timezone examples: UTC, America/Los_Angeles."),
("duration<time_unit: (ms, us, ns)>", "Time Unit can be: milliseconds: ms, microseconds: us, nanoseconds: ns."),
("time", ""),
("object", ""),
("unknown", ""),
("list<dtype>", ""),
]
.iter()
.map(|(dtype, note)| {
Value::record(record! {
"dtype" => Value::string(*dtype, span),
"note" => Value::string(*note, span),
},
span)
})
.collect();
Value::list(types, span)
}
pub fn str_to_dtype(dtype: &str, span: Span) -> Result<DataType, ShellError> {
match dtype {
"bool" => Ok(DataType::Boolean),
"u8" => Ok(DataType::UInt8),
"u16" => Ok(DataType::UInt16),
"u32" => Ok(DataType::UInt32),
"u64" => Ok(DataType::UInt64),
"i8" => Ok(DataType::Int8),
"i16" => Ok(DataType::Int16),
"i32" => Ok(DataType::Int32),
"i64" => Ok(DataType::Int64),
"f32" => Ok(DataType::Float32),
"f64" => Ok(DataType::Float64),
"str" => Ok(DataType::String),
"binary" => Ok(DataType::Binary),
"date" => Ok(DataType::Date),
"time" => Ok(DataType::Time),
"null" => Ok(DataType::Null),
"unknown" => Ok(DataType::Unknown(UnknownKind::Any)),
"object" => Ok(DataType::Object("unknown", None)),
_ if dtype.starts_with("list") => {
let dtype = dtype
.trim_start_matches("list")
.trim_start_matches('<')
.trim_end_matches('>')
.trim();
let dtype = str_to_dtype(dtype, span)?;
Ok(DataType::List(Box::new(dtype)))
}
_ if dtype.starts_with("datetime") => {
let dtype = dtype
.trim_start_matches("datetime")
.trim_start_matches('<')
.trim_end_matches('>');
let mut split = dtype.split(',');
let next = split
.next()
.ok_or_else(|| ShellError::GenericError {
error: "Invalid polars data type".into(),
msg: "Missing time unit".into(),
span: Some(span),
help: None,
inner: vec![],
})?
.trim();
let time_unit = str_to_time_unit(next, span)?;
let next = split
.next()
.ok_or_else(|| ShellError::GenericError {
error: "Invalid polars data type".into(),
msg: "Missing time zone".into(),
span: Some(span),
help: None,
inner: vec![],
})?
.trim();
let timezone = if "*" == next {
None
} else {
Some(next.to_string())
};
Ok(DataType::Datetime(
time_unit,
timezone.map(PlSmallStr::from),
))
}
_ if dtype.starts_with("duration") => {
let inner = dtype.trim_start_matches("duration<").trim_end_matches('>');
let next = inner
.split(',')
.next()
.ok_or_else(|| ShellError::GenericError {
error: "Invalid polars data type".into(),
msg: "Missing time unit".into(),
span: Some(span),
help: None,
inner: vec![],
})?
.trim();
let time_unit = str_to_time_unit(next, span)?;
Ok(DataType::Duration(time_unit))
}
_ if dtype.starts_with("decimal") => {
let dtype = dtype
.trim_start_matches("decimal")
.trim_start_matches('<')
.trim_end_matches('>');
let mut split = dtype.split(',');
let next = split
.next()
.ok_or_else(|| ShellError::GenericError {
error: "Invalid polars data type".into(),
msg: "Missing decimal precision".into(),
span: Some(span),
help: None,
inner: vec![],
})?
.trim();
let precision = match next {
"*" => None, // infer
_ => Some(
next.parse::<usize>()
.map_err(|e| ShellError::GenericError {
error: "Invalid polars data type".into(),
msg: format!("Error in parsing decimal precision: {e}"),
span: Some(span),
help: None,
inner: vec![],
})?,
),
};
let next = split
.next()
.ok_or_else(|| ShellError::GenericError {
error: "Invalid polars data type".into(),
msg: "Missing decimal scale".into(),
span: Some(span),
help: None,
inner: vec![],
})?
.trim();
let scale = match next {
"*" => Err(ShellError::GenericError {
error: "Invalid polars data type".into(),
msg: "`*` is not a permitted value for scale".into(),
span: Some(span),
help: None,
inner: vec![],
}),
_ => next
.parse::<usize>()
.map(Some)
.map_err(|e| ShellError::GenericError {
error: "Invalid polars data type".into(),
msg: format!("Error in parsing decimal precision: {e}"),
span: Some(span),
help: None,
inner: vec![],
}),
}?;
Ok(DataType::Decimal(precision, scale))
}
_ => Err(ShellError::GenericError {
error: "Invalid polars data type".into(),
msg: format!("Unknown type: {dtype}"),
span: Some(span),
help: None,
inner: vec![],
}),
}
}
fn str_to_time_unit(ts_string: &str, span: Span) -> Result<TimeUnit, ShellError> {
match ts_string {
"ms" => Ok(TimeUnit::Milliseconds),
"us" | "μs" => Ok(TimeUnit::Microseconds),
"ns" => Ok(TimeUnit::Nanoseconds),
_ => Err(ShellError::GenericError {
error: "Invalid polars data type".into(),
msg: "Invalid time unit".into(),
span: Some(span),
help: None,
inner: vec![],
}),
}
}

View File

@ -1,480 +0,0 @@
use std::sync::Arc;
use nu_protocol::{ShellError, Span, Value};
use polars::{
datatypes::UnknownKind,
prelude::{DataType, Field, PlSmallStr, Schema, SchemaExt, SchemaRef, TimeUnit},
};
#[derive(Debug, Clone)]
pub struct NuSchema {
pub schema: SchemaRef,
}
impl NuSchema {
pub fn new(schema: SchemaRef) -> Self {
Self { schema }
}
}
impl TryFrom<&Value> for NuSchema {
type Error = ShellError;
fn try_from(value: &Value) -> Result<Self, Self::Error> {
let schema = value_to_schema(value, Span::unknown())?;
Ok(Self::new(Arc::new(schema)))
}
}
impl From<NuSchema> for Value {
fn from(schema: NuSchema) -> Self {
fields_to_value(schema.schema.iter_fields(), Span::unknown())
}
}
impl From<NuSchema> for SchemaRef {
fn from(val: NuSchema) -> Self {
Arc::clone(&val.schema)
}
}
impl From<SchemaRef> for NuSchema {
fn from(val: SchemaRef) -> Self {
Self { schema: val }
}
}
fn fields_to_value(fields: impl Iterator<Item = Field>, span: Span) -> Value {
let record = fields
.map(|field| {
let col = field.name().to_string();
let val = dtype_to_value(field.dtype(), span);
(col, val)
})
.collect();
Value::record(record, Span::unknown())
}
fn dtype_to_value(dtype: &DataType, span: Span) -> Value {
match dtype {
DataType::Struct(fields) => fields_to_value(fields.iter().cloned(), span),
_ => Value::string(dtype.to_string().replace('[', "<").replace(']', ">"), span),
}
}
fn value_to_schema(value: &Value, span: Span) -> Result<Schema, ShellError> {
let fields = value_to_fields(value, span)?;
let schema = Schema::from_iter(fields);
Ok(schema)
}
fn value_to_fields(value: &Value, span: Span) -> Result<Vec<Field>, ShellError> {
let fields = value
.as_record()?
.into_iter()
.map(|(col, val)| match val {
Value::Record { .. } => {
let fields = value_to_fields(val, span)?;
let dtype = DataType::Struct(fields);
Ok(Field::new(col.into(), dtype))
}
_ => {
let dtype = str_to_dtype(&val.coerce_string()?, span)?;
Ok(Field::new(col.into(), dtype))
}
})
.collect::<Result<Vec<Field>, ShellError>>()?;
Ok(fields)
}
pub fn str_to_dtype(dtype: &str, span: Span) -> Result<DataType, ShellError> {
match dtype {
"bool" => Ok(DataType::Boolean),
"u8" => Ok(DataType::UInt8),
"u16" => Ok(DataType::UInt16),
"u32" => Ok(DataType::UInt32),
"u64" => Ok(DataType::UInt64),
"i8" => Ok(DataType::Int8),
"i16" => Ok(DataType::Int16),
"i32" => Ok(DataType::Int32),
"i64" => Ok(DataType::Int64),
"f32" => Ok(DataType::Float32),
"f64" => Ok(DataType::Float64),
"str" => Ok(DataType::String),
"binary" => Ok(DataType::Binary),
"date" => Ok(DataType::Date),
"time" => Ok(DataType::Time),
"null" => Ok(DataType::Null),
"unknown" => Ok(DataType::Unknown(UnknownKind::Any)),
"object" => Ok(DataType::Object("unknown", None)),
_ if dtype.starts_with("list") => {
let dtype = dtype
.trim_start_matches("list")
.trim_start_matches('<')
.trim_end_matches('>')
.trim();
let dtype = str_to_dtype(dtype, span)?;
Ok(DataType::List(Box::new(dtype)))
}
_ if dtype.starts_with("datetime") => {
let dtype = dtype
.trim_start_matches("datetime")
.trim_start_matches('<')
.trim_end_matches('>');
let mut split = dtype.split(',');
let next = split
.next()
.ok_or_else(|| ShellError::GenericError {
error: "Invalid polars data type".into(),
msg: "Missing time unit".into(),
span: Some(span),
help: None,
inner: vec![],
})?
.trim();
let time_unit = str_to_time_unit(next, span)?;
let next = split
.next()
.ok_or_else(|| ShellError::GenericError {
error: "Invalid polars data type".into(),
msg: "Missing time zone".into(),
span: Some(span),
help: None,
inner: vec![],
})?
.trim();
let timezone = if "*" == next {
None
} else {
Some(next.to_string())
};
Ok(DataType::Datetime(
time_unit,
timezone.map(PlSmallStr::from),
))
}
_ if dtype.starts_with("duration") => {
let inner = dtype.trim_start_matches("duration<").trim_end_matches('>');
let next = inner
.split(',')
.next()
.ok_or_else(|| ShellError::GenericError {
error: "Invalid polars data type".into(),
msg: "Missing time unit".into(),
span: Some(span),
help: None,
inner: vec![],
})?
.trim();
let time_unit = str_to_time_unit(next, span)?;
Ok(DataType::Duration(time_unit))
}
_ if dtype.starts_with("decimal") => {
let dtype = dtype
.trim_start_matches("decimal")
.trim_start_matches('<')
.trim_end_matches('>');
let mut split = dtype.split(',');
let next = split
.next()
.ok_or_else(|| ShellError::GenericError {
error: "Invalid polars data type".into(),
msg: "Missing decimal precision".into(),
span: Some(span),
help: None,
inner: vec![],
})?
.trim();
let precision = match next {
"*" => None, // infer
_ => Some(
next.parse::<usize>()
.map_err(|e| ShellError::GenericError {
error: "Invalid polars data type".into(),
msg: format!("Error in parsing decimal precision: {e}"),
span: Some(span),
help: None,
inner: vec![],
})?,
),
};
let next = split
.next()
.ok_or_else(|| ShellError::GenericError {
error: "Invalid polars data type".into(),
msg: "Missing decimal scale".into(),
span: Some(span),
help: None,
inner: vec![],
})?
.trim();
let scale = match next {
"*" => Err(ShellError::GenericError {
error: "Invalid polars data type".into(),
msg: "`*` is not a permitted value for scale".into(),
span: Some(span),
help: None,
inner: vec![],
}),
_ => next
.parse::<usize>()
.map(Some)
.map_err(|e| ShellError::GenericError {
error: "Invalid polars data type".into(),
msg: format!("Error in parsing decimal precision: {e}"),
span: Some(span),
help: None,
inner: vec![],
}),
}?;
Ok(DataType::Decimal(precision, scale))
}
_ => Err(ShellError::GenericError {
error: "Invalid polars data type".into(),
msg: format!("Unknown type: {dtype}"),
span: Some(span),
help: None,
inner: vec![],
}),
}
}
fn str_to_time_unit(ts_string: &str, span: Span) -> Result<TimeUnit, ShellError> {
match ts_string {
"ms" => Ok(TimeUnit::Milliseconds),
"us" | "μs" => Ok(TimeUnit::Microseconds),
"ns" => Ok(TimeUnit::Nanoseconds),
_ => Err(ShellError::GenericError {
error: "Invalid polars data type".into(),
msg: "Invalid time unit".into(),
span: Some(span),
help: None,
inner: vec![],
}),
}
}
#[cfg(test)]
mod test {
use nu_protocol::record;
use super::*;
#[test]
fn test_value_to_schema() {
let address = record! {
"street" => Value::test_string("str"),
"city" => Value::test_string("str"),
};
let value = Value::test_record(record! {
"name" => Value::test_string("str"),
"age" => Value::test_string("i32"),
"address" => Value::test_record(address)
});
let schema = value_to_schema(&value, Span::unknown()).unwrap();
let expected = Schema::from_iter(vec![
Field::new("name".into(), DataType::String),
Field::new("age".into(), DataType::Int32),
Field::new(
"address".into(),
DataType::Struct(vec![
Field::new("street".into(), DataType::String),
Field::new("city".into(), DataType::String),
]),
),
]);
assert_eq!(schema, expected);
}
#[test]
fn test_dtype_str_to_schema_simple_types() {
let dtype = "bool";
let schema = str_to_dtype(dtype, Span::unknown()).unwrap();
let expected = DataType::Boolean;
assert_eq!(schema, expected);
let dtype = "u8";
let schema = str_to_dtype(dtype, Span::unknown()).unwrap();
let expected = DataType::UInt8;
assert_eq!(schema, expected);
let dtype = "u16";
let schema = str_to_dtype(dtype, Span::unknown()).unwrap();
let expected = DataType::UInt16;
assert_eq!(schema, expected);
let dtype = "u32";
let schema = str_to_dtype(dtype, Span::unknown()).unwrap();
let expected = DataType::UInt32;
assert_eq!(schema, expected);
let dtype = "u64";
let schema = str_to_dtype(dtype, Span::unknown()).unwrap();
let expected = DataType::UInt64;
assert_eq!(schema, expected);
let dtype = "i8";
let schema = str_to_dtype(dtype, Span::unknown()).unwrap();
let expected = DataType::Int8;
assert_eq!(schema, expected);
let dtype = "i16";
let schema = str_to_dtype(dtype, Span::unknown()).unwrap();
let expected = DataType::Int16;
assert_eq!(schema, expected);
let dtype = "i32";
let schema = str_to_dtype(dtype, Span::unknown()).unwrap();
let expected = DataType::Int32;
assert_eq!(schema, expected);
let dtype = "i64";
let schema = str_to_dtype(dtype, Span::unknown()).unwrap();
let expected = DataType::Int64;
assert_eq!(schema, expected);
let dtype = "str";
let schema = str_to_dtype(dtype, Span::unknown()).unwrap();
let expected = DataType::String;
assert_eq!(schema, expected);
let dtype = "binary";
let schema = str_to_dtype(dtype, Span::unknown()).unwrap();
let expected = DataType::Binary;
assert_eq!(schema, expected);
let dtype = "date";
let schema = str_to_dtype(dtype, Span::unknown()).unwrap();
let expected = DataType::Date;
assert_eq!(schema, expected);
let dtype = "time";
let schema = str_to_dtype(dtype, Span::unknown()).unwrap();
let expected = DataType::Time;
assert_eq!(schema, expected);
let dtype = "null";
let schema = str_to_dtype(dtype, Span::unknown()).unwrap();
let expected = DataType::Null;
assert_eq!(schema, expected);
let dtype = "unknown";
let schema = str_to_dtype(dtype, Span::unknown()).unwrap();
let expected = DataType::Unknown(UnknownKind::Any);
assert_eq!(schema, expected);
let dtype = "object";
let schema = str_to_dtype(dtype, Span::unknown()).unwrap();
let expected = DataType::Object("unknown", None);
assert_eq!(schema, expected);
}
#[test]
fn test_dtype_str_schema_datetime() {
let dtype = "datetime<ms, *>";
let schema = str_to_dtype(dtype, Span::unknown()).unwrap();
let expected = DataType::Datetime(TimeUnit::Milliseconds, None);
assert_eq!(schema, expected);
let dtype = "datetime<us, *>";
let schema = str_to_dtype(dtype, Span::unknown()).unwrap();
let expected = DataType::Datetime(TimeUnit::Microseconds, None);
assert_eq!(schema, expected);
let dtype = "datetime<μs, *>";
let schema = str_to_dtype(dtype, Span::unknown()).unwrap();
let expected = DataType::Datetime(TimeUnit::Microseconds, None);
assert_eq!(schema, expected);
let dtype = "datetime<ns, *>";
let schema = str_to_dtype(dtype, Span::unknown()).unwrap();
let expected = DataType::Datetime(TimeUnit::Nanoseconds, None);
assert_eq!(schema, expected);
let dtype = "datetime<ms, UTC>";
let schema = str_to_dtype(dtype, Span::unknown()).unwrap();
let expected = DataType::Datetime(TimeUnit::Milliseconds, Some("UTC".into()));
assert_eq!(schema, expected);
let dtype = "invalid";
let schema = str_to_dtype(dtype, Span::unknown());
assert!(schema.is_err())
}
#[test]
fn test_dtype_str_schema_duration() {
let dtype = "duration<ms>";
let schema = str_to_dtype(dtype, Span::unknown()).unwrap();
let expected = DataType::Duration(TimeUnit::Milliseconds);
assert_eq!(schema, expected);
let dtype = "duration<us>";
let schema = str_to_dtype(dtype, Span::unknown()).unwrap();
let expected = DataType::Duration(TimeUnit::Microseconds);
assert_eq!(schema, expected);
let dtype = "duration<μs>";
let schema = str_to_dtype(dtype, Span::unknown()).unwrap();
let expected = DataType::Duration(TimeUnit::Microseconds);
assert_eq!(schema, expected);
let dtype = "duration<ns>";
let schema = str_to_dtype(dtype, Span::unknown()).unwrap();
let expected = DataType::Duration(TimeUnit::Nanoseconds);
assert_eq!(schema, expected);
}
#[test]
fn test_dtype_str_schema_decimal() {
let dtype = "decimal<7,2>";
let schema = str_to_dtype(dtype, Span::unknown()).unwrap();
let expected = DataType::Decimal(Some(7usize), Some(2usize));
assert_eq!(schema, expected);
// "*" is not a permitted value for scale
let dtype = "decimal<7,*>";
let schema = str_to_dtype(dtype, Span::unknown());
assert!(matches!(schema, Err(ShellError::GenericError { .. })));
let dtype = "decimal<*,2>";
let schema = str_to_dtype(dtype, Span::unknown()).unwrap();
let expected = DataType::Decimal(None, Some(2usize));
assert_eq!(schema, expected);
}
#[test]
fn test_dtype_str_to_schema_list_types() {
let dtype = "list<i32>";
let schema = str_to_dtype(dtype, Span::unknown()).unwrap();
let expected = DataType::List(Box::new(DataType::Int32));
assert_eq!(schema, expected);
let dtype = "list<duration<ms>>";
let schema = str_to_dtype(dtype, Span::unknown()).unwrap();
let expected = DataType::List(Box::new(DataType::Duration(TimeUnit::Milliseconds)));
assert_eq!(schema, expected);
let dtype = "list<datetime<ms, *>>";
let schema = str_to_dtype(dtype, Span::unknown()).unwrap();
let expected = DataType::List(Box::new(DataType::Datetime(TimeUnit::Milliseconds, None)));
assert_eq!(schema, expected);
let dtype = "list<decimal<7,2>>";
let schema = str_to_dtype(dtype, Span::unknown()).unwrap();
let expected = DataType::List(Box::new(DataType::Decimal(Some(7usize), Some(2usize))));
assert_eq!(schema, expected);
let dtype = "list<decimal<*,2>>";
let schema = str_to_dtype(dtype, Span::unknown()).unwrap();
let expected = DataType::List(Box::new(DataType::Decimal(None, Some(2usize))));
assert_eq!(schema, expected);
let dtype = "list<decimal<7,*>>";
let schema = str_to_dtype(dtype, Span::unknown());
assert!(matches!(schema, Err(ShellError::GenericError { .. })));
}
}

View File

@ -0,0 +1,65 @@
use nu_protocol::{CustomValue, ShellError, Span, Value};
use serde::{Deserialize, Serialize};
use uuid::Uuid;
use crate::values::{CustomValueSupport, PolarsPluginCustomValue};
use super::NuSchema;
#[derive(Clone, Debug, Deserialize, Serialize)]
pub struct NuSchemaCustomValue {
pub id: Uuid,
#[serde(skip)]
pub datatype: Option<NuSchema>,
}
#[typetag::serde]
impl CustomValue for NuSchemaCustomValue {
fn clone_value(&self, span: nu_protocol::Span) -> Value {
Value::custom(Box::new(self.clone()), span)
}
fn type_name(&self) -> String {
"NuSchema".into()
}
fn to_base_value(&self, span: Span) -> Result<Value, ShellError> {
Ok(Value::string(
"NuSchema: custom_value_to_base_value should've been called",
span,
))
}
fn as_mut_any(&mut self) -> &mut dyn std::any::Any {
self
}
fn as_any(&self) -> &dyn std::any::Any {
self
}
fn notify_plugin_on_drop(&self) -> bool {
true
}
}
impl PolarsPluginCustomValue for NuSchemaCustomValue {
type PolarsPluginObjectType = NuSchema;
fn id(&self) -> &Uuid {
&self.id
}
fn internal(&self) -> &Option<Self::PolarsPluginObjectType> {
&self.datatype
}
fn custom_value_to_base_value(
&self,
plugin: &crate::PolarsPlugin,
_engine: &nu_plugin::EngineInterface,
) -> Result<Value, ShellError> {
let dtype = NuSchema::try_from_custom_value(plugin, self)?;
dtype.base_value(Span::unknown())
}
}

View File

@ -0,0 +1,189 @@
pub mod custom_value;
use std::sync::Arc;
use custom_value::NuSchemaCustomValue;
use nu_protocol::{ShellError, Span, Value};
use polars::prelude::{DataType, Field, Schema, SchemaExt, SchemaRef};
use uuid::Uuid;
use crate::{Cacheable, PolarsPlugin};
use super::{str_to_dtype, CustomValueSupport, NuDataType, PolarsPluginObject, PolarsPluginType};
#[derive(Debug, Clone)]
pub struct NuSchema {
pub id: Uuid,
pub schema: SchemaRef,
}
impl NuSchema {
pub fn new(schema: SchemaRef) -> Self {
Self {
id: Uuid::new_v4(),
schema,
}
}
}
impl From<NuSchema> for SchemaRef {
fn from(val: NuSchema) -> Self {
Arc::clone(&val.schema)
}
}
impl From<SchemaRef> for NuSchema {
fn from(val: SchemaRef) -> Self {
Self::new(val)
}
}
impl Cacheable for NuSchema {
fn cache_id(&self) -> &Uuid {
&self.id
}
fn to_cache_value(&self) -> Result<super::PolarsPluginObject, ShellError> {
Ok(PolarsPluginObject::NuSchema(self.clone()))
}
fn from_cache_value(cv: super::PolarsPluginObject) -> Result<Self, ShellError> {
match cv {
PolarsPluginObject::NuSchema(dt) => Ok(dt),
_ => Err(ShellError::GenericError {
error: "Cache value is not a dataframe".into(),
msg: "".into(),
span: None,
help: None,
inner: vec![],
}),
}
}
}
impl CustomValueSupport for NuSchema {
type CV = NuSchemaCustomValue;
fn get_type_static() -> super::PolarsPluginType {
PolarsPluginType::NuSchema
}
fn custom_value(self) -> Self::CV {
NuSchemaCustomValue {
id: self.id,
datatype: Some(self),
}
}
fn base_value(self, span: Span) -> Result<Value, ShellError> {
Ok(fields_to_value(self.schema.iter_fields(), span))
}
fn try_from_value(plugin: &PolarsPlugin, value: &Value) -> Result<Self, ShellError> {
if let Value::Custom { val, .. } = value {
if let Some(cv) = val.as_any().downcast_ref::<Self::CV>() {
Self::try_from_custom_value(plugin, cv)
} else {
Err(ShellError::CantConvert {
to_type: Self::get_type_static().to_string(),
from_type: value.get_type().to_string(),
span: value.span(),
help: None,
})
}
} else {
let schema = value_to_schema(plugin, value, Span::unknown())?;
Ok(Self::new(Arc::new(schema)))
}
}
}
fn fields_to_value(fields: impl Iterator<Item = Field>, span: Span) -> Value {
let record = fields
.map(|field| {
let col = field.name().to_string();
let val = dtype_to_value(field.dtype(), span);
(col, val)
})
.collect();
Value::record(record, Span::unknown())
}
pub fn dtype_to_value(dtype: &DataType, span: Span) -> Value {
match dtype {
DataType::Struct(fields) => fields_to_value(fields.iter().cloned(), span),
_ => Value::string(dtype.to_string().replace('[', "<").replace(']', ">"), span),
}
}
fn value_to_schema(plugin: &PolarsPlugin, value: &Value, span: Span) -> Result<Schema, ShellError> {
let fields = value_to_fields(plugin, value, span)?;
let schema = Schema::from_iter(fields);
Ok(schema)
}
fn value_to_fields(
plugin: &PolarsPlugin,
value: &Value,
span: Span,
) -> Result<Vec<Field>, ShellError> {
let fields = value
.as_record()?
.into_iter()
.map(|(col, val)| match val {
Value::Record { .. } => {
let fields = value_to_fields(plugin, val, span)?;
let dtype = DataType::Struct(fields);
Ok(Field::new(col.into(), dtype))
}
Value::Custom { .. } => {
let dtype = NuDataType::try_from_value(plugin, val)?;
Ok(Field::new(col.into(), dtype.to_polars()))
}
_ => {
let dtype = str_to_dtype(&val.coerce_string()?, span)?;
Ok(Field::new(col.into(), dtype))
}
})
.collect::<Result<Vec<Field>, ShellError>>()?;
Ok(fields)
}
#[cfg(test)]
mod test {
use nu_protocol::record;
use super::*;
#[test]
fn test_value_to_schema() {
let plugin = PolarsPlugin::new_test_mode().expect("Failed to create plugin");
let address = record! {
"street" => Value::test_string("str"),
"city" => Value::test_string("str"),
};
let value = Value::test_record(record! {
"name" => Value::test_string("str"),
"age" => Value::test_string("i32"),
"address" => Value::test_record(address)
});
let schema = value_to_schema(&plugin, &value, Span::unknown()).unwrap();
let expected = Schema::from_iter(vec![
Field::new("name".into(), DataType::String),
Field::new("age".into(), DataType::Int32),
Field::new(
"address".into(),
DataType::Struct(vec![
Field::new("street".into(), DataType::String),
Field::new("city".into(), DataType::String),
]),
),
]);
assert_eq!(schema, expected);
}
}

View File

@ -123,6 +123,8 @@ impl Plugin for PolarsPlugin {
CustomValueType::NuExpression(cv) => cv.custom_value_to_base_value(self, engine),
CustomValueType::NuLazyGroupBy(cv) => cv.custom_value_to_base_value(self, engine),
CustomValueType::NuWhen(cv) => cv.custom_value_to_base_value(self, engine),
CustomValueType::NuDataType(cv) => cv.custom_value_to_base_value(self, engine),
CustomValueType::NuSchema(cv) => cv.custom_value_to_base_value(self, engine),
};
Ok(result?)
}
@ -150,6 +152,12 @@ impl Plugin for PolarsPlugin {
CustomValueType::NuWhen(cv) => {
cv.custom_value_operation(self, engine, left.span, operator, right)
}
CustomValueType::NuDataType(cv) => {
cv.custom_value_operation(self, engine, left.span, operator, right)
}
CustomValueType::NuSchema(cv) => {
cv.custom_value_operation(self, engine, left.span, operator, right)
}
};
Ok(result?)
}
@ -176,6 +184,12 @@ impl Plugin for PolarsPlugin {
CustomValueType::NuWhen(cv) => {
cv.custom_value_follow_path_int(self, engine, custom_value.span, index)
}
CustomValueType::NuDataType(cv) => {
cv.custom_value_follow_path_int(self, engine, custom_value.span, index)
}
CustomValueType::NuSchema(cv) => {
cv.custom_value_follow_path_int(self, engine, custom_value.span, index)
}
};
Ok(result?)
}
@ -202,6 +216,12 @@ impl Plugin for PolarsPlugin {
CustomValueType::NuWhen(cv) => {
cv.custom_value_follow_path_string(self, engine, custom_value.span, column_name)
}
CustomValueType::NuDataType(cv) => {
cv.custom_value_follow_path_string(self, engine, custom_value.span, column_name)
}
CustomValueType::NuSchema(cv) => {
cv.custom_value_follow_path_string(self, engine, custom_value.span, column_name)
}
};
Ok(result?)
}
@ -226,6 +246,10 @@ impl Plugin for PolarsPlugin {
cv.custom_value_partial_cmp(self, engine, other_value)
}
CustomValueType::NuWhen(cv) => cv.custom_value_partial_cmp(self, engine, other_value),
CustomValueType::NuDataType(cv) => {
cv.custom_value_partial_cmp(self, engine, other_value)
}
CustomValueType::NuSchema(cv) => cv.custom_value_partial_cmp(self, engine, other_value),
};
Ok(result?)
}