mirror of
https://github.com/nushell/nushell.git
synced 2025-04-14 16:28:17 +02:00
Introduction of NuDataType and polars dtype
(#15529)
# Description This pull request does a lot of the heavy lifting needed to supported more complex dtypes like categorical dtypes. It introduces a new CustomValue, NuDataType and makes NuSchema a full CustomValue. Further more it introduces a new command `polars into-dtype` that allows a dtype to be created. This can then be passed into schemas when they are created. ```nu > ❯ : let dt = ("str" | polars to-dtype) > ❯ : [[a b]; ["one" "two"]] | polars into-df -s {a: $dt, b: str} | polars schema ╭───┬─────╮ │ a │ str │ │ b │ str │ ╰───┴─────╯ ``` # User-Facing Changes - Introduces new command `polars into-dtype`, allows dtype variables to be passed in during schema creation.
This commit is contained in:
parent
173162df2e
commit
b0f9cda9b5
32
crates/nu_plugin_polars/src/cache/list.rs
vendored
32
crates/nu_plugin_polars/src/cache/list.rs
vendored
@ -132,7 +132,37 @@ impl PluginCommand for ListDF {
|
||||
"reference_count" => Value::int(value.reference_count as i64, call.head),
|
||||
},
|
||||
call.head,
|
||||
)))
|
||||
))),
|
||||
PolarsPluginObject::NuDataType(_) => Ok(Some(Value::record(
|
||||
record! {
|
||||
"key" => Value::string(key.to_string(), call.head),
|
||||
"created" => Value::date(value.created, call.head),
|
||||
"columns" => Value::nothing(call.head),
|
||||
"rows" => Value::nothing(call.head),
|
||||
"type" => Value::string("DataType", call.head),
|
||||
"estimated_size" => Value::nothing(call.head),
|
||||
"span_contents" => Value::string(span_contents, value.span),
|
||||
"span_start" => Value::int(value.span.start as i64, call.head),
|
||||
"span_end" => Value::int(value.span.end as i64, call.head),
|
||||
"reference_count" => Value::int(value.reference_count as i64, call.head),
|
||||
},
|
||||
call.head,
|
||||
))),
|
||||
PolarsPluginObject::NuSchema(_) => Ok(Some(Value::record(
|
||||
record! {
|
||||
"key" => Value::string(key.to_string(), call.head),
|
||||
"created" => Value::date(value.created, call.head),
|
||||
"columns" => Value::nothing(call.head),
|
||||
"rows" => Value::nothing(call.head),
|
||||
"type" => Value::string("Schema", call.head),
|
||||
"estimated_size" => Value::nothing(call.head),
|
||||
"span_contents" => Value::string(span_contents, value.span),
|
||||
"span_start" => Value::int(value.span.start as i64, call.head),
|
||||
"span_end" => Value::int(value.span.end as i64, call.head),
|
||||
"reference_count" => Value::int(value.reference_count as i64, call.head),
|
||||
},
|
||||
call.head,
|
||||
))),
|
||||
}
|
||||
})?;
|
||||
let vals = vals.into_iter().flatten().collect();
|
||||
|
@ -9,6 +9,7 @@ mod schema;
|
||||
mod shape;
|
||||
mod summary;
|
||||
mod to_df;
|
||||
mod to_dtype;
|
||||
mod to_lazy;
|
||||
mod to_nu;
|
||||
mod to_repr;
|
||||
@ -40,5 +41,6 @@ pub(crate) fn core_commands() -> Vec<Box<dyn PluginCommand<Plugin = PolarsPlugin
|
||||
Box::new(save::SaveDF),
|
||||
Box::new(ToLazyFrame),
|
||||
Box::new(ToRepr),
|
||||
Box::new(to_dtype::ToDataType),
|
||||
]
|
||||
}
|
||||
|
@ -163,7 +163,7 @@ fn command(
|
||||
});
|
||||
}
|
||||
|
||||
let hive_options = build_hive_options(call)?;
|
||||
let hive_options = build_hive_options(plugin, call)?;
|
||||
|
||||
match type_option {
|
||||
Some((ext, blamed)) => match PolarsFileType::from(ext.as_str()) {
|
||||
@ -388,7 +388,7 @@ fn from_json(
|
||||
})?;
|
||||
let maybe_schema = call
|
||||
.get_flag("schema")?
|
||||
.map(|schema| NuSchema::try_from(&schema))
|
||||
.map(|schema| NuSchema::try_from_value(plugin, &schema))
|
||||
.transpose()?;
|
||||
|
||||
let buf_reader = BufReader::new(file);
|
||||
@ -429,11 +429,7 @@ fn from_ndjson(
|
||||
NonZeroUsize::new(DEFAULT_INFER_SCHEMA)
|
||||
.expect("The default infer-schema should be non zero"),
|
||||
);
|
||||
let maybe_schema = call
|
||||
.get_flag("schema")?
|
||||
.map(|schema| NuSchema::try_from(&schema))
|
||||
.transpose()?;
|
||||
|
||||
let maybe_schema = get_schema(plugin, call)?;
|
||||
if !is_eager {
|
||||
let start_time = std::time::Instant::now();
|
||||
|
||||
@ -511,10 +507,7 @@ fn from_csv(
|
||||
.unwrap_or(DEFAULT_INFER_SCHEMA);
|
||||
let skip_rows: Option<usize> = call.get_flag("skip-rows")?;
|
||||
let columns: Option<Vec<String>> = call.get_flag("columns")?;
|
||||
let maybe_schema = call
|
||||
.get_flag("schema")?
|
||||
.map(|schema| NuSchema::try_from(&schema))
|
||||
.transpose()?;
|
||||
let maybe_schema = get_schema(plugin, call)?;
|
||||
let truncate_ragged_lines: bool = call.has_flag("truncate-ragged-lines")?;
|
||||
|
||||
if !is_eager {
|
||||
@ -627,12 +620,15 @@ fn cloud_not_supported(file_type: PolarsFileType, span: Span) -> ShellError {
|
||||
}
|
||||
}
|
||||
|
||||
fn build_hive_options(call: &EvaluatedCall) -> Result<HiveOptions, ShellError> {
|
||||
fn build_hive_options(
|
||||
plugin: &PolarsPlugin,
|
||||
call: &EvaluatedCall,
|
||||
) -> Result<HiveOptions, ShellError> {
|
||||
let enabled: Option<bool> = call.get_flag("hive-enabled")?;
|
||||
let hive_start_idx: Option<usize> = call.get_flag("hive-start-idx")?;
|
||||
let schema: Option<NuSchema> = call
|
||||
.get_flag::<Value>("hive-schema")?
|
||||
.map(|schema| NuSchema::try_from(&schema))
|
||||
.map(|schema| NuSchema::try_from_value(plugin, &schema))
|
||||
.transpose()?;
|
||||
let try_parse_dates: bool = call.has_flag("hive-try-parse-dates")?;
|
||||
|
||||
@ -643,3 +639,11 @@ fn build_hive_options(call: &EvaluatedCall) -> Result<HiveOptions, ShellError> {
|
||||
try_parse_dates,
|
||||
})
|
||||
}
|
||||
|
||||
fn get_schema(plugin: &PolarsPlugin, call: &EvaluatedCall) -> Result<Option<NuSchema>, ShellError> {
|
||||
let schema: Option<NuSchema> = call
|
||||
.get_flag("schema")?
|
||||
.map(|schema| NuSchema::try_from_value(plugin, &schema))
|
||||
.transpose()?;
|
||||
Ok(schema)
|
||||
}
|
||||
|
@ -1,4 +1,7 @@
|
||||
use crate::{values::PolarsPluginObject, PolarsPlugin};
|
||||
use crate::{
|
||||
values::{datatype_list, CustomValueSupport, PolarsPluginObject},
|
||||
PolarsPlugin,
|
||||
};
|
||||
|
||||
use nu_plugin::{EngineInterface, EvaluatedCall, PluginCommand};
|
||||
use nu_protocol::{
|
||||
@ -67,12 +70,12 @@ fn command(
|
||||
match PolarsPluginObject::try_from_pipeline(plugin, input, call.head)? {
|
||||
PolarsPluginObject::NuDataFrame(df) => {
|
||||
let schema = df.schema();
|
||||
let value: Value = schema.into();
|
||||
let value = schema.base_value(call.head)?;
|
||||
Ok(PipelineData::Value(value, None))
|
||||
}
|
||||
PolarsPluginObject::NuLazyFrame(mut lazy) => {
|
||||
let schema = lazy.schema()?;
|
||||
let value: Value = schema.into();
|
||||
let value = schema.base_value(call.head)?;
|
||||
Ok(PipelineData::Value(value, None))
|
||||
}
|
||||
_ => Err(ShellError::GenericError {
|
||||
@ -85,42 +88,6 @@ fn command(
|
||||
}
|
||||
}
|
||||
|
||||
fn datatype_list(span: Span) -> Value {
|
||||
let types: Vec<Value> = [
|
||||
("null", ""),
|
||||
("bool", ""),
|
||||
("u8", ""),
|
||||
("u16", ""),
|
||||
("u32", ""),
|
||||
("u64", ""),
|
||||
("i8", ""),
|
||||
("i16", ""),
|
||||
("i32", ""),
|
||||
("i64", ""),
|
||||
("f32", ""),
|
||||
("f64", ""),
|
||||
("str", ""),
|
||||
("binary", ""),
|
||||
("date", ""),
|
||||
("datetime<time_unit: (ms, us, ns) timezone (optional)>", "Time Unit can be: milliseconds: ms, microseconds: us, nanoseconds: ns. Timezone wildcard is *. Other Timezone examples: UTC, America/Los_Angeles."),
|
||||
("duration<time_unit: (ms, us, ns)>", "Time Unit can be: milliseconds: ms, microseconds: us, nanoseconds: ns."),
|
||||
("time", ""),
|
||||
("object", ""),
|
||||
("unknown", ""),
|
||||
("list<dtype>", ""),
|
||||
]
|
||||
.iter()
|
||||
.map(|(dtype, note)| {
|
||||
Value::record(record! {
|
||||
"dtype" => Value::string(*dtype, span),
|
||||
"note" => Value::string(*note, span),
|
||||
},
|
||||
span)
|
||||
})
|
||||
.collect();
|
||||
Value::list(types, span)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
use super::*;
|
||||
|
@ -216,7 +216,7 @@ impl PluginCommand for ToDataFrame {
|
||||
) -> Result<PipelineData, LabeledError> {
|
||||
let maybe_schema = call
|
||||
.get_flag("schema")?
|
||||
.map(|schema| NuSchema::try_from(&schema))
|
||||
.map(|schema| NuSchema::try_from_value(plugin, &schema))
|
||||
.transpose()?;
|
||||
|
||||
debug!("schema: {:?}", maybe_schema);
|
||||
|
@ -0,0 +1,55 @@
|
||||
use nu_plugin::PluginCommand;
|
||||
use nu_protocol::{Category, Example, ShellError, Signature, Span, Type, Value};
|
||||
|
||||
use crate::{
|
||||
values::{CustomValueSupport, NuDataType},
|
||||
PolarsPlugin,
|
||||
};
|
||||
|
||||
pub struct ToDataType;
|
||||
|
||||
impl PluginCommand for ToDataType {
|
||||
type Plugin = PolarsPlugin;
|
||||
|
||||
fn name(&self) -> &str {
|
||||
"polars into-dtype"
|
||||
}
|
||||
|
||||
fn description(&self) -> &str {
|
||||
"Convert a string to a specific datatype."
|
||||
}
|
||||
|
||||
fn signature(&self) -> Signature {
|
||||
Signature::build(self.name())
|
||||
.input_output_type(Type::String, Type::Custom("datatype".into()))
|
||||
.category(Category::Custom("dataframe".into()))
|
||||
}
|
||||
|
||||
fn examples(&self) -> Vec<Example> {
|
||||
vec![Example {
|
||||
description: "Convert a string to a specific datatype",
|
||||
example: r#""i64" | polars into-dtype"#,
|
||||
result: Some(Value::string("i64", Span::test_data())),
|
||||
}]
|
||||
}
|
||||
|
||||
fn run(
|
||||
&self,
|
||||
plugin: &Self::Plugin,
|
||||
engine: &nu_plugin::EngineInterface,
|
||||
call: &nu_plugin::EvaluatedCall,
|
||||
input: nu_protocol::PipelineData,
|
||||
) -> Result<nu_protocol::PipelineData, nu_protocol::LabeledError> {
|
||||
command(plugin, engine, call, input).map_err(nu_protocol::LabeledError::from)
|
||||
}
|
||||
}
|
||||
|
||||
fn command(
|
||||
plugin: &PolarsPlugin,
|
||||
engine: &nu_plugin::EngineInterface,
|
||||
call: &nu_plugin::EvaluatedCall,
|
||||
input: nu_protocol::PipelineData,
|
||||
) -> Result<nu_protocol::PipelineData, ShellError> {
|
||||
NuDataType::try_from_pipeline(plugin, input, call.head)?
|
||||
.to_pipeline_data(plugin, engine, call.head)
|
||||
}
|
@ -56,7 +56,7 @@ impl PluginCommand for ToLazyFrame {
|
||||
) -> Result<PipelineData, LabeledError> {
|
||||
let maybe_schema = call
|
||||
.get_flag("schema")?
|
||||
.map(|schema| NuSchema::try_from(&schema))
|
||||
.map(|schema| NuSchema::try_from_value(plugin, &schema))
|
||||
.transpose()?;
|
||||
|
||||
let df = NuDataFrame::try_from_iter(plugin, input.into_iter(), maybe_schema)?;
|
||||
|
@ -1,5 +1,6 @@
|
||||
mod file_type;
|
||||
mod nu_dataframe;
|
||||
mod nu_dtype;
|
||||
mod nu_expression;
|
||||
mod nu_lazyframe;
|
||||
mod nu_lazygroupby;
|
||||
@ -8,19 +9,23 @@ mod nu_when;
|
||||
pub mod utils;
|
||||
|
||||
use crate::{Cacheable, PolarsPlugin};
|
||||
use nu_dtype::custom_value::NuDataTypeCustomValue;
|
||||
use nu_plugin::EngineInterface;
|
||||
use nu_protocol::{
|
||||
ast::Operator, CustomValue, PipelineData, ShellError, Span, Spanned, Type, Value,
|
||||
};
|
||||
use nu_schema::custom_value::NuSchemaCustomValue;
|
||||
use std::{cmp::Ordering, fmt};
|
||||
use uuid::Uuid;
|
||||
|
||||
pub use file_type::PolarsFileType;
|
||||
pub use nu_dataframe::{Axis, Column, NuDataFrame, NuDataFrameCustomValue};
|
||||
pub use nu_dtype::NuDataType;
|
||||
pub use nu_dtype::{datatype_list, str_to_dtype};
|
||||
pub use nu_expression::{NuExpression, NuExpressionCustomValue};
|
||||
pub use nu_lazyframe::{NuLazyFrame, NuLazyFrameCustomValue};
|
||||
pub use nu_lazygroupby::{NuLazyGroupBy, NuLazyGroupByCustomValue};
|
||||
pub use nu_schema::{str_to_dtype, NuSchema};
|
||||
pub use nu_schema::NuSchema;
|
||||
pub use nu_when::{NuWhen, NuWhenCustomValue, NuWhenType};
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
@ -31,6 +36,8 @@ pub enum PolarsPluginType {
|
||||
NuLazyGroupBy,
|
||||
NuWhen,
|
||||
NuPolarsTestData,
|
||||
NuDataType,
|
||||
NuSchema,
|
||||
}
|
||||
|
||||
impl fmt::Display for PolarsPluginType {
|
||||
@ -42,6 +49,8 @@ impl fmt::Display for PolarsPluginType {
|
||||
Self::NuLazyGroupBy => write!(f, "NuLazyGroupBy"),
|
||||
Self::NuWhen => write!(f, "NuWhen"),
|
||||
Self::NuPolarsTestData => write!(f, "NuPolarsTestData"),
|
||||
Self::NuDataType => write!(f, "NuDataType"),
|
||||
Self::NuSchema => write!(f, "NuSchema"),
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -54,6 +63,8 @@ pub enum PolarsPluginObject {
|
||||
NuLazyGroupBy(NuLazyGroupBy),
|
||||
NuWhen(NuWhen),
|
||||
NuPolarsTestData(Uuid, String),
|
||||
NuDataType(NuDataType),
|
||||
NuSchema(NuSchema),
|
||||
}
|
||||
|
||||
impl PolarsPluginObject {
|
||||
@ -71,6 +82,10 @@ impl PolarsPluginObject {
|
||||
NuLazyGroupBy::try_from_value(plugin, value).map(PolarsPluginObject::NuLazyGroupBy)
|
||||
} else if NuWhen::can_downcast(value) {
|
||||
NuWhen::try_from_value(plugin, value).map(PolarsPluginObject::NuWhen)
|
||||
} else if NuSchema::can_downcast(value) {
|
||||
NuSchema::try_from_value(plugin, value).map(PolarsPluginObject::NuSchema)
|
||||
} else if NuDataType::can_downcast(value) {
|
||||
NuDataType::try_from_value(plugin, value).map(PolarsPluginObject::NuDataType)
|
||||
} else {
|
||||
Err(cant_convert_err(
|
||||
value,
|
||||
@ -80,6 +95,8 @@ impl PolarsPluginObject {
|
||||
PolarsPluginType::NuExpression,
|
||||
PolarsPluginType::NuLazyGroupBy,
|
||||
PolarsPluginType::NuWhen,
|
||||
PolarsPluginType::NuDataType,
|
||||
PolarsPluginType::NuSchema,
|
||||
],
|
||||
))
|
||||
}
|
||||
@ -102,6 +119,8 @@ impl PolarsPluginObject {
|
||||
Self::NuLazyGroupBy(_) => PolarsPluginType::NuLazyGroupBy,
|
||||
Self::NuWhen(_) => PolarsPluginType::NuWhen,
|
||||
Self::NuPolarsTestData(_, _) => PolarsPluginType::NuPolarsTestData,
|
||||
Self::NuDataType(_) => PolarsPluginType::NuDataType,
|
||||
Self::NuSchema(_) => PolarsPluginType::NuSchema,
|
||||
}
|
||||
}
|
||||
|
||||
@ -113,6 +132,8 @@ impl PolarsPluginObject {
|
||||
PolarsPluginObject::NuLazyGroupBy(lg) => lg.id,
|
||||
PolarsPluginObject::NuWhen(w) => w.id,
|
||||
PolarsPluginObject::NuPolarsTestData(id, _) => *id,
|
||||
PolarsPluginObject::NuDataType(dt) => dt.id,
|
||||
PolarsPluginObject::NuSchema(schema) => schema.id,
|
||||
}
|
||||
}
|
||||
|
||||
@ -126,6 +147,8 @@ impl PolarsPluginObject {
|
||||
PolarsPluginObject::NuPolarsTestData(id, s) => {
|
||||
Value::string(format!("{id}:{s}"), Span::test_data())
|
||||
}
|
||||
PolarsPluginObject::NuDataType(dt) => dt.into_value(span),
|
||||
PolarsPluginObject::NuSchema(schema) => schema.into_value(span),
|
||||
}
|
||||
}
|
||||
|
||||
@ -151,6 +174,8 @@ pub enum CustomValueType {
|
||||
NuExpression(NuExpressionCustomValue),
|
||||
NuLazyGroupBy(NuLazyGroupByCustomValue),
|
||||
NuWhen(NuWhenCustomValue),
|
||||
NuDataType(NuDataTypeCustomValue),
|
||||
NuSchema(NuSchemaCustomValue),
|
||||
}
|
||||
|
||||
impl CustomValueType {
|
||||
@ -161,6 +186,8 @@ impl CustomValueType {
|
||||
CustomValueType::NuExpression(e_cv) => e_cv.id,
|
||||
CustomValueType::NuLazyGroupBy(lg_cv) => lg_cv.id,
|
||||
CustomValueType::NuWhen(w_cv) => w_cv.id,
|
||||
CustomValueType::NuDataType(dt_cv) => dt_cv.id,
|
||||
CustomValueType::NuSchema(schema_cv) => schema_cv.id,
|
||||
}
|
||||
}
|
||||
|
||||
@ -175,6 +202,10 @@ impl CustomValueType {
|
||||
Ok(CustomValueType::NuLazyGroupBy(lg_cv.clone()))
|
||||
} else if let Some(w_cv) = val.as_any().downcast_ref::<NuWhenCustomValue>() {
|
||||
Ok(CustomValueType::NuWhen(w_cv.clone()))
|
||||
} else if let Some(w_cv) = val.as_any().downcast_ref::<NuDataTypeCustomValue>() {
|
||||
Ok(CustomValueType::NuDataType(w_cv.clone()))
|
||||
} else if let Some(w_cv) = val.as_any().downcast_ref::<NuSchemaCustomValue>() {
|
||||
Ok(CustomValueType::NuSchema(w_cv.clone()))
|
||||
} else {
|
||||
Err(ShellError::CantConvert {
|
||||
to_type: "physical type".into(),
|
||||
@ -381,3 +412,198 @@ pub trait CustomValueSupport: Cacheable {
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
use polars::prelude::{DataType, TimeUnit, UnknownKind};
|
||||
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_dtype_str_to_schema_simple_types() {
|
||||
let dtype = "bool";
|
||||
let schema = str_to_dtype(dtype, Span::unknown()).unwrap();
|
||||
let expected = DataType::Boolean;
|
||||
assert_eq!(schema, expected);
|
||||
|
||||
let dtype = "u8";
|
||||
let schema = str_to_dtype(dtype, Span::unknown()).unwrap();
|
||||
let expected = DataType::UInt8;
|
||||
assert_eq!(schema, expected);
|
||||
|
||||
let dtype = "u16";
|
||||
let schema = str_to_dtype(dtype, Span::unknown()).unwrap();
|
||||
let expected = DataType::UInt16;
|
||||
assert_eq!(schema, expected);
|
||||
|
||||
let dtype = "u32";
|
||||
let schema = str_to_dtype(dtype, Span::unknown()).unwrap();
|
||||
let expected = DataType::UInt32;
|
||||
assert_eq!(schema, expected);
|
||||
|
||||
let dtype = "u64";
|
||||
let schema = str_to_dtype(dtype, Span::unknown()).unwrap();
|
||||
let expected = DataType::UInt64;
|
||||
assert_eq!(schema, expected);
|
||||
|
||||
let dtype = "i8";
|
||||
let schema = str_to_dtype(dtype, Span::unknown()).unwrap();
|
||||
let expected = DataType::Int8;
|
||||
assert_eq!(schema, expected);
|
||||
|
||||
let dtype = "i16";
|
||||
let schema = str_to_dtype(dtype, Span::unknown()).unwrap();
|
||||
let expected = DataType::Int16;
|
||||
assert_eq!(schema, expected);
|
||||
|
||||
let dtype = "i32";
|
||||
let schema = str_to_dtype(dtype, Span::unknown()).unwrap();
|
||||
let expected = DataType::Int32;
|
||||
assert_eq!(schema, expected);
|
||||
|
||||
let dtype = "i64";
|
||||
let schema = str_to_dtype(dtype, Span::unknown()).unwrap();
|
||||
let expected = DataType::Int64;
|
||||
assert_eq!(schema, expected);
|
||||
|
||||
let dtype = "str";
|
||||
let schema = str_to_dtype(dtype, Span::unknown()).unwrap();
|
||||
let expected = DataType::String;
|
||||
assert_eq!(schema, expected);
|
||||
|
||||
let dtype = "binary";
|
||||
let schema = str_to_dtype(dtype, Span::unknown()).unwrap();
|
||||
let expected = DataType::Binary;
|
||||
assert_eq!(schema, expected);
|
||||
|
||||
let dtype = "date";
|
||||
let schema = str_to_dtype(dtype, Span::unknown()).unwrap();
|
||||
let expected = DataType::Date;
|
||||
assert_eq!(schema, expected);
|
||||
|
||||
let dtype = "time";
|
||||
let schema = str_to_dtype(dtype, Span::unknown()).unwrap();
|
||||
let expected = DataType::Time;
|
||||
assert_eq!(schema, expected);
|
||||
|
||||
let dtype = "null";
|
||||
let schema = str_to_dtype(dtype, Span::unknown()).unwrap();
|
||||
let expected = DataType::Null;
|
||||
assert_eq!(schema, expected);
|
||||
|
||||
let dtype = "unknown";
|
||||
let schema = str_to_dtype(dtype, Span::unknown()).unwrap();
|
||||
let expected = DataType::Unknown(UnknownKind::Any);
|
||||
assert_eq!(schema, expected);
|
||||
|
||||
let dtype = "object";
|
||||
let schema = str_to_dtype(dtype, Span::unknown()).unwrap();
|
||||
let expected = DataType::Object("unknown", None);
|
||||
assert_eq!(schema, expected);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_dtype_str_schema_datetime() {
|
||||
let dtype = "datetime<ms, *>";
|
||||
let schema = str_to_dtype(dtype, Span::unknown()).unwrap();
|
||||
let expected = DataType::Datetime(TimeUnit::Milliseconds, None);
|
||||
assert_eq!(schema, expected);
|
||||
|
||||
let dtype = "datetime<us, *>";
|
||||
let schema = str_to_dtype(dtype, Span::unknown()).unwrap();
|
||||
let expected = DataType::Datetime(TimeUnit::Microseconds, None);
|
||||
assert_eq!(schema, expected);
|
||||
|
||||
let dtype = "datetime<μs, *>";
|
||||
let schema = str_to_dtype(dtype, Span::unknown()).unwrap();
|
||||
let expected = DataType::Datetime(TimeUnit::Microseconds, None);
|
||||
assert_eq!(schema, expected);
|
||||
|
||||
let dtype = "datetime<ns, *>";
|
||||
let schema = str_to_dtype(dtype, Span::unknown()).unwrap();
|
||||
let expected = DataType::Datetime(TimeUnit::Nanoseconds, None);
|
||||
assert_eq!(schema, expected);
|
||||
|
||||
let dtype = "datetime<ms, UTC>";
|
||||
let schema = str_to_dtype(dtype, Span::unknown()).unwrap();
|
||||
let expected = DataType::Datetime(TimeUnit::Milliseconds, Some("UTC".into()));
|
||||
assert_eq!(schema, expected);
|
||||
|
||||
let dtype = "invalid";
|
||||
let schema = str_to_dtype(dtype, Span::unknown());
|
||||
assert!(schema.is_err())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_dtype_str_schema_duration() {
|
||||
let dtype = "duration<ms>";
|
||||
let schema = str_to_dtype(dtype, Span::unknown()).unwrap();
|
||||
let expected = DataType::Duration(TimeUnit::Milliseconds);
|
||||
assert_eq!(schema, expected);
|
||||
|
||||
let dtype = "duration<us>";
|
||||
let schema = str_to_dtype(dtype, Span::unknown()).unwrap();
|
||||
let expected = DataType::Duration(TimeUnit::Microseconds);
|
||||
assert_eq!(schema, expected);
|
||||
|
||||
let dtype = "duration<μs>";
|
||||
let schema = str_to_dtype(dtype, Span::unknown()).unwrap();
|
||||
let expected = DataType::Duration(TimeUnit::Microseconds);
|
||||
assert_eq!(schema, expected);
|
||||
|
||||
let dtype = "duration<ns>";
|
||||
let schema = str_to_dtype(dtype, Span::unknown()).unwrap();
|
||||
let expected = DataType::Duration(TimeUnit::Nanoseconds);
|
||||
assert_eq!(schema, expected);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_dtype_str_schema_decimal() {
|
||||
let dtype = "decimal<7,2>";
|
||||
let schema = str_to_dtype(dtype, Span::unknown()).unwrap();
|
||||
let expected = DataType::Decimal(Some(7usize), Some(2usize));
|
||||
assert_eq!(schema, expected);
|
||||
|
||||
// "*" is not a permitted value for scale
|
||||
let dtype = "decimal<7,*>";
|
||||
let schema = str_to_dtype(dtype, Span::unknown());
|
||||
assert!(matches!(schema, Err(ShellError::GenericError { .. })));
|
||||
|
||||
let dtype = "decimal<*,2>";
|
||||
let schema = str_to_dtype(dtype, Span::unknown()).unwrap();
|
||||
let expected = DataType::Decimal(None, Some(2usize));
|
||||
assert_eq!(schema, expected);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_dtype_str_to_schema_list_types() {
|
||||
let dtype = "list<i32>";
|
||||
let schema = str_to_dtype(dtype, Span::unknown()).unwrap();
|
||||
let expected = DataType::List(Box::new(DataType::Int32));
|
||||
assert_eq!(schema, expected);
|
||||
|
||||
let dtype = "list<duration<ms>>";
|
||||
let schema = str_to_dtype(dtype, Span::unknown()).unwrap();
|
||||
let expected = DataType::List(Box::new(DataType::Duration(TimeUnit::Milliseconds)));
|
||||
assert_eq!(schema, expected);
|
||||
|
||||
let dtype = "list<datetime<ms, *>>";
|
||||
let schema = str_to_dtype(dtype, Span::unknown()).unwrap();
|
||||
let expected = DataType::List(Box::new(DataType::Datetime(TimeUnit::Milliseconds, None)));
|
||||
assert_eq!(schema, expected);
|
||||
|
||||
let dtype = "list<decimal<7,2>>";
|
||||
let schema = str_to_dtype(dtype, Span::unknown()).unwrap();
|
||||
let expected = DataType::List(Box::new(DataType::Decimal(Some(7usize), Some(2usize))));
|
||||
assert_eq!(schema, expected);
|
||||
|
||||
let dtype = "list<decimal<*,2>>";
|
||||
let schema = str_to_dtype(dtype, Span::unknown()).unwrap();
|
||||
let expected = DataType::List(Box::new(DataType::Decimal(None, Some(2usize))));
|
||||
assert_eq!(schema, expected);
|
||||
|
||||
let dtype = "list<decimal<7,*>>";
|
||||
let schema = str_to_dtype(dtype, Span::unknown());
|
||||
assert!(matches!(schema, Err(ShellError::GenericError { .. })));
|
||||
}
|
||||
}
|
||||
|
@ -0,0 +1,65 @@
|
||||
use nu_protocol::{CustomValue, ShellError, Span, Value};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::values::{CustomValueSupport, PolarsPluginCustomValue};
|
||||
|
||||
use super::NuDataType;
|
||||
|
||||
#[derive(Clone, Debug, Deserialize, Serialize)]
|
||||
pub struct NuDataTypeCustomValue {
|
||||
pub id: Uuid,
|
||||
#[serde(skip)]
|
||||
pub datatype: Option<NuDataType>,
|
||||
}
|
||||
|
||||
#[typetag::serde]
|
||||
impl CustomValue for NuDataTypeCustomValue {
|
||||
fn clone_value(&self, span: nu_protocol::Span) -> Value {
|
||||
Value::custom(Box::new(self.clone()), span)
|
||||
}
|
||||
|
||||
fn type_name(&self) -> String {
|
||||
"NuDataType".into()
|
||||
}
|
||||
|
||||
fn to_base_value(&self, span: Span) -> Result<Value, ShellError> {
|
||||
Ok(Value::string(
|
||||
"NuDataType: custom_value_to_base_value should've been called",
|
||||
span,
|
||||
))
|
||||
}
|
||||
|
||||
fn as_mut_any(&mut self) -> &mut dyn std::any::Any {
|
||||
self
|
||||
}
|
||||
|
||||
fn as_any(&self) -> &dyn std::any::Any {
|
||||
self
|
||||
}
|
||||
|
||||
fn notify_plugin_on_drop(&self) -> bool {
|
||||
true
|
||||
}
|
||||
}
|
||||
|
||||
impl PolarsPluginCustomValue for NuDataTypeCustomValue {
|
||||
type PolarsPluginObjectType = NuDataType;
|
||||
|
||||
fn id(&self) -> &Uuid {
|
||||
&self.id
|
||||
}
|
||||
|
||||
fn internal(&self) -> &Option<Self::PolarsPluginObjectType> {
|
||||
&self.datatype
|
||||
}
|
||||
|
||||
fn custom_value_to_base_value(
|
||||
&self,
|
||||
plugin: &crate::PolarsPlugin,
|
||||
_engine: &nu_plugin::EngineInterface,
|
||||
) -> Result<Value, ShellError> {
|
||||
let dtype = NuDataType::try_from_custom_value(plugin, self)?;
|
||||
dtype.base_value(Span::unknown())
|
||||
}
|
||||
}
|
316
crates/nu_plugin_polars/src/dataframe/values/nu_dtype/mod.rs
Normal file
316
crates/nu_plugin_polars/src/dataframe/values/nu_dtype/mod.rs
Normal file
@ -0,0 +1,316 @@
|
||||
pub mod custom_value;
|
||||
|
||||
use custom_value::NuDataTypeCustomValue;
|
||||
use nu_protocol::{record, ShellError, Span, Value};
|
||||
use polars::prelude::{DataType, PlSmallStr, TimeUnit, UnknownKind};
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::{Cacheable, PolarsPlugin};
|
||||
|
||||
use super::{nu_schema::dtype_to_value, CustomValueSupport, PolarsPluginObject, PolarsPluginType};
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct NuDataType {
|
||||
pub id: uuid::Uuid,
|
||||
dtype: DataType,
|
||||
}
|
||||
|
||||
impl NuDataType {
|
||||
pub fn new(dtype: DataType) -> Self {
|
||||
Self {
|
||||
id: uuid::Uuid::new_v4(),
|
||||
dtype,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn new_with_str(dtype: &str, span: Span) -> Result<Self, ShellError> {
|
||||
let dtype = str_to_dtype(dtype, span)?;
|
||||
Ok(Self {
|
||||
id: uuid::Uuid::new_v4(),
|
||||
dtype,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn to_polars(&self) -> DataType {
|
||||
self.dtype.clone()
|
||||
}
|
||||
}
|
||||
|
||||
impl From<NuDataType> for Value {
|
||||
fn from(nu_dtype: NuDataType) -> Self {
|
||||
Value::String {
|
||||
val: nu_dtype.dtype.to_string(),
|
||||
internal_span: Span::unknown(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Cacheable for NuDataType {
|
||||
fn cache_id(&self) -> &Uuid {
|
||||
&self.id
|
||||
}
|
||||
|
||||
fn to_cache_value(&self) -> Result<super::PolarsPluginObject, ShellError> {
|
||||
Ok(PolarsPluginObject::NuDataType(self.clone()))
|
||||
}
|
||||
|
||||
fn from_cache_value(cv: super::PolarsPluginObject) -> Result<Self, ShellError> {
|
||||
match cv {
|
||||
PolarsPluginObject::NuDataType(dt) => Ok(dt),
|
||||
_ => Err(ShellError::GenericError {
|
||||
error: "Cache value is not a dataframe".into(),
|
||||
msg: "".into(),
|
||||
span: None,
|
||||
help: None,
|
||||
inner: vec![],
|
||||
}),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl CustomValueSupport for NuDataType {
|
||||
type CV = NuDataTypeCustomValue;
|
||||
|
||||
fn get_type_static() -> super::PolarsPluginType {
|
||||
PolarsPluginType::NuDataType
|
||||
}
|
||||
|
||||
fn custom_value(self) -> Self::CV {
|
||||
NuDataTypeCustomValue {
|
||||
id: self.id,
|
||||
datatype: Some(self),
|
||||
}
|
||||
}
|
||||
|
||||
fn base_value(self, span: Span) -> Result<Value, ShellError> {
|
||||
Ok(dtype_to_value(&self.dtype, span))
|
||||
}
|
||||
|
||||
fn try_from_value(plugin: &PolarsPlugin, value: &Value) -> Result<Self, ShellError> {
|
||||
match value {
|
||||
Value::Custom { val, .. } => {
|
||||
if let Some(cv) = val.as_any().downcast_ref::<Self::CV>() {
|
||||
Self::try_from_custom_value(plugin, cv)
|
||||
} else {
|
||||
Err(ShellError::CantConvert {
|
||||
to_type: Self::get_type_static().to_string(),
|
||||
from_type: value.get_type().to_string(),
|
||||
span: value.span(),
|
||||
help: None,
|
||||
})
|
||||
}
|
||||
}
|
||||
Value::String { val, internal_span } => NuDataType::new_with_str(val, *internal_span),
|
||||
_ => Err(ShellError::CantConvert {
|
||||
to_type: Self::get_type_static().to_string(),
|
||||
from_type: value.get_type().to_string(),
|
||||
span: value.span(),
|
||||
help: None,
|
||||
}),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn datatype_list(span: Span) -> Value {
|
||||
let types: Vec<Value> = [
|
||||
("null", ""),
|
||||
("bool", ""),
|
||||
("u8", ""),
|
||||
("u16", ""),
|
||||
("u32", ""),
|
||||
("u64", ""),
|
||||
("i8", ""),
|
||||
("i16", ""),
|
||||
("i32", ""),
|
||||
("i64", ""),
|
||||
("f32", ""),
|
||||
("f64", ""),
|
||||
("str", ""),
|
||||
("binary", ""),
|
||||
("date", ""),
|
||||
("datetime<time_unit: (ms, us, ns) timezone (optional)>", "Time Unit can be: milliseconds: ms, microseconds: us, nanoseconds: ns. Timezone wildcard is *. Other Timezone examples: UTC, America/Los_Angeles."),
|
||||
("duration<time_unit: (ms, us, ns)>", "Time Unit can be: milliseconds: ms, microseconds: us, nanoseconds: ns."),
|
||||
("time", ""),
|
||||
("object", ""),
|
||||
("unknown", ""),
|
||||
("list<dtype>", ""),
|
||||
]
|
||||
.iter()
|
||||
.map(|(dtype, note)| {
|
||||
Value::record(record! {
|
||||
"dtype" => Value::string(*dtype, span),
|
||||
"note" => Value::string(*note, span),
|
||||
},
|
||||
span)
|
||||
})
|
||||
.collect();
|
||||
Value::list(types, span)
|
||||
}
|
||||
|
||||
pub fn str_to_dtype(dtype: &str, span: Span) -> Result<DataType, ShellError> {
|
||||
match dtype {
|
||||
"bool" => Ok(DataType::Boolean),
|
||||
"u8" => Ok(DataType::UInt8),
|
||||
"u16" => Ok(DataType::UInt16),
|
||||
"u32" => Ok(DataType::UInt32),
|
||||
"u64" => Ok(DataType::UInt64),
|
||||
"i8" => Ok(DataType::Int8),
|
||||
"i16" => Ok(DataType::Int16),
|
||||
"i32" => Ok(DataType::Int32),
|
||||
"i64" => Ok(DataType::Int64),
|
||||
"f32" => Ok(DataType::Float32),
|
||||
"f64" => Ok(DataType::Float64),
|
||||
"str" => Ok(DataType::String),
|
||||
"binary" => Ok(DataType::Binary),
|
||||
"date" => Ok(DataType::Date),
|
||||
"time" => Ok(DataType::Time),
|
||||
"null" => Ok(DataType::Null),
|
||||
"unknown" => Ok(DataType::Unknown(UnknownKind::Any)),
|
||||
"object" => Ok(DataType::Object("unknown", None)),
|
||||
_ if dtype.starts_with("list") => {
|
||||
let dtype = dtype
|
||||
.trim_start_matches("list")
|
||||
.trim_start_matches('<')
|
||||
.trim_end_matches('>')
|
||||
.trim();
|
||||
let dtype = str_to_dtype(dtype, span)?;
|
||||
Ok(DataType::List(Box::new(dtype)))
|
||||
}
|
||||
_ if dtype.starts_with("datetime") => {
|
||||
let dtype = dtype
|
||||
.trim_start_matches("datetime")
|
||||
.trim_start_matches('<')
|
||||
.trim_end_matches('>');
|
||||
let mut split = dtype.split(',');
|
||||
let next = split
|
||||
.next()
|
||||
.ok_or_else(|| ShellError::GenericError {
|
||||
error: "Invalid polars data type".into(),
|
||||
msg: "Missing time unit".into(),
|
||||
span: Some(span),
|
||||
help: None,
|
||||
inner: vec![],
|
||||
})?
|
||||
.trim();
|
||||
let time_unit = str_to_time_unit(next, span)?;
|
||||
let next = split
|
||||
.next()
|
||||
.ok_or_else(|| ShellError::GenericError {
|
||||
error: "Invalid polars data type".into(),
|
||||
msg: "Missing time zone".into(),
|
||||
span: Some(span),
|
||||
help: None,
|
||||
inner: vec![],
|
||||
})?
|
||||
.trim();
|
||||
let timezone = if "*" == next {
|
||||
None
|
||||
} else {
|
||||
Some(next.to_string())
|
||||
};
|
||||
Ok(DataType::Datetime(
|
||||
time_unit,
|
||||
timezone.map(PlSmallStr::from),
|
||||
))
|
||||
}
|
||||
_ if dtype.starts_with("duration") => {
|
||||
let inner = dtype.trim_start_matches("duration<").trim_end_matches('>');
|
||||
let next = inner
|
||||
.split(',')
|
||||
.next()
|
||||
.ok_or_else(|| ShellError::GenericError {
|
||||
error: "Invalid polars data type".into(),
|
||||
msg: "Missing time unit".into(),
|
||||
span: Some(span),
|
||||
help: None,
|
||||
inner: vec![],
|
||||
})?
|
||||
.trim();
|
||||
let time_unit = str_to_time_unit(next, span)?;
|
||||
Ok(DataType::Duration(time_unit))
|
||||
}
|
||||
_ if dtype.starts_with("decimal") => {
|
||||
let dtype = dtype
|
||||
.trim_start_matches("decimal")
|
||||
.trim_start_matches('<')
|
||||
.trim_end_matches('>');
|
||||
let mut split = dtype.split(',');
|
||||
let next = split
|
||||
.next()
|
||||
.ok_or_else(|| ShellError::GenericError {
|
||||
error: "Invalid polars data type".into(),
|
||||
msg: "Missing decimal precision".into(),
|
||||
span: Some(span),
|
||||
help: None,
|
||||
inner: vec![],
|
||||
})?
|
||||
.trim();
|
||||
let precision = match next {
|
||||
"*" => None, // infer
|
||||
_ => Some(
|
||||
next.parse::<usize>()
|
||||
.map_err(|e| ShellError::GenericError {
|
||||
error: "Invalid polars data type".into(),
|
||||
msg: format!("Error in parsing decimal precision: {e}"),
|
||||
span: Some(span),
|
||||
help: None,
|
||||
inner: vec![],
|
||||
})?,
|
||||
),
|
||||
};
|
||||
|
||||
let next = split
|
||||
.next()
|
||||
.ok_or_else(|| ShellError::GenericError {
|
||||
error: "Invalid polars data type".into(),
|
||||
msg: "Missing decimal scale".into(),
|
||||
span: Some(span),
|
||||
help: None,
|
||||
inner: vec![],
|
||||
})?
|
||||
.trim();
|
||||
let scale = match next {
|
||||
"*" => Err(ShellError::GenericError {
|
||||
error: "Invalid polars data type".into(),
|
||||
msg: "`*` is not a permitted value for scale".into(),
|
||||
span: Some(span),
|
||||
help: None,
|
||||
inner: vec![],
|
||||
}),
|
||||
_ => next
|
||||
.parse::<usize>()
|
||||
.map(Some)
|
||||
.map_err(|e| ShellError::GenericError {
|
||||
error: "Invalid polars data type".into(),
|
||||
msg: format!("Error in parsing decimal precision: {e}"),
|
||||
span: Some(span),
|
||||
help: None,
|
||||
inner: vec![],
|
||||
}),
|
||||
}?;
|
||||
Ok(DataType::Decimal(precision, scale))
|
||||
}
|
||||
_ => Err(ShellError::GenericError {
|
||||
error: "Invalid polars data type".into(),
|
||||
msg: format!("Unknown type: {dtype}"),
|
||||
span: Some(span),
|
||||
help: None,
|
||||
inner: vec![],
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
fn str_to_time_unit(ts_string: &str, span: Span) -> Result<TimeUnit, ShellError> {
|
||||
match ts_string {
|
||||
"ms" => Ok(TimeUnit::Milliseconds),
|
||||
"us" | "μs" => Ok(TimeUnit::Microseconds),
|
||||
"ns" => Ok(TimeUnit::Nanoseconds),
|
||||
_ => Err(ShellError::GenericError {
|
||||
error: "Invalid polars data type".into(),
|
||||
msg: "Invalid time unit".into(),
|
||||
span: Some(span),
|
||||
help: None,
|
||||
inner: vec![],
|
||||
}),
|
||||
}
|
||||
}
|
@ -1,480 +0,0 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use nu_protocol::{ShellError, Span, Value};
|
||||
use polars::{
|
||||
datatypes::UnknownKind,
|
||||
prelude::{DataType, Field, PlSmallStr, Schema, SchemaExt, SchemaRef, TimeUnit},
|
||||
};
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct NuSchema {
|
||||
pub schema: SchemaRef,
|
||||
}
|
||||
|
||||
impl NuSchema {
|
||||
pub fn new(schema: SchemaRef) -> Self {
|
||||
Self { schema }
|
||||
}
|
||||
}
|
||||
|
||||
impl TryFrom<&Value> for NuSchema {
|
||||
type Error = ShellError;
|
||||
fn try_from(value: &Value) -> Result<Self, Self::Error> {
|
||||
let schema = value_to_schema(value, Span::unknown())?;
|
||||
Ok(Self::new(Arc::new(schema)))
|
||||
}
|
||||
}
|
||||
|
||||
impl From<NuSchema> for Value {
|
||||
fn from(schema: NuSchema) -> Self {
|
||||
fields_to_value(schema.schema.iter_fields(), Span::unknown())
|
||||
}
|
||||
}
|
||||
|
||||
impl From<NuSchema> for SchemaRef {
|
||||
fn from(val: NuSchema) -> Self {
|
||||
Arc::clone(&val.schema)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<SchemaRef> for NuSchema {
|
||||
fn from(val: SchemaRef) -> Self {
|
||||
Self { schema: val }
|
||||
}
|
||||
}
|
||||
|
||||
fn fields_to_value(fields: impl Iterator<Item = Field>, span: Span) -> Value {
|
||||
let record = fields
|
||||
.map(|field| {
|
||||
let col = field.name().to_string();
|
||||
let val = dtype_to_value(field.dtype(), span);
|
||||
(col, val)
|
||||
})
|
||||
.collect();
|
||||
|
||||
Value::record(record, Span::unknown())
|
||||
}
|
||||
|
||||
fn dtype_to_value(dtype: &DataType, span: Span) -> Value {
|
||||
match dtype {
|
||||
DataType::Struct(fields) => fields_to_value(fields.iter().cloned(), span),
|
||||
_ => Value::string(dtype.to_string().replace('[', "<").replace(']', ">"), span),
|
||||
}
|
||||
}
|
||||
|
||||
fn value_to_schema(value: &Value, span: Span) -> Result<Schema, ShellError> {
|
||||
let fields = value_to_fields(value, span)?;
|
||||
let schema = Schema::from_iter(fields);
|
||||
Ok(schema)
|
||||
}
|
||||
|
||||
fn value_to_fields(value: &Value, span: Span) -> Result<Vec<Field>, ShellError> {
|
||||
let fields = value
|
||||
.as_record()?
|
||||
.into_iter()
|
||||
.map(|(col, val)| match val {
|
||||
Value::Record { .. } => {
|
||||
let fields = value_to_fields(val, span)?;
|
||||
let dtype = DataType::Struct(fields);
|
||||
Ok(Field::new(col.into(), dtype))
|
||||
}
|
||||
_ => {
|
||||
let dtype = str_to_dtype(&val.coerce_string()?, span)?;
|
||||
Ok(Field::new(col.into(), dtype))
|
||||
}
|
||||
})
|
||||
.collect::<Result<Vec<Field>, ShellError>>()?;
|
||||
Ok(fields)
|
||||
}
|
||||
|
||||
pub fn str_to_dtype(dtype: &str, span: Span) -> Result<DataType, ShellError> {
|
||||
match dtype {
|
||||
"bool" => Ok(DataType::Boolean),
|
||||
"u8" => Ok(DataType::UInt8),
|
||||
"u16" => Ok(DataType::UInt16),
|
||||
"u32" => Ok(DataType::UInt32),
|
||||
"u64" => Ok(DataType::UInt64),
|
||||
"i8" => Ok(DataType::Int8),
|
||||
"i16" => Ok(DataType::Int16),
|
||||
"i32" => Ok(DataType::Int32),
|
||||
"i64" => Ok(DataType::Int64),
|
||||
"f32" => Ok(DataType::Float32),
|
||||
"f64" => Ok(DataType::Float64),
|
||||
"str" => Ok(DataType::String),
|
||||
"binary" => Ok(DataType::Binary),
|
||||
"date" => Ok(DataType::Date),
|
||||
"time" => Ok(DataType::Time),
|
||||
"null" => Ok(DataType::Null),
|
||||
"unknown" => Ok(DataType::Unknown(UnknownKind::Any)),
|
||||
"object" => Ok(DataType::Object("unknown", None)),
|
||||
_ if dtype.starts_with("list") => {
|
||||
let dtype = dtype
|
||||
.trim_start_matches("list")
|
||||
.trim_start_matches('<')
|
||||
.trim_end_matches('>')
|
||||
.trim();
|
||||
let dtype = str_to_dtype(dtype, span)?;
|
||||
Ok(DataType::List(Box::new(dtype)))
|
||||
}
|
||||
_ if dtype.starts_with("datetime") => {
|
||||
let dtype = dtype
|
||||
.trim_start_matches("datetime")
|
||||
.trim_start_matches('<')
|
||||
.trim_end_matches('>');
|
||||
let mut split = dtype.split(',');
|
||||
let next = split
|
||||
.next()
|
||||
.ok_or_else(|| ShellError::GenericError {
|
||||
error: "Invalid polars data type".into(),
|
||||
msg: "Missing time unit".into(),
|
||||
span: Some(span),
|
||||
help: None,
|
||||
inner: vec![],
|
||||
})?
|
||||
.trim();
|
||||
let time_unit = str_to_time_unit(next, span)?;
|
||||
let next = split
|
||||
.next()
|
||||
.ok_or_else(|| ShellError::GenericError {
|
||||
error: "Invalid polars data type".into(),
|
||||
msg: "Missing time zone".into(),
|
||||
span: Some(span),
|
||||
help: None,
|
||||
inner: vec![],
|
||||
})?
|
||||
.trim();
|
||||
let timezone = if "*" == next {
|
||||
None
|
||||
} else {
|
||||
Some(next.to_string())
|
||||
};
|
||||
Ok(DataType::Datetime(
|
||||
time_unit,
|
||||
timezone.map(PlSmallStr::from),
|
||||
))
|
||||
}
|
||||
_ if dtype.starts_with("duration") => {
|
||||
let inner = dtype.trim_start_matches("duration<").trim_end_matches('>');
|
||||
let next = inner
|
||||
.split(',')
|
||||
.next()
|
||||
.ok_or_else(|| ShellError::GenericError {
|
||||
error: "Invalid polars data type".into(),
|
||||
msg: "Missing time unit".into(),
|
||||
span: Some(span),
|
||||
help: None,
|
||||
inner: vec![],
|
||||
})?
|
||||
.trim();
|
||||
let time_unit = str_to_time_unit(next, span)?;
|
||||
Ok(DataType::Duration(time_unit))
|
||||
}
|
||||
_ if dtype.starts_with("decimal") => {
|
||||
let dtype = dtype
|
||||
.trim_start_matches("decimal")
|
||||
.trim_start_matches('<')
|
||||
.trim_end_matches('>');
|
||||
let mut split = dtype.split(',');
|
||||
let next = split
|
||||
.next()
|
||||
.ok_or_else(|| ShellError::GenericError {
|
||||
error: "Invalid polars data type".into(),
|
||||
msg: "Missing decimal precision".into(),
|
||||
span: Some(span),
|
||||
help: None,
|
||||
inner: vec![],
|
||||
})?
|
||||
.trim();
|
||||
let precision = match next {
|
||||
"*" => None, // infer
|
||||
_ => Some(
|
||||
next.parse::<usize>()
|
||||
.map_err(|e| ShellError::GenericError {
|
||||
error: "Invalid polars data type".into(),
|
||||
msg: format!("Error in parsing decimal precision: {e}"),
|
||||
span: Some(span),
|
||||
help: None,
|
||||
inner: vec![],
|
||||
})?,
|
||||
),
|
||||
};
|
||||
|
||||
let next = split
|
||||
.next()
|
||||
.ok_or_else(|| ShellError::GenericError {
|
||||
error: "Invalid polars data type".into(),
|
||||
msg: "Missing decimal scale".into(),
|
||||
span: Some(span),
|
||||
help: None,
|
||||
inner: vec![],
|
||||
})?
|
||||
.trim();
|
||||
let scale = match next {
|
||||
"*" => Err(ShellError::GenericError {
|
||||
error: "Invalid polars data type".into(),
|
||||
msg: "`*` is not a permitted value for scale".into(),
|
||||
span: Some(span),
|
||||
help: None,
|
||||
inner: vec![],
|
||||
}),
|
||||
_ => next
|
||||
.parse::<usize>()
|
||||
.map(Some)
|
||||
.map_err(|e| ShellError::GenericError {
|
||||
error: "Invalid polars data type".into(),
|
||||
msg: format!("Error in parsing decimal precision: {e}"),
|
||||
span: Some(span),
|
||||
help: None,
|
||||
inner: vec![],
|
||||
}),
|
||||
}?;
|
||||
Ok(DataType::Decimal(precision, scale))
|
||||
}
|
||||
_ => Err(ShellError::GenericError {
|
||||
error: "Invalid polars data type".into(),
|
||||
msg: format!("Unknown type: {dtype}"),
|
||||
span: Some(span),
|
||||
help: None,
|
||||
inner: vec![],
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
fn str_to_time_unit(ts_string: &str, span: Span) -> Result<TimeUnit, ShellError> {
|
||||
match ts_string {
|
||||
"ms" => Ok(TimeUnit::Milliseconds),
|
||||
"us" | "μs" => Ok(TimeUnit::Microseconds),
|
||||
"ns" => Ok(TimeUnit::Nanoseconds),
|
||||
_ => Err(ShellError::GenericError {
|
||||
error: "Invalid polars data type".into(),
|
||||
msg: "Invalid time unit".into(),
|
||||
span: Some(span),
|
||||
help: None,
|
||||
inner: vec![],
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
|
||||
use nu_protocol::record;
|
||||
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_value_to_schema() {
|
||||
let address = record! {
|
||||
"street" => Value::test_string("str"),
|
||||
"city" => Value::test_string("str"),
|
||||
};
|
||||
|
||||
let value = Value::test_record(record! {
|
||||
"name" => Value::test_string("str"),
|
||||
"age" => Value::test_string("i32"),
|
||||
"address" => Value::test_record(address)
|
||||
});
|
||||
|
||||
let schema = value_to_schema(&value, Span::unknown()).unwrap();
|
||||
let expected = Schema::from_iter(vec![
|
||||
Field::new("name".into(), DataType::String),
|
||||
Field::new("age".into(), DataType::Int32),
|
||||
Field::new(
|
||||
"address".into(),
|
||||
DataType::Struct(vec![
|
||||
Field::new("street".into(), DataType::String),
|
||||
Field::new("city".into(), DataType::String),
|
||||
]),
|
||||
),
|
||||
]);
|
||||
assert_eq!(schema, expected);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_dtype_str_to_schema_simple_types() {
|
||||
let dtype = "bool";
|
||||
let schema = str_to_dtype(dtype, Span::unknown()).unwrap();
|
||||
let expected = DataType::Boolean;
|
||||
assert_eq!(schema, expected);
|
||||
|
||||
let dtype = "u8";
|
||||
let schema = str_to_dtype(dtype, Span::unknown()).unwrap();
|
||||
let expected = DataType::UInt8;
|
||||
assert_eq!(schema, expected);
|
||||
|
||||
let dtype = "u16";
|
||||
let schema = str_to_dtype(dtype, Span::unknown()).unwrap();
|
||||
let expected = DataType::UInt16;
|
||||
assert_eq!(schema, expected);
|
||||
|
||||
let dtype = "u32";
|
||||
let schema = str_to_dtype(dtype, Span::unknown()).unwrap();
|
||||
let expected = DataType::UInt32;
|
||||
assert_eq!(schema, expected);
|
||||
|
||||
let dtype = "u64";
|
||||
let schema = str_to_dtype(dtype, Span::unknown()).unwrap();
|
||||
let expected = DataType::UInt64;
|
||||
assert_eq!(schema, expected);
|
||||
|
||||
let dtype = "i8";
|
||||
let schema = str_to_dtype(dtype, Span::unknown()).unwrap();
|
||||
let expected = DataType::Int8;
|
||||
assert_eq!(schema, expected);
|
||||
|
||||
let dtype = "i16";
|
||||
let schema = str_to_dtype(dtype, Span::unknown()).unwrap();
|
||||
let expected = DataType::Int16;
|
||||
assert_eq!(schema, expected);
|
||||
|
||||
let dtype = "i32";
|
||||
let schema = str_to_dtype(dtype, Span::unknown()).unwrap();
|
||||
let expected = DataType::Int32;
|
||||
assert_eq!(schema, expected);
|
||||
|
||||
let dtype = "i64";
|
||||
let schema = str_to_dtype(dtype, Span::unknown()).unwrap();
|
||||
let expected = DataType::Int64;
|
||||
assert_eq!(schema, expected);
|
||||
|
||||
let dtype = "str";
|
||||
let schema = str_to_dtype(dtype, Span::unknown()).unwrap();
|
||||
let expected = DataType::String;
|
||||
assert_eq!(schema, expected);
|
||||
|
||||
let dtype = "binary";
|
||||
let schema = str_to_dtype(dtype, Span::unknown()).unwrap();
|
||||
let expected = DataType::Binary;
|
||||
assert_eq!(schema, expected);
|
||||
|
||||
let dtype = "date";
|
||||
let schema = str_to_dtype(dtype, Span::unknown()).unwrap();
|
||||
let expected = DataType::Date;
|
||||
assert_eq!(schema, expected);
|
||||
|
||||
let dtype = "time";
|
||||
let schema = str_to_dtype(dtype, Span::unknown()).unwrap();
|
||||
let expected = DataType::Time;
|
||||
assert_eq!(schema, expected);
|
||||
|
||||
let dtype = "null";
|
||||
let schema = str_to_dtype(dtype, Span::unknown()).unwrap();
|
||||
let expected = DataType::Null;
|
||||
assert_eq!(schema, expected);
|
||||
|
||||
let dtype = "unknown";
|
||||
let schema = str_to_dtype(dtype, Span::unknown()).unwrap();
|
||||
let expected = DataType::Unknown(UnknownKind::Any);
|
||||
assert_eq!(schema, expected);
|
||||
|
||||
let dtype = "object";
|
||||
let schema = str_to_dtype(dtype, Span::unknown()).unwrap();
|
||||
let expected = DataType::Object("unknown", None);
|
||||
assert_eq!(schema, expected);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_dtype_str_schema_datetime() {
|
||||
let dtype = "datetime<ms, *>";
|
||||
let schema = str_to_dtype(dtype, Span::unknown()).unwrap();
|
||||
let expected = DataType::Datetime(TimeUnit::Milliseconds, None);
|
||||
assert_eq!(schema, expected);
|
||||
|
||||
let dtype = "datetime<us, *>";
|
||||
let schema = str_to_dtype(dtype, Span::unknown()).unwrap();
|
||||
let expected = DataType::Datetime(TimeUnit::Microseconds, None);
|
||||
assert_eq!(schema, expected);
|
||||
|
||||
let dtype = "datetime<μs, *>";
|
||||
let schema = str_to_dtype(dtype, Span::unknown()).unwrap();
|
||||
let expected = DataType::Datetime(TimeUnit::Microseconds, None);
|
||||
assert_eq!(schema, expected);
|
||||
|
||||
let dtype = "datetime<ns, *>";
|
||||
let schema = str_to_dtype(dtype, Span::unknown()).unwrap();
|
||||
let expected = DataType::Datetime(TimeUnit::Nanoseconds, None);
|
||||
assert_eq!(schema, expected);
|
||||
|
||||
let dtype = "datetime<ms, UTC>";
|
||||
let schema = str_to_dtype(dtype, Span::unknown()).unwrap();
|
||||
let expected = DataType::Datetime(TimeUnit::Milliseconds, Some("UTC".into()));
|
||||
assert_eq!(schema, expected);
|
||||
|
||||
let dtype = "invalid";
|
||||
let schema = str_to_dtype(dtype, Span::unknown());
|
||||
assert!(schema.is_err())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_dtype_str_schema_duration() {
|
||||
let dtype = "duration<ms>";
|
||||
let schema = str_to_dtype(dtype, Span::unknown()).unwrap();
|
||||
let expected = DataType::Duration(TimeUnit::Milliseconds);
|
||||
assert_eq!(schema, expected);
|
||||
|
||||
let dtype = "duration<us>";
|
||||
let schema = str_to_dtype(dtype, Span::unknown()).unwrap();
|
||||
let expected = DataType::Duration(TimeUnit::Microseconds);
|
||||
assert_eq!(schema, expected);
|
||||
|
||||
let dtype = "duration<μs>";
|
||||
let schema = str_to_dtype(dtype, Span::unknown()).unwrap();
|
||||
let expected = DataType::Duration(TimeUnit::Microseconds);
|
||||
assert_eq!(schema, expected);
|
||||
|
||||
let dtype = "duration<ns>";
|
||||
let schema = str_to_dtype(dtype, Span::unknown()).unwrap();
|
||||
let expected = DataType::Duration(TimeUnit::Nanoseconds);
|
||||
assert_eq!(schema, expected);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_dtype_str_schema_decimal() {
|
||||
let dtype = "decimal<7,2>";
|
||||
let schema = str_to_dtype(dtype, Span::unknown()).unwrap();
|
||||
let expected = DataType::Decimal(Some(7usize), Some(2usize));
|
||||
assert_eq!(schema, expected);
|
||||
|
||||
// "*" is not a permitted value for scale
|
||||
let dtype = "decimal<7,*>";
|
||||
let schema = str_to_dtype(dtype, Span::unknown());
|
||||
assert!(matches!(schema, Err(ShellError::GenericError { .. })));
|
||||
|
||||
let dtype = "decimal<*,2>";
|
||||
let schema = str_to_dtype(dtype, Span::unknown()).unwrap();
|
||||
let expected = DataType::Decimal(None, Some(2usize));
|
||||
assert_eq!(schema, expected);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_dtype_str_to_schema_list_types() {
|
||||
let dtype = "list<i32>";
|
||||
let schema = str_to_dtype(dtype, Span::unknown()).unwrap();
|
||||
let expected = DataType::List(Box::new(DataType::Int32));
|
||||
assert_eq!(schema, expected);
|
||||
|
||||
let dtype = "list<duration<ms>>";
|
||||
let schema = str_to_dtype(dtype, Span::unknown()).unwrap();
|
||||
let expected = DataType::List(Box::new(DataType::Duration(TimeUnit::Milliseconds)));
|
||||
assert_eq!(schema, expected);
|
||||
|
||||
let dtype = "list<datetime<ms, *>>";
|
||||
let schema = str_to_dtype(dtype, Span::unknown()).unwrap();
|
||||
let expected = DataType::List(Box::new(DataType::Datetime(TimeUnit::Milliseconds, None)));
|
||||
assert_eq!(schema, expected);
|
||||
|
||||
let dtype = "list<decimal<7,2>>";
|
||||
let schema = str_to_dtype(dtype, Span::unknown()).unwrap();
|
||||
let expected = DataType::List(Box::new(DataType::Decimal(Some(7usize), Some(2usize))));
|
||||
assert_eq!(schema, expected);
|
||||
|
||||
let dtype = "list<decimal<*,2>>";
|
||||
let schema = str_to_dtype(dtype, Span::unknown()).unwrap();
|
||||
let expected = DataType::List(Box::new(DataType::Decimal(None, Some(2usize))));
|
||||
assert_eq!(schema, expected);
|
||||
|
||||
let dtype = "list<decimal<7,*>>";
|
||||
let schema = str_to_dtype(dtype, Span::unknown());
|
||||
assert!(matches!(schema, Err(ShellError::GenericError { .. })));
|
||||
}
|
||||
}
|
@ -0,0 +1,65 @@
|
||||
use nu_protocol::{CustomValue, ShellError, Span, Value};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::values::{CustomValueSupport, PolarsPluginCustomValue};
|
||||
|
||||
use super::NuSchema;
|
||||
|
||||
#[derive(Clone, Debug, Deserialize, Serialize)]
|
||||
pub struct NuSchemaCustomValue {
|
||||
pub id: Uuid,
|
||||
#[serde(skip)]
|
||||
pub datatype: Option<NuSchema>,
|
||||
}
|
||||
|
||||
#[typetag::serde]
|
||||
impl CustomValue for NuSchemaCustomValue {
|
||||
fn clone_value(&self, span: nu_protocol::Span) -> Value {
|
||||
Value::custom(Box::new(self.clone()), span)
|
||||
}
|
||||
|
||||
fn type_name(&self) -> String {
|
||||
"NuSchema".into()
|
||||
}
|
||||
|
||||
fn to_base_value(&self, span: Span) -> Result<Value, ShellError> {
|
||||
Ok(Value::string(
|
||||
"NuSchema: custom_value_to_base_value should've been called",
|
||||
span,
|
||||
))
|
||||
}
|
||||
|
||||
fn as_mut_any(&mut self) -> &mut dyn std::any::Any {
|
||||
self
|
||||
}
|
||||
|
||||
fn as_any(&self) -> &dyn std::any::Any {
|
||||
self
|
||||
}
|
||||
|
||||
fn notify_plugin_on_drop(&self) -> bool {
|
||||
true
|
||||
}
|
||||
}
|
||||
|
||||
impl PolarsPluginCustomValue for NuSchemaCustomValue {
|
||||
type PolarsPluginObjectType = NuSchema;
|
||||
|
||||
fn id(&self) -> &Uuid {
|
||||
&self.id
|
||||
}
|
||||
|
||||
fn internal(&self) -> &Option<Self::PolarsPluginObjectType> {
|
||||
&self.datatype
|
||||
}
|
||||
|
||||
fn custom_value_to_base_value(
|
||||
&self,
|
||||
plugin: &crate::PolarsPlugin,
|
||||
_engine: &nu_plugin::EngineInterface,
|
||||
) -> Result<Value, ShellError> {
|
||||
let dtype = NuSchema::try_from_custom_value(plugin, self)?;
|
||||
dtype.base_value(Span::unknown())
|
||||
}
|
||||
}
|
189
crates/nu_plugin_polars/src/dataframe/values/nu_schema/mod.rs
Normal file
189
crates/nu_plugin_polars/src/dataframe/values/nu_schema/mod.rs
Normal file
@ -0,0 +1,189 @@
|
||||
pub mod custom_value;
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
use custom_value::NuSchemaCustomValue;
|
||||
use nu_protocol::{ShellError, Span, Value};
|
||||
use polars::prelude::{DataType, Field, Schema, SchemaExt, SchemaRef};
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::{Cacheable, PolarsPlugin};
|
||||
|
||||
use super::{str_to_dtype, CustomValueSupport, NuDataType, PolarsPluginObject, PolarsPluginType};
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct NuSchema {
|
||||
pub id: Uuid,
|
||||
pub schema: SchemaRef,
|
||||
}
|
||||
|
||||
impl NuSchema {
|
||||
pub fn new(schema: SchemaRef) -> Self {
|
||||
Self {
|
||||
id: Uuid::new_v4(),
|
||||
schema,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<NuSchema> for SchemaRef {
|
||||
fn from(val: NuSchema) -> Self {
|
||||
Arc::clone(&val.schema)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<SchemaRef> for NuSchema {
|
||||
fn from(val: SchemaRef) -> Self {
|
||||
Self::new(val)
|
||||
}
|
||||
}
|
||||
|
||||
impl Cacheable for NuSchema {
|
||||
fn cache_id(&self) -> &Uuid {
|
||||
&self.id
|
||||
}
|
||||
|
||||
fn to_cache_value(&self) -> Result<super::PolarsPluginObject, ShellError> {
|
||||
Ok(PolarsPluginObject::NuSchema(self.clone()))
|
||||
}
|
||||
|
||||
fn from_cache_value(cv: super::PolarsPluginObject) -> Result<Self, ShellError> {
|
||||
match cv {
|
||||
PolarsPluginObject::NuSchema(dt) => Ok(dt),
|
||||
_ => Err(ShellError::GenericError {
|
||||
error: "Cache value is not a dataframe".into(),
|
||||
msg: "".into(),
|
||||
span: None,
|
||||
help: None,
|
||||
inner: vec![],
|
||||
}),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl CustomValueSupport for NuSchema {
|
||||
type CV = NuSchemaCustomValue;
|
||||
|
||||
fn get_type_static() -> super::PolarsPluginType {
|
||||
PolarsPluginType::NuSchema
|
||||
}
|
||||
|
||||
fn custom_value(self) -> Self::CV {
|
||||
NuSchemaCustomValue {
|
||||
id: self.id,
|
||||
datatype: Some(self),
|
||||
}
|
||||
}
|
||||
|
||||
fn base_value(self, span: Span) -> Result<Value, ShellError> {
|
||||
Ok(fields_to_value(self.schema.iter_fields(), span))
|
||||
}
|
||||
|
||||
fn try_from_value(plugin: &PolarsPlugin, value: &Value) -> Result<Self, ShellError> {
|
||||
if let Value::Custom { val, .. } = value {
|
||||
if let Some(cv) = val.as_any().downcast_ref::<Self::CV>() {
|
||||
Self::try_from_custom_value(plugin, cv)
|
||||
} else {
|
||||
Err(ShellError::CantConvert {
|
||||
to_type: Self::get_type_static().to_string(),
|
||||
from_type: value.get_type().to_string(),
|
||||
span: value.span(),
|
||||
help: None,
|
||||
})
|
||||
}
|
||||
} else {
|
||||
let schema = value_to_schema(plugin, value, Span::unknown())?;
|
||||
Ok(Self::new(Arc::new(schema)))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn fields_to_value(fields: impl Iterator<Item = Field>, span: Span) -> Value {
|
||||
let record = fields
|
||||
.map(|field| {
|
||||
let col = field.name().to_string();
|
||||
let val = dtype_to_value(field.dtype(), span);
|
||||
(col, val)
|
||||
})
|
||||
.collect();
|
||||
|
||||
Value::record(record, Span::unknown())
|
||||
}
|
||||
|
||||
pub fn dtype_to_value(dtype: &DataType, span: Span) -> Value {
|
||||
match dtype {
|
||||
DataType::Struct(fields) => fields_to_value(fields.iter().cloned(), span),
|
||||
_ => Value::string(dtype.to_string().replace('[', "<").replace(']', ">"), span),
|
||||
}
|
||||
}
|
||||
|
||||
fn value_to_schema(plugin: &PolarsPlugin, value: &Value, span: Span) -> Result<Schema, ShellError> {
|
||||
let fields = value_to_fields(plugin, value, span)?;
|
||||
let schema = Schema::from_iter(fields);
|
||||
Ok(schema)
|
||||
}
|
||||
|
||||
fn value_to_fields(
|
||||
plugin: &PolarsPlugin,
|
||||
value: &Value,
|
||||
span: Span,
|
||||
) -> Result<Vec<Field>, ShellError> {
|
||||
let fields = value
|
||||
.as_record()?
|
||||
.into_iter()
|
||||
.map(|(col, val)| match val {
|
||||
Value::Record { .. } => {
|
||||
let fields = value_to_fields(plugin, val, span)?;
|
||||
let dtype = DataType::Struct(fields);
|
||||
Ok(Field::new(col.into(), dtype))
|
||||
}
|
||||
Value::Custom { .. } => {
|
||||
let dtype = NuDataType::try_from_value(plugin, val)?;
|
||||
Ok(Field::new(col.into(), dtype.to_polars()))
|
||||
}
|
||||
_ => {
|
||||
let dtype = str_to_dtype(&val.coerce_string()?, span)?;
|
||||
Ok(Field::new(col.into(), dtype))
|
||||
}
|
||||
})
|
||||
.collect::<Result<Vec<Field>, ShellError>>()?;
|
||||
Ok(fields)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
|
||||
use nu_protocol::record;
|
||||
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_value_to_schema() {
|
||||
let plugin = PolarsPlugin::new_test_mode().expect("Failed to create plugin");
|
||||
|
||||
let address = record! {
|
||||
"street" => Value::test_string("str"),
|
||||
"city" => Value::test_string("str"),
|
||||
};
|
||||
|
||||
let value = Value::test_record(record! {
|
||||
"name" => Value::test_string("str"),
|
||||
"age" => Value::test_string("i32"),
|
||||
"address" => Value::test_record(address)
|
||||
});
|
||||
|
||||
let schema = value_to_schema(&plugin, &value, Span::unknown()).unwrap();
|
||||
let expected = Schema::from_iter(vec![
|
||||
Field::new("name".into(), DataType::String),
|
||||
Field::new("age".into(), DataType::Int32),
|
||||
Field::new(
|
||||
"address".into(),
|
||||
DataType::Struct(vec![
|
||||
Field::new("street".into(), DataType::String),
|
||||
Field::new("city".into(), DataType::String),
|
||||
]),
|
||||
),
|
||||
]);
|
||||
assert_eq!(schema, expected);
|
||||
}
|
||||
}
|
@ -123,6 +123,8 @@ impl Plugin for PolarsPlugin {
|
||||
CustomValueType::NuExpression(cv) => cv.custom_value_to_base_value(self, engine),
|
||||
CustomValueType::NuLazyGroupBy(cv) => cv.custom_value_to_base_value(self, engine),
|
||||
CustomValueType::NuWhen(cv) => cv.custom_value_to_base_value(self, engine),
|
||||
CustomValueType::NuDataType(cv) => cv.custom_value_to_base_value(self, engine),
|
||||
CustomValueType::NuSchema(cv) => cv.custom_value_to_base_value(self, engine),
|
||||
};
|
||||
Ok(result?)
|
||||
}
|
||||
@ -150,6 +152,12 @@ impl Plugin for PolarsPlugin {
|
||||
CustomValueType::NuWhen(cv) => {
|
||||
cv.custom_value_operation(self, engine, left.span, operator, right)
|
||||
}
|
||||
CustomValueType::NuDataType(cv) => {
|
||||
cv.custom_value_operation(self, engine, left.span, operator, right)
|
||||
}
|
||||
CustomValueType::NuSchema(cv) => {
|
||||
cv.custom_value_operation(self, engine, left.span, operator, right)
|
||||
}
|
||||
};
|
||||
Ok(result?)
|
||||
}
|
||||
@ -176,6 +184,12 @@ impl Plugin for PolarsPlugin {
|
||||
CustomValueType::NuWhen(cv) => {
|
||||
cv.custom_value_follow_path_int(self, engine, custom_value.span, index)
|
||||
}
|
||||
CustomValueType::NuDataType(cv) => {
|
||||
cv.custom_value_follow_path_int(self, engine, custom_value.span, index)
|
||||
}
|
||||
CustomValueType::NuSchema(cv) => {
|
||||
cv.custom_value_follow_path_int(self, engine, custom_value.span, index)
|
||||
}
|
||||
};
|
||||
Ok(result?)
|
||||
}
|
||||
@ -202,6 +216,12 @@ impl Plugin for PolarsPlugin {
|
||||
CustomValueType::NuWhen(cv) => {
|
||||
cv.custom_value_follow_path_string(self, engine, custom_value.span, column_name)
|
||||
}
|
||||
CustomValueType::NuDataType(cv) => {
|
||||
cv.custom_value_follow_path_string(self, engine, custom_value.span, column_name)
|
||||
}
|
||||
CustomValueType::NuSchema(cv) => {
|
||||
cv.custom_value_follow_path_string(self, engine, custom_value.span, column_name)
|
||||
}
|
||||
};
|
||||
Ok(result?)
|
||||
}
|
||||
@ -226,6 +246,10 @@ impl Plugin for PolarsPlugin {
|
||||
cv.custom_value_partial_cmp(self, engine, other_value)
|
||||
}
|
||||
CustomValueType::NuWhen(cv) => cv.custom_value_partial_cmp(self, engine, other_value),
|
||||
CustomValueType::NuDataType(cv) => {
|
||||
cv.custom_value_partial_cmp(self, engine, other_value)
|
||||
}
|
||||
CustomValueType::NuSchema(cv) => cv.custom_value_partial_cmp(self, engine, other_value),
|
||||
};
|
||||
Ok(result?)
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user