Using try_from_value for NuSchema

This commit is contained in:
Jack Wright 2025-04-08 16:38:35 -07:00
parent 6e0c2cc511
commit f44bd9863a
4 changed files with 38 additions and 24 deletions

View File

@ -163,7 +163,7 @@ fn command(
});
}
let hive_options = build_hive_options(call)?;
let hive_options = build_hive_options(plugin, call)?;
match type_option {
Some((ext, blamed)) => match PolarsFileType::from(ext.as_str()) {
@ -388,7 +388,7 @@ fn from_json(
})?;
let maybe_schema = call
.get_flag("schema")?
.map(|schema| NuSchema::try_from(&schema))
.map(|schema| NuSchema::try_from_value(plugin, &schema))
.transpose()?;
let buf_reader = BufReader::new(file);
@ -429,11 +429,7 @@ fn from_ndjson(
NonZeroUsize::new(DEFAULT_INFER_SCHEMA)
.expect("The default infer-schema should be non zero"),
);
let maybe_schema = call
.get_flag("schema")?
.map(|schema| NuSchema::try_from(&schema))
.transpose()?;
let maybe_schema = get_schema(plugin, call)?;
if !is_eager {
let start_time = std::time::Instant::now();
@ -511,10 +507,7 @@ fn from_csv(
.unwrap_or(DEFAULT_INFER_SCHEMA);
let skip_rows: Option<usize> = call.get_flag("skip-rows")?;
let columns: Option<Vec<String>> = call.get_flag("columns")?;
let maybe_schema = call
.get_flag("schema")?
.map(|schema| NuSchema::try_from(&schema))
.transpose()?;
let maybe_schema = get_schema(plugin, call)?;
let truncate_ragged_lines: bool = call.has_flag("truncate-ragged-lines")?;
if !is_eager {
@ -627,12 +620,15 @@ fn cloud_not_supported(file_type: PolarsFileType, span: Span) -> ShellError {
}
}
fn build_hive_options(call: &EvaluatedCall) -> Result<HiveOptions, ShellError> {
fn build_hive_options(
plugin: &PolarsPlugin,
call: &EvaluatedCall,
) -> Result<HiveOptions, ShellError> {
let enabled: Option<bool> = call.get_flag("hive-enabled")?;
let hive_start_idx: Option<usize> = call.get_flag("hive-start-idx")?;
let schema: Option<NuSchema> = call
.get_flag::<Value>("hive-schema")?
.map(|schema| NuSchema::try_from(&schema))
.map(|schema| NuSchema::try_from_value(plugin, &schema))
.transpose()?;
let try_parse_dates: bool = call.has_flag("hive-try-parse-dates")?;
@ -643,3 +639,11 @@ fn build_hive_options(call: &EvaluatedCall) -> Result<HiveOptions, ShellError> {
try_parse_dates,
})
}
fn get_schema(plugin: &PolarsPlugin, call: &EvaluatedCall) -> Result<Option<NuSchema>, ShellError> {
let schema: Option<NuSchema> = call
.get_flag("schema")?
.map(|schema| NuSchema::try_from_value(plugin, &schema))
.transpose()?;
Ok(schema)
}

View File

@ -206,7 +206,7 @@ impl PluginCommand for ToDataFrame {
) -> Result<PipelineData, LabeledError> {
let maybe_schema = call
.get_flag("schema")?
.map(|schema| NuSchema::try_from(&schema))
.map(|schema| NuSchema::try_from_value(plugin, &schema))
.transpose()?;
debug!("schema: {:?}", maybe_schema);

View File

@ -54,7 +54,7 @@ impl PluginCommand for ToLazyFrame {
) -> Result<PipelineData, LabeledError> {
let maybe_schema = call
.get_flag("schema")?
.map(|schema| NuSchema::try_from(&schema))
.map(|schema| NuSchema::try_from_value(plugin, &schema))
.transpose()?;
let df = NuDataFrame::try_from_iter(plugin, input.into_iter(), maybe_schema)?;

View File

@ -7,7 +7,7 @@ use nu_protocol::{ShellError, Span, Value};
use polars::prelude::{DataType, Field, Schema, SchemaExt, SchemaRef};
use uuid::Uuid;
use crate::Cacheable;
use crate::{Cacheable, PolarsPlugin};
use super::{str_to_dtype, CustomValueSupport, PolarsPluginObject, PolarsPluginType};
@ -26,14 +26,6 @@ impl NuSchema {
}
}
impl TryFrom<&Value> for NuSchema {
type Error = ShellError;
fn try_from(value: &Value) -> Result<Self, Self::Error> {
let schema = value_to_schema(value, Span::unknown())?;
Ok(Self::new(Arc::new(schema)))
}
}
impl From<NuSchema> for SchemaRef {
fn from(val: NuSchema) -> Self {
Arc::clone(&val.schema)
@ -86,6 +78,24 @@ impl CustomValueSupport for NuSchema {
fn base_value(self, span: Span) -> Result<Value, ShellError> {
Ok(fields_to_value(self.schema.iter_fields(), span))
}
fn try_from_value(plugin: &PolarsPlugin, value: &Value) -> Result<Self, ShellError> {
if let Value::Custom { val, .. } = value {
if let Some(cv) = val.as_any().downcast_ref::<Self::CV>() {
Self::try_from_custom_value(plugin, cv)
} else {
Err(ShellError::CantConvert {
to_type: Self::get_type_static().to_string(),
from_type: value.get_type().to_string(),
span: value.span(),
help: None,
})
}
} else {
let schema = value_to_schema(value, Span::unknown())?;
Ok(Self::new(Arc::new(schema)))
}
}
}
fn fields_to_value(fields: impl Iterator<Item = Field>, span: Span) -> Value {