Allow NuDataType values to be passed in when creating schemas

This commit is contained in:
Jack Wright 2025-04-08 16:52:11 -07:00
parent f44bd9863a
commit b2509a8084

View File

@ -9,7 +9,7 @@ use uuid::Uuid;
use crate::{Cacheable, PolarsPlugin};
use super::{str_to_dtype, CustomValueSupport, PolarsPluginObject, PolarsPluginType};
use super::{str_to_dtype, CustomValueSupport, NuDataType, PolarsPluginObject, PolarsPluginType};
#[derive(Debug, Clone)]
pub struct NuSchema {
@ -92,7 +92,7 @@ impl CustomValueSupport for NuSchema {
})
}
} else {
let schema = value_to_schema(value, Span::unknown())?;
let schema = value_to_schema(plugin, value, Span::unknown())?;
Ok(Self::new(Arc::new(schema)))
}
}
@ -117,22 +117,30 @@ pub fn dtype_to_value(dtype: &DataType, span: Span) -> Value {
}
}
fn value_to_schema(value: &Value, span: Span) -> Result<Schema, ShellError> {
let fields = value_to_fields(value, span)?;
fn value_to_schema(plugin: &PolarsPlugin, value: &Value, span: Span) -> Result<Schema, ShellError> {
let fields = value_to_fields(plugin, value, span)?;
let schema = Schema::from_iter(fields);
Ok(schema)
}
fn value_to_fields(value: &Value, span: Span) -> Result<Vec<Field>, ShellError> {
fn value_to_fields(
plugin: &PolarsPlugin,
value: &Value,
span: Span,
) -> Result<Vec<Field>, ShellError> {
let fields = value
.as_record()?
.into_iter()
.map(|(col, val)| match val {
Value::Record { .. } => {
let fields = value_to_fields(val, span)?;
let fields = value_to_fields(plugin, val, span)?;
let dtype = DataType::Struct(fields);
Ok(Field::new(col.into(), dtype))
}
Value::Custom { .. } => {
let dtype = NuDataType::try_from_value(plugin, val)?;
Ok(Field::new(col.into(), dtype.to_polars()))
}
_ => {
let dtype = str_to_dtype(&val.coerce_string()?, span)?;
Ok(Field::new(col.into(), dtype))
@ -151,6 +159,8 @@ mod test {
#[test]
fn test_value_to_schema() {
let plugin = PolarsPlugin::new_test_mode().expect("Failed to create plugin");
let address = record! {
"street" => Value::test_string("str"),
"city" => Value::test_string("str"),
@ -162,7 +172,7 @@ mod test {
"address" => Value::test_record(address)
});
let schema = value_to_schema(&value, Span::unknown()).unwrap();
let schema = value_to_schema(&plugin, &value, Span::unknown()).unwrap();
let expected = Schema::from_iter(vec![
Field::new("name".into(), DataType::String),
Field::new("age".into(), DataType::Int32),