diff --git a/crates/nu-cmd-dataframe/src/dataframe/values/nu_dataframe/custom_value.rs b/crates/nu-cmd-dataframe/src/dataframe/values/nu_dataframe/custom_value.rs index 7bde7cb539..da8b27398b 100644 --- a/crates/nu-cmd-dataframe/src/dataframe/values/nu_dataframe/custom_value.rs +++ b/crates/nu-cmd-dataframe/src/dataframe/values/nu_dataframe/custom_value.rs @@ -34,6 +34,10 @@ impl CustomValue for NuDataFrame { self } + fn as_mut_any(&mut self) -> &mut dyn std::any::Any { + self + } + fn follow_path_int( &self, _self_span: Span, diff --git a/crates/nu-cmd-dataframe/src/dataframe/values/nu_expression/custom_value.rs b/crates/nu-cmd-dataframe/src/dataframe/values/nu_expression/custom_value.rs index 0859746fb9..7a7f59e648 100644 --- a/crates/nu-cmd-dataframe/src/dataframe/values/nu_expression/custom_value.rs +++ b/crates/nu-cmd-dataframe/src/dataframe/values/nu_expression/custom_value.rs @@ -34,6 +34,10 @@ impl CustomValue for NuExpression { self } + fn as_mut_any(&mut self) -> &mut dyn std::any::Any { + self + } + fn operation( &self, lhs_span: Span, diff --git a/crates/nu-cmd-dataframe/src/dataframe/values/nu_lazyframe/custom_value.rs b/crates/nu-cmd-dataframe/src/dataframe/values/nu_lazyframe/custom_value.rs index 0990e95663..f747ae4d18 100644 --- a/crates/nu-cmd-dataframe/src/dataframe/values/nu_lazyframe/custom_value.rs +++ b/crates/nu-cmd-dataframe/src/dataframe/values/nu_lazyframe/custom_value.rs @@ -43,4 +43,8 @@ impl CustomValue for NuLazyFrame { fn as_any(&self) -> &dyn std::any::Any { self } + + fn as_mut_any(&mut self) -> &mut dyn std::any::Any { + self + } } diff --git a/crates/nu-cmd-dataframe/src/dataframe/values/nu_lazygroupby/custom_value.rs b/crates/nu-cmd-dataframe/src/dataframe/values/nu_lazygroupby/custom_value.rs index 0f686c6bd2..6ac6cc6046 100644 --- a/crates/nu-cmd-dataframe/src/dataframe/values/nu_lazygroupby/custom_value.rs +++ b/crates/nu-cmd-dataframe/src/dataframe/values/nu_lazygroupby/custom_value.rs @@ -37,4 +37,8 @@ impl CustomValue for NuLazyGroupBy { fn as_any(&self) -> &dyn std::any::Any { self } + + fn as_mut_any(&mut self) -> &mut dyn std::any::Any { + self + } } diff --git a/crates/nu-cmd-dataframe/src/dataframe/values/nu_when/custom_value.rs b/crates/nu-cmd-dataframe/src/dataframe/values/nu_when/custom_value.rs index 60d13d7088..e2b73bcef1 100644 --- a/crates/nu-cmd-dataframe/src/dataframe/values/nu_when/custom_value.rs +++ b/crates/nu-cmd-dataframe/src/dataframe/values/nu_when/custom_value.rs @@ -34,4 +34,8 @@ impl CustomValue for NuWhen { fn as_any(&self) -> &dyn std::any::Any { self } + + fn as_mut_any(&mut self) -> &mut dyn std::any::Any { + self + } } diff --git a/crates/nu-command/src/database/values/sqlite.rs b/crates/nu-command/src/database/values/sqlite.rs index d865f4a26a..6225d12897 100644 --- a/crates/nu-command/src/database/values/sqlite.rs +++ b/crates/nu-command/src/database/values/sqlite.rs @@ -379,6 +379,10 @@ impl CustomValue for SQLiteDatabase { self } + fn as_mut_any(&mut self) -> &mut dyn std::any::Any { + self + } + fn follow_path_int( &self, _self_span: Span, diff --git a/crates/nu-plugin-test-support/src/plugin_test.rs b/crates/nu-plugin-test-support/src/plugin_test.rs index cf6f2dddff..54ab84d3aa 100644 --- a/crates/nu-plugin-test-support/src/plugin_test.rs +++ b/crates/nu-plugin-test-support/src/plugin_test.rs @@ -136,7 +136,7 @@ impl PluginTest { move |mut value| match PluginCustomValue::serialize_custom_values_in(&mut value) { Ok(()) => { // Make sure to mark them with the source so they pass correctly, too. - PluginCustomValue::add_source(&mut value, &source); + let _ = PluginCustomValue::add_source_in(&mut value, &source); value } Err(err) => Value::error(err, value.span()), diff --git a/crates/nu-plugin-test-support/tests/custom_value/mod.rs b/crates/nu-plugin-test-support/tests/custom_value/mod.rs index 735a2daa9e..aaae5538ff 100644 --- a/crates/nu-plugin-test-support/tests/custom_value/mod.rs +++ b/crates/nu-plugin-test-support/tests/custom_value/mod.rs @@ -35,6 +35,10 @@ impl CustomValue for CustomU32 { self } + fn as_mut_any(&mut self) -> &mut dyn std::any::Any { + self + } + fn partial_cmp(&self, other: &Value) -> Option { other .as_custom_value() diff --git a/crates/nu-plugin/src/lib.rs b/crates/nu-plugin/src/lib.rs index d5c6c27157..f3e4bf9ce5 100644 --- a/crates/nu-plugin/src/lib.rs +++ b/crates/nu-plugin/src/lib.rs @@ -67,7 +67,6 @@ mod plugin; mod protocol; mod sequence; mod serializers; -mod util; pub use plugin::{ serve_plugin, EngineInterface, Plugin, PluginCommand, PluginEncoder, PluginRead, PluginWrite, @@ -88,6 +87,8 @@ pub use plugin::{ pub use protocol::{PluginCustomValue, PluginInput, PluginOutput}; #[doc(hidden)] pub use serializers::EncodingType; +#[doc(hidden)] +pub mod util; // Used by external benchmarks. #[doc(hidden)] diff --git a/crates/nu-plugin/src/plugin/interface.rs b/crates/nu-plugin/src/plugin/interface.rs index a200bc9703..45fbd75c12 100644 --- a/crates/nu-plugin/src/plugin/interface.rs +++ b/crates/nu-plugin/src/plugin/interface.rs @@ -231,6 +231,9 @@ pub trait Interface: Clone + Send { /// The output message type, which must be capable of encapsulating a [`StreamMessage`]. type Output: From; + /// Any context required to construct [`PipelineData`]. Can be `()` if not needed. + type DataContext; + /// Write an output message. fn write(&self, output: Self::Output) -> Result<(), ShellError>; @@ -245,7 +248,11 @@ pub trait Interface: Clone + Send { /// Prepare [`PipelineData`] to be written. This is called by `init_write_pipeline_data()` as /// a hook so that values that need special handling can be taken care of. - fn prepare_pipeline_data(&self, data: PipelineData) -> Result; + fn prepare_pipeline_data( + &self, + data: PipelineData, + context: &Self::DataContext, + ) -> Result; /// Initialize a write for [`PipelineData`]. This returns two parts: the header, which can be /// embedded in the particular message that references the stream, and a writer, which will @@ -258,6 +265,7 @@ pub trait Interface: Clone + Send { fn init_write_pipeline_data( &self, data: PipelineData, + context: &Self::DataContext, ) -> Result<(PipelineDataHeader, PipelineDataWriter), ShellError> { // Allocate a stream id and a writer let new_stream = |high_pressure_mark: i32| { @@ -269,7 +277,7 @@ pub trait Interface: Clone + Send { .write_stream(id, self.clone(), high_pressure_mark)?; Ok::<_, ShellError>((id, writer)) }; - match self.prepare_pipeline_data(data)? { + match self.prepare_pipeline_data(data, context)? { PipelineData::Value(value, _) => { Ok((PipelineDataHeader::Value(value), PipelineDataWriter::None)) } diff --git a/crates/nu-plugin/src/plugin/interface/engine.rs b/crates/nu-plugin/src/plugin/interface/engine.rs index 5791477939..79433ae9e7 100644 --- a/crates/nu-plugin/src/plugin/interface/engine.rs +++ b/crates/nu-plugin/src/plugin/interface/engine.rs @@ -377,7 +377,7 @@ impl EngineInterface { ) -> Result, ShellError> { match result { Ok(data) => { - let (header, writer) = match self.init_write_pipeline_data(data) { + let (header, writer) = match self.init_write_pipeline_data(data, &()) { Ok(tup) => tup, // If we get an error while trying to construct the pipeline data, send that // instead @@ -438,7 +438,7 @@ impl EngineInterface { let mut writer = None; let call = call.map_data(|input| { - let (input_header, input_writer) = self.init_write_pipeline_data(input)?; + let (input_header, input_writer) = self.init_write_pipeline_data(input, &())?; writer = Some(input_writer); Ok(input_header) })?; @@ -809,6 +809,7 @@ impl EngineInterface { impl Interface for EngineInterface { type Output = PluginOutput; + type DataContext = (); fn write(&self, output: PluginOutput) -> Result<(), ShellError> { log::trace!("to engine: {:?}", output); @@ -827,7 +828,11 @@ impl Interface for EngineInterface { &self.stream_manager_handle } - fn prepare_pipeline_data(&self, mut data: PipelineData) -> Result { + fn prepare_pipeline_data( + &self, + mut data: PipelineData, + _context: &(), + ) -> Result { // Serialize custom values in the pipeline data match data { PipelineData::Value(ref mut value, _) => { diff --git a/crates/nu-plugin/src/plugin/interface/engine/tests.rs b/crates/nu-plugin/src/plugin/interface/engine/tests.rs index 1ebe2491f7..fb98ac123f 100644 --- a/crates/nu-plugin/src/plugin/interface/engine/tests.rs +++ b/crates/nu-plugin/src/plugin/interface/engine/tests.rs @@ -1085,10 +1085,13 @@ fn interface_eval_closure_with_stream() -> Result<(), ShellError> { fn interface_prepare_pipeline_data_serializes_custom_values() -> Result<(), ShellError> { let interface = TestCase::new().engine().get_interface(); - let data = interface.prepare_pipeline_data(PipelineData::Value( - Value::test_custom_value(Box::new(expected_test_custom_value())), - None, - ))?; + let data = interface.prepare_pipeline_data( + PipelineData::Value( + Value::test_custom_value(Box::new(expected_test_custom_value())), + None, + ), + &(), + )?; let value = data .into_iter() @@ -1117,6 +1120,7 @@ fn interface_prepare_pipeline_data_serializes_custom_values_in_streams() -> Resu expected_test_custom_value(), ))] .into_pipeline_data(None), + &(), )?; let value = data @@ -1161,6 +1165,10 @@ impl CustomValue for CantSerialize { fn as_any(&self) -> &dyn std::any::Any { self } + + fn as_mut_any(&mut self) -> &mut dyn std::any::Any { + self + } } #[test] @@ -1171,6 +1179,7 @@ fn interface_prepare_pipeline_data_embeds_serialization_errors_in_streams() -> R let span = Span::new(40, 60); let data = interface.prepare_pipeline_data( [Value::custom(Box::new(CantSerialize::BadVariant), span)].into_pipeline_data(None), + &(), )?; let value = data diff --git a/crates/nu-plugin/src/plugin/interface/plugin.rs b/crates/nu-plugin/src/plugin/interface/plugin.rs index bf8b1b44fd..91e64f1407 100644 --- a/crates/nu-plugin/src/plugin/interface/plugin.rs +++ b/crates/nu-plugin/src/plugin/interface/plugin.rs @@ -12,10 +12,11 @@ use crate::{ PluginOutput, ProtocolInfo, StreamId, StreamMessage, }, sequence::Sequence, + util::with_custom_values_in, }; use nu_protocol::{ - ast::Operator, IntoInterruptiblePipelineData, IntoSpanned, ListStream, PipelineData, - PluginSignature, ShellError, Span, Spanned, Value, + ast::Operator, CustomValue, IntoInterruptiblePipelineData, IntoSpanned, ListStream, + PipelineData, PluginSignature, ShellError, Span, Spanned, Value, }; use std::{ collections::{btree_map, BTreeMap}, @@ -96,10 +97,28 @@ struct PluginCallState { ctrlc: Option>, /// Channel to receive context on to be used if needed context_rx: Option>, + /// Channel for plugin custom values that should be kept alive for the duration of the plugin + /// call. The plugin custom values on this channel are never read, we just hold on to it to keep + /// them in memory so they can be dropped at the end of the call. We hold the sender as well so + /// we can generate the CurrentCallState. + keep_plugin_custom_values: ( + mpsc::Sender, + mpsc::Receiver, + ), /// Number of streams that still need to be read from the plugin call response remaining_streams_to_read: i32, } +impl Drop for PluginCallState { + fn drop(&mut self) { + // Clear the keep custom values channel, so drop notifications can be sent + for value in self.keep_plugin_custom_values.1.try_iter() { + log::trace!("Dropping custom value that was kept: {:?}", value); + drop(value); + } + } +} + /// Manages reading and dispatching messages for [`PluginInterface`]s. /// /// This is not a public API. @@ -264,12 +283,20 @@ impl PluginInterfaceManager { msg: "Tried to spawn the fallback engine call handler more than once" .into(), })?; + + // Generate the state needed to handle engine calls + let current_call_state = CurrentCallState { + context_tx: None, + keep_plugin_custom_values_tx: Some(state.keep_plugin_custom_values.0.clone()), + }; + let handler = move || { // We receive on the thread so that we don't block the reader thread let mut context = context_rx .recv() .ok() // The plugin call won't send context if it's not required. .map(|c| c.0); + for msg in rx { // This thread only handles engine calls. match msg { @@ -277,6 +304,7 @@ impl PluginInterfaceManager { if let Err(err) = interface.handle_engine_call( engine_call_id, engine_call, + ¤t_call_state, context.as_deref_mut(), ) { log::warn!( @@ -486,26 +514,35 @@ impl InterfaceManager for PluginInterfaceManager { result } PluginOutput::EngineCall { context, id, call } => { - // Handle reading the pipeline data, if any - let mut call = call.map_data(|input| { - let ctrlc = self.get_ctrlc(context)?; - self.read_pipeline_data(input, ctrlc.as_ref()) - }); - // Add source to any plugin custom values in the arguments - if let Ok(EngineCall::EvalClosure { - ref mut positional, .. - }) = call - { - for arg in positional.iter_mut() { - PluginCustomValue::add_source(arg, &self.state.source); - } - } + let call = call + // Handle reading the pipeline data, if any + .map_data(|input| { + let ctrlc = self.get_ctrlc(context)?; + self.read_pipeline_data(input, ctrlc.as_ref()) + }) + // Do anything extra needed for each engine call setup + .and_then(|mut engine_call| { + match engine_call { + EngineCall::EvalClosure { + ref mut positional, .. + } => { + for arg in positional.iter_mut() { + // Add source to any plugin custom values in the arguments + PluginCustomValue::add_source_in(arg, &self.state.source)?; + } + Ok(engine_call) + } + _ => Ok(engine_call), + } + }); match call { Ok(call) => self.send_engine_call(context, id, call), // If there was an error with setting up the call, just write the error - Err(err) => self - .get_interface() - .write_engine_call_response(id, EngineCallResponse::Error(err)), + Err(err) => self.get_interface().write_engine_call_response( + id, + EngineCallResponse::Error(err), + &CurrentCallState::default(), + ), } } } @@ -519,14 +556,17 @@ impl InterfaceManager for PluginInterfaceManager { // Add source to any values match data { PipelineData::Value(ref mut value, _) => { - PluginCustomValue::add_source(value, &self.state.source); + with_custom_values_in(value, |custom_value| { + PluginCustomValue::add_source(custom_value.item, &self.state.source); + Ok::<_, ShellError>(()) + })?; Ok(data) } PipelineData::ListStream(ListStream { stream, ctrlc, .. }, meta) => { let source = self.state.source.clone(); Ok(stream .map(move |mut value| { - PluginCustomValue::add_source(&mut value, &source); + let _ = PluginCustomValue::add_source_in(&mut value, &source); value }) .into_pipeline_data_with_metadata(meta, ctrlc)) @@ -581,11 +621,12 @@ impl PluginInterface { &self, id: EngineCallId, response: EngineCallResponse, + state: &CurrentCallState, ) -> Result<(), ShellError> { // Set up any stream if necessary let mut writer = None; let response = response.map_data(|data| { - let (data_header, data_writer) = self.init_write_pipeline_data(data)?; + let (data_header, data_writer) = self.init_write_pipeline_data(data, state)?; writer = Some(data_writer); Ok(data_header) })?; @@ -602,22 +643,26 @@ impl PluginInterface { Ok(()) } - /// Write a plugin call message. Returns the writer for the stream, and the receiver for - /// messages - i.e. response and engine calls - related to the plugin call + /// Write a plugin call message. Returns the writer for the stream. fn write_plugin_call( &self, - call: PluginCall, - ctrlc: Option>, - context_rx: mpsc::Receiver, - ) -> Result< - ( - PipelineDataWriter, - mpsc::Receiver, - ), - ShellError, - > { + mut call: PluginCall, + context: Option<&dyn PluginExecutionContext>, + ) -> Result { let id = self.state.plugin_call_id_sequence.next()?; + let ctrlc = context.and_then(|c| c.ctrlc().cloned()); let (tx, rx) = mpsc::channel(); + let (context_tx, context_rx) = mpsc::channel(); + let keep_plugin_custom_values = mpsc::channel(); + + // Set up the state that will stay alive during the call. + let state = CurrentCallState { + context_tx: Some(context_tx), + keep_plugin_custom_values_tx: Some(keep_plugin_custom_values.0.clone()), + }; + + // Prepare the call with the state. + state.prepare_plugin_call(&mut call, &self.state.source)?; // Convert the call into one with a header and handle the stream, if necessary let (call, writer) = match call { @@ -630,8 +675,8 @@ impl PluginInterface { mut call, input, }) => { - verify_call_args(&mut call, &self.state.source)?; - let (header, writer) = self.init_write_pipeline_data(input)?; + state.prepare_call_args(&mut call, &self.state.source)?; + let (header, writer) = self.init_write_pipeline_data(input, &state)?; ( PluginCall::Run(CallInfo { name, @@ -652,6 +697,7 @@ impl PluginInterface { sender: Some(tx), ctrlc, context_rx: Some(context_rx), + keep_plugin_custom_values, remaining_streams_to_read: 0, }, )) @@ -659,9 +705,9 @@ impl PluginInterface { error: format!("Plugin `{}` closed unexpectedly", self.state.source.name()), msg: "can't complete this operation because the plugin is closed".into(), span: match &call { - PluginCall::CustomValueOp(value, _) => Some(value.span), - PluginCall::Run(info) => Some(info.call.head), - _ => None, + PluginCall::Signature => None, + PluginCall::Run(CallInfo { call, .. }) => Some(call.head), + PluginCall::CustomValueOp(val, _) => Some(val.span), }, help: Some(format!( "the plugin may have experienced an error. Try registering the plugin again \ @@ -679,11 +725,22 @@ impl PluginInterface { inner: vec![], })?; + // Starting a plugin call adds a lock on the GC. Locks are not added for streams being read + // by the plugin, so the plugin would have to explicitly tell us if it expects to stay alive + // while reading streams in the background after the response ends. + if let Some(ref gc) = self.gc { + gc.increment_locks(1); + } + // Write request self.write(PluginInput::Call(id, call))?; self.flush()?; - Ok((writer, rx)) + Ok(WritePluginCallResult { + receiver: rx, + writer, + state, + }) } /// Read the channel for plugin call messages and handle them until the response is received. @@ -691,7 +748,7 @@ impl PluginInterface { &self, rx: mpsc::Receiver, mut context: Option<&mut (dyn PluginExecutionContext + '_)>, - context_tx: mpsc::Sender, + state: CurrentCallState, ) -> Result, ShellError> { // Handle message from receiver for msg in rx { @@ -700,7 +757,9 @@ impl PluginInterface { if resp.has_stream() { // If the response has a stream, we need to register the context if let Some(context) = context { - let _ = context_tx.send(Context(context.boxed())); + if let Some(ref context_tx) = state.context_tx { + let _ = context_tx.send(Context(context.boxed())); + } } } return Ok(resp); @@ -709,7 +768,12 @@ impl PluginInterface { return Err(err); } ReceivedPluginCallMessage::EngineCall(engine_call_id, engine_call) => { - self.handle_engine_call(engine_call_id, engine_call, context.as_deref_mut())?; + self.handle_engine_call( + engine_call_id, + engine_call, + &state, + context.as_deref_mut(), + )?; } } } @@ -724,6 +788,7 @@ impl PluginInterface { &self, engine_call_id: EngineCallId, engine_call: EngineCall, + state: &CurrentCallState, context: Option<&mut (dyn PluginExecutionContext + '_)>, ) -> Result<(), ShellError> { let resp = @@ -732,7 +797,7 @@ impl PluginInterface { let mut writer = None; let resp = resp .map_data(|data| { - let (data_header, data_writer) = self.init_write_pipeline_data(data)?; + let (data_header, data_writer) = self.init_write_pipeline_data(data, state)?; writer = Some(data_writer); Ok(data_header) }) @@ -762,26 +827,12 @@ impl PluginInterface { return Err(error.clone()); } - // Starting a plugin call adds a lock on the GC. Locks are not added for streams being read - // by the plugin, so the plugin would have to explicitly tell us if it expects to stay alive - // while reading streams in the background after the response ends. - if let Some(ref gc) = self.gc { - gc.increment_locks(1); - } - - // Create the channel to send context on if needed - let (context_tx, context_rx) = mpsc::channel(); - - let (writer, rx) = self.write_plugin_call( - call, - context.as_ref().and_then(|c| c.ctrlc().cloned()), - context_rx, - )?; + let result = self.write_plugin_call(call, context.as_deref())?; // Finish writing stream in the background - writer.write_background()?; + result.writer.write_background()?; - self.receive_plugin_call_response(rx, context, context_tx) + self.receive_plugin_call_response(result.receiver, context, result.state) } /// Get the command signatures from the plugin. @@ -858,9 +909,8 @@ impl PluginInterface { pub fn custom_value_partial_cmp( &self, value: PluginCustomValue, - mut other_value: Value, + other_value: Value, ) -> Result, ShellError> { - PluginCustomValue::verify_source(&mut other_value, &self.state.source)?; // Note: the protocol is always designed to have a span with the custom value, but this // operation doesn't support one. let call = PluginCall::CustomValueOp( @@ -881,9 +931,8 @@ impl PluginInterface { &self, left: Spanned, operator: Spanned, - mut right: Value, + right: Value, ) -> Result { - PluginCustomValue::verify_source(&mut right, &self.state.source)?; self.custom_value_op_expecting_value(left, CustomValueOp::Operation(operator, right)) } @@ -899,22 +948,9 @@ impl PluginInterface { } } -/// Check that custom values in call arguments come from the right source -fn verify_call_args( - call: &mut crate::EvaluatedCall, - source: &Arc, -) -> Result<(), ShellError> { - for arg in call.positional.iter_mut() { - PluginCustomValue::verify_source(arg, source)?; - } - for arg in call.named.iter_mut().flat_map(|(_, arg)| arg.as_mut()) { - PluginCustomValue::verify_source(arg, source)?; - } - Ok(()) -} - impl Interface for PluginInterface { type Output = PluginInput; + type DataContext = CurrentCallState; fn write(&self, input: PluginInput) -> Result<(), ShellError> { log::trace!("to plugin: {:?}", input); @@ -933,18 +969,23 @@ impl Interface for PluginInterface { &self.stream_manager_handle } - fn prepare_pipeline_data(&self, data: PipelineData) -> Result { + fn prepare_pipeline_data( + &self, + data: PipelineData, + state: &CurrentCallState, + ) -> Result { // Validate the destination of values in the pipeline data match data { PipelineData::Value(mut value, meta) => { - PluginCustomValue::verify_source(&mut value, &self.state.source)?; + state.prepare_value(&mut value, &self.state.source)?; Ok(PipelineData::Value(value, meta)) } PipelineData::ListStream(ListStream { stream, ctrlc, .. }, meta) => { let source = self.state.source.clone(); + let state = state.clone(); Ok(stream .map(move |mut value| { - match PluginCustomValue::verify_source(&mut value, &source) { + match state.prepare_value(&mut value, &source) { Ok(()) => value, // Put the error in the stream instead Err(err) => Value::error(err, value.span()), @@ -972,6 +1013,113 @@ impl Drop for PluginInterface { } } +/// Return value of [`PluginInterface::write_plugin_call()`]. +#[must_use] +struct WritePluginCallResult { + /// Receiver for plugin call messages related to the written plugin call. + receiver: mpsc::Receiver, + /// Writer for the stream, if any. + writer: PipelineDataWriter, + /// State to be kept for the duration of the plugin call. + state: CurrentCallState, +} + +/// State related to the current plugin call being executed. +/// +/// This is not a public API. +#[doc(hidden)] +#[derive(Default, Clone)] +pub struct CurrentCallState { + /// Sender for context, which should be sent if the plugin call returned a stream so that + /// engine calls may continue to be handled. + context_tx: Option>, + /// Sender for a channel that retains plugin custom values that need to stay alive for the + /// duration of a plugin call. + keep_plugin_custom_values_tx: Option>, +} + +impl CurrentCallState { + /// Prepare a custom value for write. Verifies custom value origin, and keeps custom values that + /// shouldn't be dropped immediately. + fn prepare_custom_value( + &self, + custom_value: Spanned<&mut (dyn CustomValue + '_)>, + source: &PluginSource, + ) -> Result<(), ShellError> { + // Ensure we can use it + PluginCustomValue::verify_source(custom_value.as_deref(), source)?; + + // Check whether we need to keep it + if let Some(keep_tx) = &self.keep_plugin_custom_values_tx { + if let Some(custom_value) = custom_value + .item + .as_any() + .downcast_ref::() + { + if custom_value.notify_on_drop() { + log::trace!("Keeping custom value for drop later: {:?}", custom_value); + keep_tx + .send(custom_value.clone()) + .map_err(|_| ShellError::NushellFailed { + msg: "Failed to custom value to keep channel".into(), + })?; + } + } + } + Ok(()) + } + + /// Prepare a value for write, including all contained custom values. + fn prepare_value(&self, value: &mut Value, source: &PluginSource) -> Result<(), ShellError> { + with_custom_values_in(value, |custom_value| { + self.prepare_custom_value(custom_value, source) + }) + } + + /// Prepare call arguments for write. + fn prepare_call_args( + &self, + call: &mut crate::EvaluatedCall, + source: &PluginSource, + ) -> Result<(), ShellError> { + for arg in call.positional.iter_mut() { + self.prepare_value(arg, source)?; + } + for arg in call.named.iter_mut().flat_map(|(_, arg)| arg.as_mut()) { + self.prepare_value(arg, source)?; + } + Ok(()) + } + + /// Prepare a plugin call for write. Does not affect pipeline data, which is handled by + /// `prepare_pipeline_data()` instead. + fn prepare_plugin_call( + &self, + call: &mut PluginCall, + source: &PluginSource, + ) -> Result<(), ShellError> { + match call { + PluginCall::Signature => Ok(()), + PluginCall::Run(CallInfo { call, .. }) => self.prepare_call_args(call, source), + PluginCall::CustomValueOp(custom_value, op) => { + // `source` isn't present on Dropped. + if !matches!(op, CustomValueOp::Dropped) { + self.prepare_custom_value(custom_value.as_mut().map(|r| r as &mut _), source)?; + } + // Handle anything within the op. + match op { + CustomValueOp::ToBaseValue => Ok(()), + CustomValueOp::FollowPathInt(_) => Ok(()), + CustomValueOp::FollowPathString(_) => Ok(()), + CustomValueOp::PartialCmp(value) => self.prepare_value(value, source), + CustomValueOp::Operation(_, value) => self.prepare_value(value, source), + CustomValueOp::Dropped => Ok(()), + } + } + } + } +} + /// Handle an engine call. pub(crate) fn handle_engine_call( call: EngineCall, diff --git a/crates/nu-plugin/src/plugin/interface/plugin/tests.rs b/crates/nu-plugin/src/plugin/interface/plugin/tests.rs index 2a25c8e50b..3d66fe9624 100644 --- a/crates/nu-plugin/src/plugin/interface/plugin/tests.rs +++ b/crates/nu-plugin/src/plugin/interface/plugin/tests.rs @@ -4,11 +4,14 @@ use super::{ use crate::{ plugin::{ context::PluginExecutionBogusContext, - interface::{test_util::TestCase, Interface, InterfaceManager}, + interface::{plugin::CurrentCallState, test_util::TestCase, Interface, InterfaceManager}, PluginSource, }, protocol::{ - test_util::{expected_test_custom_value, test_plugin_custom_value}, + test_util::{ + expected_test_custom_value, test_plugin_custom_value, + test_plugin_custom_value_with_source, + }, CallInfo, CustomValueOp, EngineCall, EngineCallResponse, ExternalStreamInfo, ListStreamInfo, PipelineDataHeader, PluginCall, PluginCallId, PluginCustomValue, PluginInput, Protocol, ProtocolInfo, RawStreamInfo, StreamData, StreamMessage, @@ -16,10 +19,16 @@ use crate::{ EvaluatedCall, PluginCallResponse, PluginOutput, }; use nu_protocol::{ - engine::Closure, IntoInterruptiblePipelineData, PipelineData, PluginSignature, ShellError, - Span, Spanned, Value, + ast::{Math, Operator}, + engine::Closure, + CustomValue, IntoInterruptiblePipelineData, IntoSpanned, PipelineData, PluginSignature, + ShellError, Span, Spanned, Value, +}; +use serde::{Deserialize, Serialize}; +use std::{ + sync::{mpsc, Arc}, + time::Duration, }; -use std::{sync::mpsc, time::Duration}; #[test] fn manager_consume_all_consumes_messages() -> Result<(), ShellError> { @@ -186,6 +195,7 @@ fn fake_plugin_call( sender: Some(tx), ctrlc: None, context_rx: None, + keep_plugin_custom_values: mpsc::channel(), remaining_streams_to_read: 0, }, ); @@ -488,6 +498,7 @@ fn manager_handle_engine_call_after_response_received() -> Result<(), ShellError sender: None, ctrlc: None, context_rx: Some(context_rx), + keep_plugin_custom_values: mpsc::channel(), remaining_streams_to_read: 1, }, ); @@ -551,6 +562,7 @@ fn manager_send_plugin_call_response_removes_context_only_if_no_streams_to_read( sender: None, ctrlc: None, context_rx: None, + keep_plugin_custom_values: mpsc::channel(), remaining_streams_to_read: n as i32, }, ); @@ -584,6 +596,7 @@ fn manager_consume_stream_end_removes_context_only_if_last_stream() -> Result<() sender: None, ctrlc: None, context_rx: None, + keep_plugin_custom_values: mpsc::channel(), remaining_streams_to_read: n as i32, }, ); @@ -734,7 +747,7 @@ fn interface_write_plugin_call_registers_subscription() -> Result<(), ShellError ); let interface = manager.get_interface(); - let _ = interface.write_plugin_call(PluginCall::Signature, None, mpsc::channel().1)?; + let _ = interface.write_plugin_call(PluginCall::Signature, None)?; manager.receive_plugin_call_subscriptions(); assert!(!manager.plugin_call_states.is_empty(), "not registered"); @@ -747,9 +760,8 @@ fn interface_write_plugin_call_writes_signature() -> Result<(), ShellError> { let manager = test.plugin("test"); let interface = manager.get_interface(); - let (writer, _) = - interface.write_plugin_call(PluginCall::Signature, None, mpsc::channel().1)?; - writer.write()?; + let result = interface.write_plugin_call(PluginCall::Signature, None)?; + result.writer.write()?; let written = test.next_written().expect("nothing written"); match written { @@ -768,18 +780,17 @@ fn interface_write_plugin_call_writes_custom_value_op() -> Result<(), ShellError let manager = test.plugin("test"); let interface = manager.get_interface(); - let (writer, _) = interface.write_plugin_call( + let result = interface.write_plugin_call( PluginCall::CustomValueOp( Spanned { - item: test_plugin_custom_value(), + item: test_plugin_custom_value_with_source(), span: Span::test_data(), }, CustomValueOp::ToBaseValue, ), None, - mpsc::channel().1, )?; - writer.write()?; + result.writer.write()?; let written = test.next_written().expect("nothing written"); match written { @@ -801,7 +812,7 @@ fn interface_write_plugin_call_writes_run_with_value_input() -> Result<(), Shell let manager = test.plugin("test"); let interface = manager.get_interface(); - let (writer, _) = interface.write_plugin_call( + let result = interface.write_plugin_call( PluginCall::Run(CallInfo { name: "foo".into(), call: EvaluatedCall { @@ -812,9 +823,8 @@ fn interface_write_plugin_call_writes_run_with_value_input() -> Result<(), Shell input: PipelineData::Value(Value::test_int(-1), None), }), None, - mpsc::channel().1, )?; - writer.write()?; + result.writer.write()?; let written = test.next_written().expect("nothing written"); match written { @@ -840,7 +850,7 @@ fn interface_write_plugin_call_writes_run_with_stream_input() -> Result<(), Shel let interface = manager.get_interface(); let values = vec![Value::test_int(1), Value::test_int(2)]; - let (writer, _) = interface.write_plugin_call( + let result = interface.write_plugin_call( PluginCall::Run(CallInfo { name: "foo".into(), call: EvaluatedCall { @@ -851,9 +861,8 @@ fn interface_write_plugin_call_writes_run_with_stream_input() -> Result<(), Shel input: values.clone().into_pipeline_data(None), }), None, - mpsc::channel().1, )?; - writer.write()?; + result.writer.write()?; let written = test.next_written().expect("nothing written"); let info = match written { @@ -914,7 +923,7 @@ fn interface_receive_plugin_call_receives_response() -> Result<(), ShellError> { .expect("failed to send on new channel"); drop(tx); // so we don't deadlock on recv() - let response = interface.receive_plugin_call_response(rx, None, mpsc::channel().0)?; + let response = interface.receive_plugin_call_response(rx, None, CurrentCallState::default())?; assert!( matches!(response, PluginCallResponse::Signature(_)), "wrong response: {response:?}" @@ -937,7 +946,7 @@ fn interface_receive_plugin_call_receives_error() -> Result<(), ShellError> { drop(tx); // so we don't deadlock on recv() let error = interface - .receive_plugin_call_response(rx, None, mpsc::channel().0) + .receive_plugin_call_response(rx, None, CurrentCallState::default()) .expect_err("did not receive error"); assert!( matches!(error, ShellError::ExternalNotSupported { .. }), @@ -966,7 +975,7 @@ fn interface_receive_plugin_call_handles_engine_call() -> Result<(), ShellError> // an error, but it should still do the engine call drop(tx); interface - .receive_plugin_call_response(rx, Some(&mut context), mpsc::channel().0) + .receive_plugin_call_response(rx, Some(&mut context), CurrentCallState::default()) .expect_err("no error even though there was no response"); // Check for the engine call response output @@ -1083,7 +1092,7 @@ fn interface_custom_value_to_base_value() -> Result<(), ShellError> { }); let result = interface.custom_value_to_base_value(Spanned { - item: test_plugin_custom_value(), + item: test_plugin_custom_value_with_source(), span: Span::test_data(), })?; @@ -1108,8 +1117,9 @@ fn normal_values(interface: &PluginInterface) -> Vec { #[test] fn interface_prepare_pipeline_data_accepts_normal_values() -> Result<(), ShellError> { let interface = TestCase::new().plugin("test").get_interface(); + let state = CurrentCallState::default(); for value in normal_values(&interface) { - match interface.prepare_pipeline_data(PipelineData::Value(value.clone(), None)) { + match interface.prepare_pipeline_data(PipelineData::Value(value.clone(), None), &state) { Ok(data) => assert_eq!( value.get_type(), data.into_value(Span::test_data()).get_type() @@ -1124,7 +1134,8 @@ fn interface_prepare_pipeline_data_accepts_normal_values() -> Result<(), ShellEr fn interface_prepare_pipeline_data_accepts_normal_streams() -> Result<(), ShellError> { let interface = TestCase::new().plugin("test").get_interface(); let values = normal_values(&interface); - let data = interface.prepare_pipeline_data(values.clone().into_pipeline_data(None))?; + let state = CurrentCallState::default(); + let data = interface.prepare_pipeline_data(values.clone().into_pipeline_data(None), &state)?; let mut count = 0; for (expected_value, actual_value) in values.iter().zip(data) { @@ -1168,8 +1179,9 @@ fn bad_custom_values() -> Vec { #[test] fn interface_prepare_pipeline_data_rejects_bad_custom_value() -> Result<(), ShellError> { let interface = TestCase::new().plugin("test").get_interface(); + let state = CurrentCallState::default(); for value in bad_custom_values() { - match interface.prepare_pipeline_data(PipelineData::Value(value.clone(), None)) { + match interface.prepare_pipeline_data(PipelineData::Value(value.clone(), None), &state) { Err(err) => match err { ShellError::CustomValueIncorrectForPlugin { .. } => (), _ => panic!("expected error type CustomValueIncorrectForPlugin, but got {err:?}"), @@ -1185,7 +1197,8 @@ fn interface_prepare_pipeline_data_rejects_bad_custom_value_in_a_stream() -> Res { let interface = TestCase::new().plugin("test").get_interface(); let values = bad_custom_values(); - let data = interface.prepare_pipeline_data(values.clone().into_pipeline_data(None))?; + let state = CurrentCallState::default(); + let data = interface.prepare_pipeline_data(values.clone().into_pipeline_data(None), &state)?; let mut count = 0; for value in data { @@ -1199,3 +1212,297 @@ fn interface_prepare_pipeline_data_rejects_bad_custom_value_in_a_stream() -> Res ); Ok(()) } + +#[test] +fn prepare_custom_value_verifies_source() { + let span = Span::test_data(); + let source = Arc::new(PluginSource::new_fake("test")); + + let mut val = test_plugin_custom_value(); + assert!(CurrentCallState::default() + .prepare_custom_value( + Spanned { + item: &mut val, + span, + }, + &source + ) + .is_err()); + + let mut val = test_plugin_custom_value().with_source(Some(source.clone())); + assert!(CurrentCallState::default() + .prepare_custom_value( + Spanned { + item: &mut val, + span, + }, + &source + ) + .is_ok()); +} + +#[derive(Debug, Serialize, Deserialize)] +struct DropCustomVal; +#[typetag::serde] +impl CustomValue for DropCustomVal { + fn clone_value(&self, _span: Span) -> Value { + unimplemented!() + } + + fn type_name(&self) -> String { + "DropCustomVal".into() + } + + fn to_base_value(&self, _span: Span) -> Result { + unimplemented!() + } + + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn as_mut_any(&mut self) -> &mut dyn std::any::Any { + self + } + + fn notify_plugin_on_drop(&self) -> bool { + true + } +} + +#[test] +fn prepare_custom_value_sends_to_keep_channel_if_drop_notify() -> Result<(), ShellError> { + let span = Span::test_data(); + let source = Arc::new(PluginSource::new_fake("test")); + let (tx, rx) = mpsc::channel(); + let state = CurrentCallState { + context_tx: None, + keep_plugin_custom_values_tx: Some(tx), + }; + // Try with a custom val that has drop check set + let mut drop_val = PluginCustomValue::serialize_from_custom_value(&DropCustomVal, span)? + .with_source(Some(source.clone())); + state.prepare_custom_value( + Spanned { + item: &mut drop_val, + span, + }, + &source, + )?; + // Check that the custom value was actually sent + assert!(rx.try_recv().is_ok()); + // Now try with one that doesn't have it + let mut not_drop_val = test_plugin_custom_value().with_source(Some(source.clone())); + state.prepare_custom_value( + Spanned { + item: &mut not_drop_val, + span, + }, + &source, + )?; + // Should not have been sent to the channel + assert!(rx.try_recv().is_err()); + Ok(()) +} + +#[test] +fn prepare_plugin_call_run() { + // Check that args are handled + let span = Span::test_data(); + let source = Arc::new(PluginSource::new_fake("test")); + let other_source = Arc::new(PluginSource::new_fake("other")); + let cv_ok = test_plugin_custom_value() + .with_source(Some(source.clone())) + .into_value(span); + let cv_bad = test_plugin_custom_value() + .with_source(Some(other_source)) + .into_value(span); + + let fixtures = [ + ( + true, // should succeed + PluginCall::Run(CallInfo { + name: "".into(), + call: EvaluatedCall { + head: span, + positional: vec![Value::test_int(4)], + named: vec![("x".to_owned().into_spanned(span), Some(Value::test_int(6)))], + }, + input: PipelineData::Empty, + }), + ), + ( + true, // should succeed + PluginCall::Run(CallInfo { + name: "".into(), + call: EvaluatedCall { + head: span, + positional: vec![cv_ok.clone()], + named: vec![("ok".to_owned().into_spanned(span), Some(cv_ok.clone()))], + }, + input: PipelineData::Empty, + }), + ), + ( + false, // should fail + PluginCall::Run(CallInfo { + name: "".into(), + call: EvaluatedCall { + head: span, + positional: vec![cv_bad.clone()], + named: vec![], + }, + input: PipelineData::Empty, + }), + ), + ( + false, // should fail + PluginCall::Run(CallInfo { + name: "".into(), + call: EvaluatedCall { + head: span, + positional: vec![], + named: vec![("bad".to_owned().into_spanned(span), Some(cv_bad.clone()))], + }, + input: PipelineData::Empty, + }), + ), + ( + true, // should succeed + PluginCall::Run(CallInfo { + name: "".into(), + call: EvaluatedCall { + head: span, + positional: vec![], + named: vec![], + }, + // Shouldn't check input - that happens somewhere else + input: PipelineData::Value(cv_bad.clone(), None), + }), + ), + ]; + + for (should_succeed, mut fixture) in fixtures { + let result = CurrentCallState::default().prepare_plugin_call(&mut fixture, &source); + if should_succeed { + assert!( + result.is_ok(), + "Expected success, but failed with {:?} on {fixture:#?}", + result.unwrap_err(), + ); + } else { + assert!( + result.is_err(), + "Expected failure, but succeeded on {fixture:#?}", + ); + } + } +} + +#[test] +fn prepare_plugin_call_custom_value_op() { + // Check behavior with custom value ops + let span = Span::test_data(); + let source = Arc::new(PluginSource::new_fake("test")); + let other_source = Arc::new(PluginSource::new_fake("other")); + let cv_ok = test_plugin_custom_value().with_source(Some(source.clone())); + let cv_ok_val = cv_ok.clone_value(span); + let cv_bad = test_plugin_custom_value().with_source(Some(other_source)); + let cv_bad_val = cv_bad.clone_value(span); + + let fixtures = [ + ( + true, // should succeed + PluginCall::CustomValueOp::( + Spanned { + item: cv_ok.clone(), + span, + }, + CustomValueOp::ToBaseValue, + ), + ), + ( + false, // should fail + PluginCall::CustomValueOp( + Spanned { + item: cv_bad.clone(), + span, + }, + CustomValueOp::ToBaseValue, + ), + ), + ( + true, // should succeed + PluginCall::CustomValueOp( + Spanned { + item: test_plugin_custom_value(), + span, + }, + // Dropped shouldn't check. We don't have a source set. + CustomValueOp::Dropped, + ), + ), + ( + true, // should succeed + PluginCall::CustomValueOp::( + Spanned { + item: cv_ok.clone(), + span, + }, + CustomValueOp::PartialCmp(cv_ok_val.clone()), + ), + ), + ( + false, // should fail + PluginCall::CustomValueOp( + Spanned { + item: cv_ok.clone(), + span, + }, + CustomValueOp::PartialCmp(cv_bad_val.clone()), + ), + ), + ( + true, // should succeed + PluginCall::CustomValueOp::( + Spanned { + item: cv_ok.clone(), + span, + }, + CustomValueOp::Operation( + Operator::Math(Math::Append).into_spanned(span), + cv_ok_val.clone(), + ), + ), + ), + ( + false, // should fail + PluginCall::CustomValueOp( + Spanned { + item: cv_ok.clone(), + span, + }, + CustomValueOp::Operation( + Operator::Math(Math::Append).into_spanned(span), + cv_bad_val.clone(), + ), + ), + ), + ]; + + for (should_succeed, mut fixture) in fixtures { + let result = CurrentCallState::default().prepare_plugin_call(&mut fixture, &source); + if should_succeed { + assert!( + result.is_ok(), + "Expected success, but failed with {:?} on {fixture:#?}", + result.unwrap_err(), + ); + } else { + assert!( + result.is_err(), + "Expected failure, but succeeded on {fixture:#?}", + ); + } + } +} diff --git a/crates/nu-plugin/src/plugin/interface/tests.rs b/crates/nu-plugin/src/plugin/interface/tests.rs index a96be38f5a..706798b31f 100644 --- a/crates/nu-plugin/src/plugin/interface/tests.rs +++ b/crates/nu-plugin/src/plugin/interface/tests.rs @@ -82,6 +82,7 @@ impl InterfaceManager for TestInterfaceManager { impl Interface for TestInterface { type Output = PluginOutput; + type DataContext = (); fn write(&self, output: Self::Output) -> Result<(), ShellError> { self.test.write(&output) @@ -99,7 +100,11 @@ impl Interface for TestInterface { &self.stream_manager_handle } - fn prepare_pipeline_data(&self, data: PipelineData) -> Result { + fn prepare_pipeline_data( + &self, + data: PipelineData, + _context: &(), + ) -> Result { // Add an arbitrary check to the data to verify this is being called match data { PipelineData::Value(Value::Binary { .. }, None) => Err(ShellError::NushellFailed { @@ -318,7 +323,7 @@ fn write_pipeline_data_empty() -> Result<(), ShellError> { let manager = TestInterfaceManager::new(&test); let interface = manager.get_interface(); - let (header, writer) = interface.init_write_pipeline_data(PipelineData::Empty)?; + let (header, writer) = interface.init_write_pipeline_data(PipelineData::Empty, &())?; assert!(matches!(header, PipelineDataHeader::Empty)); @@ -340,7 +345,7 @@ fn write_pipeline_data_value() -> Result<(), ShellError> { let value = Value::test_int(7); let (header, writer) = - interface.init_write_pipeline_data(PipelineData::Value(value.clone(), None))?; + interface.init_write_pipeline_data(PipelineData::Value(value.clone(), None), &())?; match header { PipelineDataHeader::Value(read_value) => assert_eq!(value, read_value), @@ -365,7 +370,7 @@ fn write_pipeline_data_prepared_properly() { // Sending a binary should be an error in our test scenario let value = Value::test_binary(vec![7, 8]); - match interface.init_write_pipeline_data(PipelineData::Value(value, None)) { + match interface.init_write_pipeline_data(PipelineData::Value(value, None), &()) { Ok(_) => panic!("prepare_pipeline_data was not called"), Err(err) => { assert_eq!( @@ -397,7 +402,7 @@ fn write_pipeline_data_list_stream() -> Result<(), ShellError> { None, ); - let (header, writer) = interface.init_write_pipeline_data(pipe)?; + let (header, writer) = interface.init_write_pipeline_data(pipe, &())?; let info = match header { PipelineDataHeader::ListStream(info) => info, @@ -472,7 +477,7 @@ fn write_pipeline_data_external_stream() -> Result<(), ShellError> { trim_end_newline: true, }; - let (header, writer) = interface.init_write_pipeline_data(pipe)?; + let (header, writer) = interface.init_write_pipeline_data(pipe, &())?; let info = match header { PipelineDataHeader::ExternalStream(info) => info, diff --git a/crates/nu-plugin/src/protocol/plugin_custom_value.rs b/crates/nu-plugin/src/protocol/plugin_custom_value.rs index bd10c03029..b4bab9f013 100644 --- a/crates/nu-plugin/src/protocol/plugin_custom_value.rs +++ b/crates/nu-plugin/src/protocol/plugin_custom_value.rs @@ -1,7 +1,11 @@ -use crate::plugin::{PluginInterface, PluginSource}; -use nu_protocol::{ast::Operator, CustomValue, IntoSpanned, ShellError, Span, Value}; +use std::{cmp::Ordering, sync::Arc}; + +use crate::{ + plugin::{PluginInterface, PluginSource}, + util::with_custom_values_in, +}; +use nu_protocol::{ast::Operator, CustomValue, IntoSpanned, ShellError, Span, Spanned, Value}; use serde::{Deserialize, Serialize}; -use std::{cmp::Ordering, convert::Infallible, sync::Arc}; #[cfg(test)] mod tests; @@ -50,10 +54,16 @@ fn is_false(b: &bool) -> bool { !b } +impl PluginCustomValue { + pub fn into_value(self, span: Span) -> Value { + Value::custom(Box::new(self), span) + } +} + #[typetag::serde] impl CustomValue for PluginCustomValue { fn clone_value(&self, span: Span) -> Value { - Value::custom(Box::new(self.clone()), span) + self.clone().into_value(span) } fn type_name(&self) -> String { @@ -127,6 +137,10 @@ impl CustomValue for PluginCustomValue { fn as_any(&self) -> &dyn std::any::Any { self } + + fn as_mut_any(&mut self) -> &mut dyn std::any::Any { + self + } } impl PluginCustomValue { @@ -164,11 +178,15 @@ impl PluginCustomValue { /// Which plugin the custom value came from. This is not defined on the plugin side. The engine /// side is responsible for maintaining it, and it is not sent over the serialization boundary. - #[cfg(test)] - pub(crate) fn source(&self) -> &Option> { + pub fn source(&self) -> &Option> { &self.source } + /// Set the [`PluginSource`] for this [`PluginCustomValue`]. + pub fn set_source(&mut self, source: Option>) { + self.source = source; + } + /// Create the [`PluginCustomValue`] with the given source. #[cfg(test)] pub(crate) fn with_source(mut self, source: Option>) -> PluginCustomValue { @@ -234,84 +252,55 @@ impl PluginCustomValue { }) } - /// Add a [`PluginSource`] to all [`PluginCustomValue`]s within a value, recursively. - pub fn add_source(value: &mut Value, source: &Arc) { - // This can't cause an error. - let _: Result<(), Infallible> = value.recurse_mut(&mut |value| { - let span = value.span(); - match value { - // Set source on custom value - Value::Custom { ref val, .. } => { - if let Some(custom_value) = val.as_any().downcast_ref::() { - // Since there's no `as_mut_any()`, we have to copy the whole thing - let mut custom_value = custom_value.clone(); - custom_value.source = Some(source.clone()); - *value = Value::custom(Box::new(custom_value), span); - } - Ok(()) - } - // LazyRecord could generate other values, but we shouldn't be receiving it anyway - // - // It's better to handle this as a bug - Value::LazyRecord { .. } => unimplemented!("add_source for LazyRecord"), - _ => Ok(()), - } - }); + /// Add a [`PluginSource`] to the given [`CustomValue`] if it is a [`PluginCustomValue`]. + pub fn add_source(value: &mut dyn CustomValue, source: &Arc) { + if let Some(custom_value) = value.as_mut_any().downcast_mut::() { + custom_value.set_source(Some(source.clone())); + } } - /// Check that all [`CustomValue`]s present within the `value` are [`PluginCustomValue`]s that - /// come from the given `source`, and return an error if not. + /// Add a [`PluginSource`] to all [`PluginCustomValue`]s within the value, recursively. + pub fn add_source_in(value: &mut Value, source: &Arc) -> Result<(), ShellError> { + with_custom_values_in(value, |custom_value| { + Self::add_source(custom_value.item, source); + Ok::<_, ShellError>(()) + }) + } + + /// Check that a [`CustomValue`] is a [`PluginCustomValue`] that come from the given `source`, + /// and return an error if not. /// /// This method will collapse `LazyRecord` in-place as necessary to make the guarantee, /// since `LazyRecord` could return something different the next time it is called. pub(crate) fn verify_source( - value: &mut Value, + value: Spanned<&dyn CustomValue>, source: &PluginSource, ) -> Result<(), ShellError> { - value.recurse_mut(&mut |value| { - let span = value.span(); - match value { - // Set source on custom value - Value::Custom { val, .. } => { - if let Some(custom_value) = val.as_any().downcast_ref::() { - if custom_value - .source - .as_ref() - .map(|s| s.is_compatible(source)) - .unwrap_or(false) - { - Ok(()) - } else { - Err(ShellError::CustomValueIncorrectForPlugin { - name: custom_value.name().to_owned(), - span, - dest_plugin: source.name().to_owned(), - src_plugin: custom_value - .source - .as_ref() - .map(|s| s.name().to_owned()), - }) - } - } else { - // Only PluginCustomValues can be sent - Err(ShellError::CustomValueIncorrectForPlugin { - name: val.type_name(), - span, - dest_plugin: source.name().to_owned(), - src_plugin: None, - }) - } - } - // LazyRecord would be a problem for us, since it could return something else the - // next time, and we have to collect it anyway to serialize it. Collect it in place, - // and then verify the source of the result - Value::LazyRecord { val, .. } => { - *value = val.collect()?; - Ok(()) - } - _ => Ok(()), + if let Some(custom_value) = value.item.as_any().downcast_ref::() { + if custom_value + .source + .as_ref() + .map(|s| s.is_compatible(source)) + .unwrap_or(false) + { + Ok(()) + } else { + Err(ShellError::CustomValueIncorrectForPlugin { + name: custom_value.name().to_owned(), + span: value.span, + dest_plugin: source.name().to_owned(), + src_plugin: custom_value.source.as_ref().map(|s| s.name().to_owned()), + }) } - }) + } else { + // Only PluginCustomValues can be sent + Err(ShellError::CustomValueIncorrectForPlugin { + name: value.item.type_name(), + span: value.span, + dest_plugin: source.name().to_owned(), + src_plugin: None, + }) + } } /// Convert all plugin-native custom values to [`PluginCustomValue`] within the given `value`, diff --git a/crates/nu-plugin/src/protocol/plugin_custom_value/tests.rs b/crates/nu-plugin/src/protocol/plugin_custom_value/tests.rs index aee87e0c2a..5136dc665e 100644 --- a/crates/nu-plugin/src/protocol/plugin_custom_value/tests.rs +++ b/crates/nu-plugin/src/protocol/plugin_custom_value/tests.rs @@ -7,7 +7,8 @@ use crate::{ }, }; use nu_protocol::{ - ast::RangeInclusion, engine::Closure, record, CustomValue, Range, ShellError, Span, Value, + ast::RangeInclusion, engine::Closure, record, CustomValue, IntoSpanned, Range, ShellError, + Span, Value, }; use std::sync::Arc; @@ -42,10 +43,10 @@ fn expected_serialize_output() -> Result<(), ShellError> { } #[test] -fn add_source_at_root() -> Result<(), ShellError> { +fn add_source_in_at_root() -> Result<(), ShellError> { let mut val = Value::test_custom_value(Box::new(test_plugin_custom_value())); let source = Arc::new(PluginSource::new_fake("foo")); - PluginCustomValue::add_source(&mut val, &source); + PluginCustomValue::add_source_in(&mut val, &source)?; let custom_value = val.as_custom_value()?; let plugin_custom_value: &PluginCustomValue = custom_value @@ -78,7 +79,7 @@ fn check_range_custom_values( } #[test] -fn add_source_nested_range() -> Result<(), ShellError> { +fn add_source_in_nested_range() -> Result<(), ShellError> { let orig_custom_val = Value::test_custom_value(Box::new(test_plugin_custom_value())); let mut val = Value::test_range(Range { from: orig_custom_val.clone(), @@ -87,7 +88,7 @@ fn add_source_nested_range() -> Result<(), ShellError> { inclusion: RangeInclusion::Inclusive, }); let source = Arc::new(PluginSource::new_fake("foo")); - PluginCustomValue::add_source(&mut val, &source); + PluginCustomValue::add_source_in(&mut val, &source)?; check_range_custom_values(&val, |name, custom_value| { let plugin_custom_value: &PluginCustomValue = custom_value @@ -122,14 +123,14 @@ fn check_record_custom_values( } #[test] -fn add_source_nested_record() -> Result<(), ShellError> { +fn add_source_in_nested_record() -> Result<(), ShellError> { let orig_custom_val = Value::test_custom_value(Box::new(test_plugin_custom_value())); let mut val = Value::test_record(record! { "foo" => orig_custom_val.clone(), "bar" => orig_custom_val.clone(), }); let source = Arc::new(PluginSource::new_fake("foo")); - PluginCustomValue::add_source(&mut val, &source); + PluginCustomValue::add_source_in(&mut val, &source)?; check_record_custom_values(&val, &["foo", "bar"], |key, custom_value| { let plugin_custom_value: &PluginCustomValue = custom_value @@ -164,11 +165,11 @@ fn check_list_custom_values( } #[test] -fn add_source_nested_list() -> Result<(), ShellError> { +fn add_source_in_nested_list() -> Result<(), ShellError> { let orig_custom_val = Value::test_custom_value(Box::new(test_plugin_custom_value())); let mut val = Value::test_list(vec![orig_custom_val.clone(), orig_custom_val.clone()]); let source = Arc::new(PluginSource::new_fake("foo")); - PluginCustomValue::add_source(&mut val, &source); + PluginCustomValue::add_source_in(&mut val, &source)?; check_list_custom_values(&val, 0..=1, |index, custom_value| { let plugin_custom_value: &PluginCustomValue = custom_value @@ -205,14 +206,14 @@ fn check_closure_custom_values( } #[test] -fn add_source_nested_closure() -> Result<(), ShellError> { +fn add_source_in_nested_closure() -> Result<(), ShellError> { let orig_custom_val = Value::test_custom_value(Box::new(test_plugin_custom_value())); let mut val = Value::test_closure(Closure { block_id: 0, captures: vec![(0, orig_custom_val.clone()), (1, orig_custom_val.clone())], }); let source = Arc::new(PluginSource::new_fake("foo")); - PluginCustomValue::add_source(&mut val, &source); + PluginCustomValue::add_source_in(&mut val, &source)?; check_closure_custom_values(&val, 0..=1, |index, custom_value| { let plugin_custom_value: &PluginCustomValue = custom_value @@ -231,21 +232,25 @@ fn add_source_nested_closure() -> Result<(), ShellError> { #[test] fn verify_source_error_message() -> Result<(), ShellError> { let span = Span::new(5, 7); - let mut ok_val = Value::custom(Box::new(test_plugin_custom_value_with_source()), span); - let mut native_val = Value::custom(Box::new(TestCustomValue(32)), span); - let mut foreign_val = { + let ok_val = test_plugin_custom_value_with_source(); + let native_val = TestCustomValue(32); + let foreign_val = { let mut val = test_plugin_custom_value(); val.source = Some(Arc::new(PluginSource::new_fake("other"))); - Value::custom(Box::new(val), span) + val }; let source = PluginSource::new_fake("test"); - PluginCustomValue::verify_source(&mut ok_val, &source).expect("ok_val should be verified ok"); + PluginCustomValue::verify_source((&ok_val as &dyn CustomValue).into_spanned(span), &source) + .expect("ok_val should be verified ok"); - for (val, src_plugin) in [(&mut native_val, None), (&mut foreign_val, Some("other"))] { - let error = PluginCustomValue::verify_source(val, &source).expect_err(&format!( - "a custom value from {src_plugin:?} should result in an error" - )); + for (val, src_plugin) in [ + (&native_val as &dyn CustomValue, None), + (&foreign_val as &dyn CustomValue, Some("other")), + ] { + let error = PluginCustomValue::verify_source(val.into_spanned(span), &source).expect_err( + &format!("a custom value from {src_plugin:?} should result in an error"), + ); if let ShellError::CustomValueIncorrectForPlugin { name, span: err_span, @@ -265,145 +270,6 @@ fn verify_source_error_message() -> Result<(), ShellError> { Ok(()) } -#[test] -fn verify_source_nested_range() -> Result<(), ShellError> { - let native_val = Value::test_custom_value(Box::new(TestCustomValue(32))); - let source = PluginSource::new_fake("test"); - for (name, mut val) in [ - ( - "from", - Value::test_range(Range { - from: native_val.clone(), - incr: Value::test_nothing(), - to: Value::test_nothing(), - inclusion: RangeInclusion::RightExclusive, - }), - ), - ( - "incr", - Value::test_range(Range { - from: Value::test_nothing(), - incr: native_val.clone(), - to: Value::test_nothing(), - inclusion: RangeInclusion::RightExclusive, - }), - ), - ( - "to", - Value::test_range(Range { - from: Value::test_nothing(), - incr: Value::test_nothing(), - to: native_val.clone(), - inclusion: RangeInclusion::RightExclusive, - }), - ), - ] { - PluginCustomValue::verify_source(&mut val, &source) - .expect_err(&format!("error not generated on {name}")); - } - - let mut ok_range = Value::test_range(Range { - from: Value::test_nothing(), - incr: Value::test_nothing(), - to: Value::test_nothing(), - inclusion: RangeInclusion::RightExclusive, - }); - PluginCustomValue::verify_source(&mut ok_range, &source) - .expect("ok_range should not generate error"); - - Ok(()) -} - -#[test] -fn verify_source_nested_record() -> Result<(), ShellError> { - let native_val = Value::test_custom_value(Box::new(TestCustomValue(32))); - let source = PluginSource::new_fake("test"); - for (name, mut val) in [ - ( - "first element foo", - Value::test_record(record! { - "foo" => native_val.clone(), - "bar" => Value::test_nothing(), - }), - ), - ( - "second element bar", - Value::test_record(record! { - "foo" => Value::test_nothing(), - "bar" => native_val.clone(), - }), - ), - ] { - PluginCustomValue::verify_source(&mut val, &source) - .expect_err(&format!("error not generated on {name}")); - } - - let mut ok_record = Value::test_record(record! {"foo" => Value::test_nothing()}); - PluginCustomValue::verify_source(&mut ok_record, &source) - .expect("ok_record should not generate error"); - - Ok(()) -} - -#[test] -fn verify_source_nested_list() -> Result<(), ShellError> { - let native_val = Value::test_custom_value(Box::new(TestCustomValue(32))); - let source = PluginSource::new_fake("test"); - for (name, mut val) in [ - ( - "first element", - Value::test_list(vec![native_val.clone(), Value::test_nothing()]), - ), - ( - "second element", - Value::test_list(vec![Value::test_nothing(), native_val.clone()]), - ), - ] { - PluginCustomValue::verify_source(&mut val, &source) - .expect_err(&format!("error not generated on {name}")); - } - - let mut ok_list = Value::test_list(vec![Value::test_nothing()]); - PluginCustomValue::verify_source(&mut ok_list, &source) - .expect("ok_list should not generate error"); - - Ok(()) -} - -#[test] -fn verify_source_nested_closure() -> Result<(), ShellError> { - let native_val = Value::test_custom_value(Box::new(TestCustomValue(32))); - let source = PluginSource::new_fake("test"); - for (name, mut val) in [ - ( - "first capture", - Value::test_closure(Closure { - block_id: 0, - captures: vec![(0, native_val.clone()), (1, Value::test_nothing())], - }), - ), - ( - "second capture", - Value::test_closure(Closure { - block_id: 0, - captures: vec![(0, Value::test_nothing()), (1, native_val.clone())], - }), - ), - ] { - PluginCustomValue::verify_source(&mut val, &source) - .expect_err(&format!("error not generated on {name}")); - } - - let mut ok_closure = Value::test_closure(Closure { - block_id: 0, - captures: vec![(0, Value::test_nothing())], - }); - PluginCustomValue::verify_source(&mut ok_closure, &source) - .expect("ok_closure should not generate error"); - - Ok(()) -} - #[test] fn serialize_in_root() -> Result<(), ShellError> { let span = Span::new(4, 10); diff --git a/crates/nu-plugin/src/protocol/test_util.rs b/crates/nu-plugin/src/protocol/test_util.rs index 338177fb2c..6e1fe8cd75 100644 --- a/crates/nu-plugin/src/protocol/test_util.rs +++ b/crates/nu-plugin/src/protocol/test_util.rs @@ -23,6 +23,10 @@ impl CustomValue for TestCustomValue { fn as_any(&self) -> &dyn std::any::Any { self } + + fn as_mut_any(&mut self) -> &mut dyn std::any::Any { + self + } } pub(crate) fn test_plugin_custom_value() -> PluginCustomValue { diff --git a/crates/nu-plugin/src/util/mod.rs b/crates/nu-plugin/src/util/mod.rs index 92818a4054..6a4fe5e5d4 100644 --- a/crates/nu-plugin/src/util/mod.rs +++ b/crates/nu-plugin/src/util/mod.rs @@ -1,3 +1,5 @@ mod mutable_cow; +mod with_custom_values_in; pub(crate) use mutable_cow::*; +pub use with_custom_values_in::*; diff --git a/crates/nu-plugin/src/util/with_custom_values_in.rs b/crates/nu-plugin/src/util/with_custom_values_in.rs new file mode 100644 index 0000000000..bffb6b3737 --- /dev/null +++ b/crates/nu-plugin/src/util/with_custom_values_in.rs @@ -0,0 +1,102 @@ +use nu_protocol::{CustomValue, IntoSpanned, ShellError, Spanned, Value}; + +/// Do something with all [`CustomValue`]s recursively within a `Value`. This is not limited to +/// plugin custom values. +/// +/// `LazyRecord`s will be collected to plain values for completeness. +pub fn with_custom_values_in( + value: &mut Value, + mut f: impl FnMut(Spanned<&mut (dyn CustomValue + '_)>) -> Result<(), E>, +) -> Result<(), E> +where + E: From, +{ + value.recurse_mut(&mut |value| { + let span = value.span(); + match value { + Value::Custom { val, .. } => { + // Operate on a CustomValue. + f(val.as_mut().into_spanned(span)) + } + // LazyRecord would be a problem for us, since it could return something else the + // next time, and we have to collect it anyway to serialize it. Collect it in place, + // and then use the result + Value::LazyRecord { val, .. } => { + *value = val.collect()?; + Ok(()) + } + _ => Ok(()), + } + }) +} + +#[test] +fn find_custom_values() { + use crate::protocol::test_util::test_plugin_custom_value; + use nu_protocol::{ast::RangeInclusion, engine::Closure, record, LazyRecord, Range, Span}; + + #[derive(Debug, Clone)] + struct Lazy; + impl<'a> LazyRecord<'a> for Lazy { + fn column_names(&'a self) -> Vec<&'a str> { + vec!["custom", "plain"] + } + + fn get_column_value(&self, column: &str) -> Result { + Ok(match column { + "custom" => Value::test_custom_value(Box::new(test_plugin_custom_value())), + "plain" => Value::test_int(42), + _ => unimplemented!(), + }) + } + + fn span(&self) -> Span { + Span::test_data() + } + + fn clone_value(&self, span: Span) -> Value { + Value::lazy_record(Box::new(self.clone()), span) + } + } + + let mut cv = Value::test_custom_value(Box::new(test_plugin_custom_value())); + + let mut value = Value::test_record(record! { + "bare" => cv.clone(), + "list" => Value::test_list(vec![ + cv.clone(), + Value::test_int(4), + ]), + "closure" => Value::test_closure( + Closure { + block_id: 0, + captures: vec![(0, cv.clone()), (1, Value::test_string("foo"))] + } + ), + "range" => Value::test_range(Range { + from: cv.clone(), + incr: cv.clone(), + to: cv.clone(), + inclusion: RangeInclusion::Inclusive + }), + "lazy" => Value::test_lazy_record(Box::new(Lazy)), + }); + + // Do with_custom_values_in, and count the number of custom values found + let mut found = 0; + with_custom_values_in::(&mut value, |_| { + found += 1; + Ok(()) + }) + .expect("error"); + assert_eq!(7, found, "found in value"); + + // Try it on bare custom value too + found = 0; + with_custom_values_in::(&mut cv, |_| { + found += 1; + Ok(()) + }) + .expect("error"); + assert_eq!(1, found, "bare custom value didn't work"); +} diff --git a/crates/nu-protocol/src/span.rs b/crates/nu-protocol/src/span.rs index 1184864b8b..db2885c1a4 100644 --- a/crates/nu-protocol/src/span.rs +++ b/crates/nu-protocol/src/span.rs @@ -1,13 +1,54 @@ +use std::ops::Deref; + use miette::SourceSpan; use serde::{Deserialize, Serialize}; /// A spanned area of interest, generic over what kind of thing is of interest -#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq)] +#[derive(Clone, Copy, Debug, Serialize, Deserialize, PartialEq, Eq)] pub struct Spanned { pub item: T, pub span: Span, } +impl Spanned { + /// Map to a spanned reference of the inner type, i.e. `Spanned -> Spanned<&T>`. + pub fn as_ref(&self) -> Spanned<&T> { + Spanned { + item: &self.item, + span: self.span, + } + } + + /// Map to a mutable reference of the inner type, i.e. `Spanned -> Spanned<&mut T>`. + pub fn as_mut(&mut self) -> Spanned<&mut T> { + Spanned { + item: &mut self.item, + span: self.span, + } + } + + /// Map to the result of [`.deref()`](std::ops::Deref::deref) on the inner type. + /// + /// This can be used for example to turn `Spanned>` into `Spanned<&[T]>`. + pub fn as_deref(&self) -> Spanned<&::Target> + where + T: Deref, + { + Spanned { + item: self.item.deref(), + span: self.span, + } + } + + /// Map the spanned item with a function. + pub fn map(self, f: impl FnOnce(T) -> U) -> Spanned { + Spanned { + item: f(self.item), + span: self.span, + } + } +} + /// Helper trait to create [`Spanned`] more ergonomically. pub trait IntoSpanned: Sized { /// Wrap items together with a span into [`Spanned`]. diff --git a/crates/nu-protocol/src/value/custom_value.rs b/crates/nu-protocol/src/value/custom_value.rs index a500753a0b..480ca0018e 100644 --- a/crates/nu-protocol/src/value/custom_value.rs +++ b/crates/nu-protocol/src/value/custom_value.rs @@ -27,6 +27,9 @@ pub trait CustomValue: fmt::Debug + Send + Sync { /// Any representation used to downcast object to its original type fn as_any(&self) -> &dyn std::any::Any; + /// Any representation used to downcast object to its original type (mutable reference) + fn as_mut_any(&mut self) -> &mut dyn std::any::Any; + /// Follow cell path by numeric index (e.g. rows) fn follow_path_int( &self, diff --git a/crates/nu_plugin_custom_values/src/cool_custom_value.rs b/crates/nu_plugin_custom_values/src/cool_custom_value.rs index 058e66da90..d838aa3f9e 100644 --- a/crates/nu_plugin_custom_values/src/cool_custom_value.rs +++ b/crates/nu_plugin_custom_values/src/cool_custom_value.rs @@ -144,4 +144,8 @@ impl CustomValue for CoolCustomValue { fn as_any(&self) -> &dyn std::any::Any { self } + + fn as_mut_any(&mut self) -> &mut dyn std::any::Any { + self + } } diff --git a/crates/nu_plugin_custom_values/src/drop_check.rs b/crates/nu_plugin_custom_values/src/drop_check.rs index 4cb7f54bf9..b23090c37d 100644 --- a/crates/nu_plugin_custom_values/src/drop_check.rs +++ b/crates/nu_plugin_custom_values/src/drop_check.rs @@ -47,6 +47,10 @@ impl CustomValue for DropCheckValue { self } + fn as_mut_any(&mut self) -> &mut dyn std::any::Any { + self + } + fn notify_plugin_on_drop(&self) -> bool { // This is what causes Nushell to let us know when the value is dropped true diff --git a/crates/nu_plugin_custom_values/src/generate.rs b/crates/nu_plugin_custom_values/src/generate.rs index 1176f9ca9a..b4cc8bd6b1 100644 --- a/crates/nu_plugin_custom_values/src/generate.rs +++ b/crates/nu_plugin_custom_values/src/generate.rs @@ -42,6 +42,6 @@ impl SimplePluginCommand for Generate { fn test_examples() -> Result<(), nu_protocol::ShellError> { use nu_plugin_test_support::PluginTest; - PluginTest::new("custom_values", crate::CustomValuePlugin.into())? + PluginTest::new("custom_values", CustomValuePlugin::new().into())? .test_command_examples(&Generate) } diff --git a/crates/nu_plugin_custom_values/src/generate2.rs b/crates/nu_plugin_custom_values/src/generate2.rs index 1f5f2e8ef3..806086f4fb 100644 --- a/crates/nu_plugin_custom_values/src/generate2.rs +++ b/crates/nu_plugin_custom_values/src/generate2.rs @@ -66,6 +66,6 @@ impl SimplePluginCommand for Generate2 { fn test_examples() -> Result<(), nu_protocol::ShellError> { use nu_plugin_test_support::PluginTest; - PluginTest::new("custom_values", crate::CustomValuePlugin.into())? + PluginTest::new("custom_values", crate::CustomValuePlugin::new().into())? .test_command_examples(&Generate2) } diff --git a/crates/nu_plugin_custom_values/src/handle_custom_value.rs b/crates/nu_plugin_custom_values/src/handle_custom_value.rs new file mode 100644 index 0000000000..ac4fc6bbea --- /dev/null +++ b/crates/nu_plugin_custom_values/src/handle_custom_value.rs @@ -0,0 +1,42 @@ +use nu_protocol::{CustomValue, LabeledError, ShellError, Span, Value}; +use serde::{Deserialize, Serialize}; + +/// References a stored handle within the plugin +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct HandleCustomValue(pub u64); + +impl HandleCustomValue { + pub fn into_value(self, span: Span) -> Value { + Value::custom(Box::new(self), span) + } +} + +#[typetag::serde] +impl CustomValue for HandleCustomValue { + fn clone_value(&self, span: Span) -> Value { + self.clone().into_value(span) + } + + fn type_name(&self) -> String { + "HandleCustomValue".into() + } + + fn to_base_value(&self, span: Span) -> Result { + Err(LabeledError::new("Unsupported operation") + .with_label("can't call to_base_value() directly on this", span) + .with_help("HandleCustomValue uses custom_value_to_base_value() on the plugin instead") + .into()) + } + + fn notify_plugin_on_drop(&self) -> bool { + true + } + + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn as_mut_any(&mut self) -> &mut dyn std::any::Any { + self + } +} diff --git a/crates/nu_plugin_custom_values/src/handle_get.rs b/crates/nu_plugin_custom_values/src/handle_get.rs new file mode 100644 index 0000000000..017ac42477 --- /dev/null +++ b/crates/nu_plugin_custom_values/src/handle_get.rs @@ -0,0 +1,61 @@ +use nu_plugin::{EngineInterface, EvaluatedCall, SimplePluginCommand}; +use nu_protocol::{LabeledError, ShellError, Signature, Type, Value}; + +use crate::{handle_custom_value::HandleCustomValue, CustomValuePlugin}; + +pub struct HandleGet; + +impl SimplePluginCommand for HandleGet { + type Plugin = CustomValuePlugin; + + fn name(&self) -> &str { + "custom-value handle get" + } + + fn signature(&self) -> Signature { + Signature::build(self.name()) + .input_output_type(Type::Custom("HandleCustomValue".into()), Type::Any) + } + + fn usage(&self) -> &str { + "Get a value previously stored in a handle" + } + + fn run( + &self, + plugin: &Self::Plugin, + _engine: &EngineInterface, + call: &EvaluatedCall, + input: &Value, + ) -> Result { + if let Some(handle) = input + .as_custom_value()? + .as_any() + .downcast_ref::() + { + // Find the handle + let value = plugin + .handles + .lock() + .map_err(|err| LabeledError::new(err.to_string()))? + .get(&handle.0) + .cloned(); + + if let Some(value) = value { + Ok(value) + } else { + Err(LabeledError::new("Handle expired") + .with_label("this handle is no longer valid", input.span()) + .with_help("the plugin may have exited, or there was a bug")) + } + } else { + Err(ShellError::UnsupportedInput { + msg: "requires HandleCustomValue".into(), + input: format!("got {}", input.get_type()), + msg_span: call.head, + input_span: input.span(), + } + .into()) + } + } +} diff --git a/crates/nu_plugin_custom_values/src/handle_make.rs b/crates/nu_plugin_custom_values/src/handle_make.rs new file mode 100644 index 0000000000..afc3e914a2 --- /dev/null +++ b/crates/nu_plugin_custom_values/src/handle_make.rs @@ -0,0 +1,47 @@ +use std::sync::atomic; + +use nu_plugin::{EngineInterface, EvaluatedCall, SimplePluginCommand}; +use nu_protocol::{LabeledError, Signature, Type, Value}; + +use crate::{handle_custom_value::HandleCustomValue, CustomValuePlugin}; + +pub struct HandleMake; + +impl SimplePluginCommand for HandleMake { + type Plugin = CustomValuePlugin; + + fn name(&self) -> &str { + "custom-value handle make" + } + + fn signature(&self) -> Signature { + Signature::build(self.name()) + .input_output_type(Type::Any, Type::Custom("HandleCustomValue".into())) + } + + fn usage(&self) -> &str { + "Store a value in plugin memory and return a handle to it" + } + + fn run( + &self, + plugin: &Self::Plugin, + _engine: &EngineInterface, + call: &EvaluatedCall, + input: &Value, + ) -> Result { + // Generate an id and store in the plugin. + let new_id = plugin.counter.fetch_add(1, atomic::Ordering::Relaxed); + + plugin + .handles + .lock() + .map_err(|err| LabeledError::new(err.to_string()))? + .insert(new_id, input.clone()); + + Ok(Value::custom( + Box::new(HandleCustomValue(new_id)), + call.head, + )) + } +} diff --git a/crates/nu_plugin_custom_values/src/main.rs b/crates/nu_plugin_custom_values/src/main.rs index d29bbc3dee..c2df39115f 100644 --- a/crates/nu_plugin_custom_values/src/main.rs +++ b/crates/nu_plugin_custom_values/src/main.rs @@ -1,22 +1,43 @@ +use std::{ + collections::BTreeMap, + sync::{atomic::AtomicU64, Mutex}, +}; + +use handle_custom_value::HandleCustomValue; use nu_plugin::{serve_plugin, EngineInterface, MsgPackSerializer, Plugin, PluginCommand}; mod cool_custom_value; +mod handle_custom_value; mod second_custom_value; mod drop_check; mod generate; mod generate2; +mod handle_get; +mod handle_make; mod update; mod update_arg; use drop_check::{DropCheck, DropCheckValue}; use generate::Generate; use generate2::Generate2; -use nu_protocol::{CustomValue, LabeledError}; +use handle_get::HandleGet; +use handle_make::HandleMake; +use nu_protocol::{CustomValue, LabeledError, Spanned, Value}; use update::Update; use update_arg::UpdateArg; -pub struct CustomValuePlugin; +#[derive(Default)] +pub struct CustomValuePlugin { + counter: AtomicU64, + handles: Mutex>, +} + +impl CustomValuePlugin { + pub fn new() -> Self { + Self::default() + } +} impl Plugin for CustomValuePlugin { fn commands(&self) -> Vec>> { @@ -26,22 +47,54 @@ impl Plugin for CustomValuePlugin { Box::new(Update), Box::new(UpdateArg), Box::new(DropCheck), + Box::new(HandleGet), + Box::new(HandleMake), ] } + fn custom_value_to_base_value( + &self, + _engine: &EngineInterface, + custom_value: Spanned>, + ) -> Result { + // HandleCustomValue depends on the plugin state to get. + if let Some(handle) = custom_value + .item + .as_any() + .downcast_ref::() + { + Ok(self + .handles + .lock() + .map_err(|err| LabeledError::new(err.to_string()))? + .get(&handle.0) + .cloned() + .unwrap_or_else(|| Value::nothing(custom_value.span))) + } else { + custom_value + .item + .to_base_value(custom_value.span) + .map_err(|err| err.into()) + } + } + fn custom_value_dropped( &self, _engine: &EngineInterface, custom_value: Box, ) -> Result<(), LabeledError> { - // This is how we implement our drop behavior for DropCheck. + // This is how we implement our drop behavior. if let Some(drop_check) = custom_value.as_any().downcast_ref::() { drop_check.notify(); + } else if let Some(handle) = custom_value.as_any().downcast_ref::() { + if let Ok(mut handles) = self.handles.lock() { + handles.remove(&handle.0); + } } Ok(()) } } fn main() { - serve_plugin(&CustomValuePlugin, MsgPackSerializer {}) + serve_plugin(&CustomValuePlugin::default(), MsgPackSerializer {}) } diff --git a/crates/nu_plugin_custom_values/src/second_custom_value.rs b/crates/nu_plugin_custom_values/src/second_custom_value.rs index 7a7647f521..fde02cfade 100644 --- a/crates/nu_plugin_custom_values/src/second_custom_value.rs +++ b/crates/nu_plugin_custom_values/src/second_custom_value.rs @@ -74,4 +74,8 @@ impl CustomValue for SecondCustomValue { fn as_any(&self) -> &dyn std::any::Any { self } + + fn as_mut_any(&mut self) -> &mut dyn std::any::Any { + self + } } diff --git a/crates/nu_plugin_custom_values/src/update.rs b/crates/nu_plugin_custom_values/src/update.rs index 7c4c5e4960..0ce7c09ec9 100644 --- a/crates/nu_plugin_custom_values/src/update.rs +++ b/crates/nu_plugin_custom_values/src/update.rs @@ -67,6 +67,6 @@ impl SimplePluginCommand for Update { fn test_examples() -> Result<(), nu_protocol::ShellError> { use nu_plugin_test_support::PluginTest; - PluginTest::new("custom_values", crate::CustomValuePlugin.into())? + PluginTest::new("custom_values", crate::CustomValuePlugin::new().into())? .test_command_examples(&Update) } diff --git a/tests/plugins/custom_values.rs b/tests/plugins/custom_values.rs index 34aa3ddf6e..359e532449 100644 --- a/tests/plugins/custom_values.rs +++ b/tests/plugins/custom_values.rs @@ -181,6 +181,20 @@ fn drop_check_custom_value_prints_message_on_drop() { assert!(actual.status.success()); } +#[test] +fn handle_make_then_get_success() { + // The drop notification must wait until the `handle get` call has finished in order for this + // to succeed + let actual = nu_with_plugins!( + cwd: "tests", + plugin: ("nu_plugin_custom_values"), + "42 | custom-value handle make | custom-value handle get" + ); + + assert_eq!(actual.out, "42"); + assert!(actual.status.success()); +} + #[test] fn custom_value_in_example_is_rendered() { let actual = nu_with_plugins!(