diff --git a/crates/nu_plugin_polars/src/dataframe/command/core/open.rs b/crates/nu_plugin_polars/src/dataframe/command/core/open.rs index f01fc76824..21c84267d9 100644 --- a/crates/nu_plugin_polars/src/dataframe/command/core/open.rs +++ b/crates/nu_plugin_polars/src/dataframe/command/core/open.rs @@ -9,8 +9,8 @@ use nu_utils::perf; use nu_plugin::{EvaluatedCall, PluginCommand}; use nu_protocol::{ - shell_error::io::IoError, Category, Example, LabeledError, PipelineData, ShellError, Signature, - Span, Spanned, SyntaxShape, Type, Value, + shell_error::io::IoError, Category, DataSource, Example, LabeledError, PipelineData, + PipelineMetadata, ShellError, Signature, Span, Spanned, SyntaxShape, Type, Value, }; use std::{fs::File, io::BufReader, num::NonZeroUsize, path::PathBuf, sync::Arc}; @@ -164,6 +164,8 @@ fn command( } let hive_options = build_hive_options(plugin, call)?; + let metadata = PipelineMetadata::default() + .with_data_source(DataSource::FilePath(spanned_file.item.clone().into())); match type_option { Some((ext, blamed)) => match PolarsFileType::from(ext.as_str()) { @@ -199,7 +201,7 @@ fn command( "File without extension", ))), } - .map(|value| PipelineData::Value(value, None)) + .map(|value| PipelineData::Value(value, Some(metadata))) } fn from_parquet( diff --git a/crates/nu_plugin_polars/src/dataframe/command/core/save/mod.rs b/crates/nu_plugin_polars/src/dataframe/command/core/save/mod.rs index ce56acd984..cab96c308e 100644 --- a/crates/nu_plugin_polars/src/dataframe/command/core/save/mod.rs +++ b/crates/nu_plugin_polars/src/dataframe/command/core/save/mod.rs @@ -15,8 +15,8 @@ use crate::{ use log::debug; use nu_plugin::{EngineInterface, EvaluatedCall, PluginCommand}; use nu_protocol::{ - shell_error::io::IoError, Category, Example, LabeledError, PipelineData, ShellError, Signature, - Span, Spanned, SyntaxShape, Type, + shell_error::io::IoError, Category, DataSource, Example, LabeledError, PipelineData, + PipelineMetadata, ShellError, Signature, Span, Spanned, SyntaxShape, Type, }; use polars::error::PolarsError; @@ -112,11 +112,20 @@ impl PluginCommand for SaveDF { call: &EvaluatedCall, input: PipelineData, ) -> Result { + let spanned_file: Spanned = call.req(0)?; + debug!("file: {}", spanned_file.item); + + let metadata = input.metadata(); let value = input.into_value(call.head)?; + check_writing_into_source_file( + metadata.as_ref(), + &spanned_file.as_ref().map(PathBuf::from), + )?; + match PolarsPluginObject::try_from_value(plugin, &value)? { po @ PolarsPluginObject::NuDataFrame(_) | po @ PolarsPluginObject::NuLazyFrame(_) => { - command(plugin, engine, call, po) + command(plugin, engine, call, po, spanned_file) } _ => Err(cant_convert_err( &value, @@ -132,10 +141,8 @@ fn command( engine: &EngineInterface, call: &EvaluatedCall, polars_object: PolarsPluginObject, + spanned_file: Spanned, ) -> Result { - let spanned_file: Spanned = call.req(0)?; - debug!("file: {}", spanned_file.item); - let resource = Resource::new(plugin, engine, &spanned_file)?; let type_option: Option<(String, Span)> = call .get_flag("type")? @@ -223,6 +230,28 @@ fn command( Ok(PipelineData::empty()) } +fn check_writing_into_source_file( + metadata: Option<&PipelineMetadata>, + dest: &Spanned, +) -> Result<(), ShellError> { + let Some(DataSource::FilePath(source)) = metadata.map(|meta| &meta.data_source) else { + return Ok(()); + }; + + if &dest.item == source { + return Err(write_into_source_error(dest.span)); + } + + Ok(()) +} + +fn write_into_source_error(span: Span) -> ShellError { + polars_file_save_error( + PolarsError::InvalidOperation("attempted to save into source".into()), + span, + ) +} + pub(crate) fn polars_file_save_error(e: PolarsError, span: Span) -> ShellError { ShellError::GenericError { error: format!("Error saving file: {e}"), @@ -247,17 +276,13 @@ pub fn unknown_file_save_error(span: Span) -> ShellError { pub(crate) mod test { use nu_plugin_test_support::PluginTest; use nu_protocol::{Span, Value}; + use tempfile::TempDir; use uuid::Uuid; use crate::PolarsPlugin; - fn test_save(cmd: &'static str, extension: &str) -> Result<(), Box> { + fn tmp_dir_sandbox() -> Result<(TempDir, PluginTest), Box> { let tmp_dir = tempfile::tempdir()?; - let mut tmp_file = tmp_dir.path().to_owned(); - tmp_file.push(format!("{}.{}", Uuid::new_v4(), extension)); - let tmp_file_str = tmp_file.to_str().expect("should be able to get file path"); - - let cmd = format!("{cmd} {tmp_file_str}"); let mut plugin_test = PluginTest::new("polars", PolarsPlugin::new()?.into())?; plugin_test.engine_state_mut().add_env_var( "PWD".to_string(), @@ -270,6 +295,17 @@ pub(crate) mod test { Span::test_data(), ), ); + + Ok((tmp_dir, plugin_test)) + } + + fn test_save(cmd: &'static str, extension: &str) -> Result<(), Box> { + let (tmp_dir, mut plugin_test) = tmp_dir_sandbox()?; + let mut tmp_file = tmp_dir.path().to_owned(); + tmp_file.push(format!("{}.{}", Uuid::new_v4(), extension)); + let tmp_file_str = tmp_file.to_str().expect("should be able to get file path"); + + let cmd = format!("{cmd} {tmp_file_str}"); let _pipeline_data = plugin_test.eval(&cmd)?; assert!(tmp_file.exists()); @@ -290,4 +326,27 @@ pub(crate) mod test { extension, ) } + + #[test] + fn test_write_to_source_guard() -> Result<(), Box> { + let (tmp_dir, mut plugin_test) = tmp_dir_sandbox()?; + let mut tmp_file = tmp_dir.path().to_owned(); + dbg!(&tmp_dir); + tmp_file.push(format!("{}.{}", Uuid::new_v4(), "parquet")); + let tmp_file_str = tmp_file.to_str().expect("Should be able to get file path"); + + let _setup = plugin_test.eval(&format!( + "[1 2 3] | polars into-df | polars save {tmp_file_str}", + ))?; + + let output = plugin_test.eval(&format!( + "polars open {tmp_file_str} | polars save {tmp_file_str}" + )); + + assert!(output.is_err_and(|e| e + .to_string() + .contains("Error saving file: attempted to save into source"))); + + Ok(()) + } }