diff --git a/crates/nu-plugin/src/lib.rs b/crates/nu-plugin/src/lib.rs index bac11db1ff..ea7563de1b 100644 --- a/crates/nu-plugin/src/lib.rs +++ b/crates/nu-plugin/src/lib.rs @@ -16,7 +16,7 @@ //! invoked by Nushell. //! //! ```rust,no_run -//! use nu_plugin::{EvaluatedCall, LabeledError, MsgPackSerializer, Plugin, serve_plugin}; +//! use nu_plugin::*; //! use nu_protocol::{PluginSignature, Value}; //! //! struct MyPlugin; @@ -26,9 +26,9 @@ //! todo!(); //! } //! fn run( -//! &mut self, +//! &self, //! name: &str, -//! config: &Option, +//! engine: &EngineInterface, //! call: &EvaluatedCall, //! input: &Value //! ) -> Result { @@ -37,7 +37,7 @@ //! } //! //! fn main() { -//! serve_plugin(&mut MyPlugin{}, MsgPackSerializer) +//! serve_plugin(&MyPlugin{}, MsgPackSerializer) //! } //! ``` //! @@ -49,7 +49,7 @@ mod protocol; mod sequence; mod serializers; -pub use plugin::{serve_plugin, Plugin, PluginEncoder, StreamingPlugin}; +pub use plugin::{serve_plugin, EngineInterface, Plugin, PluginEncoder, StreamingPlugin}; pub use protocol::{EvaluatedCall, LabeledError}; pub use serializers::{json::JsonSerializer, msgpack::MsgPackSerializer}; diff --git a/crates/nu-plugin/src/plugin/context.rs b/crates/nu-plugin/src/plugin/context.rs index cac5cd5b3e..a57a38c82a 100644 --- a/crates/nu-plugin/src/plugin/context.rs +++ b/crates/nu-plugin/src/plugin/context.rs @@ -1,36 +1,166 @@ use std::sync::{atomic::AtomicBool, Arc}; +use nu_engine::get_eval_block_with_early_return; use nu_protocol::{ ast::Call, - engine::{EngineState, Stack}, + engine::{Closure, EngineState, Stack}, + Config, PipelineData, ShellError, Span, Spanned, Value, }; +use super::PluginIdentity; + /// Object safe trait for abstracting operations required of the plugin context. pub(crate) trait PluginExecutionContext: Send + Sync { + /// The [Span] for the command execution (`call.head`) + fn command_span(&self) -> Span; + /// The name of the command being executed + fn command_name(&self) -> &str; /// The interrupt signal, if present fn ctrlc(&self) -> Option<&Arc>; + /// Get engine configuration + fn get_config(&self) -> Result; + /// Get plugin configuration + fn get_plugin_config(&self) -> Result, ShellError>; + /// Evaluate a closure passed to the plugin + fn eval_closure( + &self, + closure: Spanned, + positional: Vec, + input: PipelineData, + redirect_stdout: bool, + redirect_stderr: bool, + ) -> Result; } -/// The execution context of a plugin command. May be extended with more fields in the future. +/// The execution context of a plugin command. pub(crate) struct PluginExecutionCommandContext { - ctrlc: Option>, + identity: Arc, + engine_state: EngineState, + stack: Stack, + call: Call, } impl PluginExecutionCommandContext { pub fn new( + identity: Arc, engine_state: &EngineState, - _stack: &Stack, - _call: &Call, + stack: &Stack, + call: &Call, ) -> PluginExecutionCommandContext { PluginExecutionCommandContext { - ctrlc: engine_state.ctrlc.clone(), + identity, + engine_state: engine_state.clone(), + stack: stack.clone(), + call: call.clone(), } } } impl PluginExecutionContext for PluginExecutionCommandContext { + fn command_span(&self) -> Span { + self.call.head + } + + fn command_name(&self) -> &str { + self.engine_state.get_decl(self.call.decl_id).name() + } + fn ctrlc(&self) -> Option<&Arc> { - self.ctrlc.as_ref() + self.engine_state.ctrlc.as_ref() + } + + fn get_config(&self) -> Result { + Ok(nu_engine::get_config(&self.engine_state, &self.stack)) + } + + fn get_plugin_config(&self) -> Result, ShellError> { + // Fetch the configuration for a plugin + // + // The `plugin` must match the registered name of a plugin. For + // `register nu_plugin_example` the plugin config lookup uses `"example"` + Ok(self + .get_config()? + .plugins + .get(&self.identity.plugin_name) + .cloned() + .map(|value| { + let span = value.span(); + match value { + Value::Closure { val, .. } => { + let input = PipelineData::Empty; + + let block = self.engine_state.get_block(val.block_id).clone(); + let mut stack = self.stack.captures_to_stack(val.captures); + + let eval_block_with_early_return = + get_eval_block_with_early_return(&self.engine_state); + + match eval_block_with_early_return( + &self.engine_state, + &mut stack, + &block, + input, + false, + false, + ) { + Ok(v) => v.into_value(span), + Err(e) => Value::error(e, self.call.head), + } + } + _ => value.clone(), + } + })) + } + + fn eval_closure( + &self, + closure: Spanned, + positional: Vec, + input: PipelineData, + redirect_stdout: bool, + redirect_stderr: bool, + ) -> Result { + let block = self + .engine_state + .try_get_block(closure.item.block_id) + .ok_or_else(|| ShellError::GenericError { + error: "Plugin misbehaving".into(), + msg: format!( + "Tried to evaluate unknown block id: {}", + closure.item.block_id + ), + span: Some(closure.span), + help: None, + inner: vec![], + })?; + + let mut stack = self.stack.captures_to_stack(closure.item.captures); + + // Set up the positional arguments + for (idx, value) in positional.into_iter().enumerate() { + if let Some(arg) = block.signature.get_positional(idx) { + if let Some(var_id) = arg.var_id { + stack.add_var(var_id, value); + } else { + return Err(ShellError::NushellFailedSpanned { + msg: "Error while evaluating closure from plugin".into(), + label: "closure argument missing var_id".into(), + span: closure.span, + }); + } + } + } + + let eval_block_with_early_return = get_eval_block_with_early_return(&self.engine_state); + + eval_block_with_early_return( + &self.engine_state, + &mut stack, + block, + input, + redirect_stdout, + redirect_stderr, + ) } } @@ -40,7 +170,38 @@ pub(crate) struct PluginExecutionBogusContext; #[cfg(test)] impl PluginExecutionContext for PluginExecutionBogusContext { + fn command_span(&self) -> Span { + Span::test_data() + } + + fn command_name(&self) -> &str { + "bogus" + } + fn ctrlc(&self) -> Option<&Arc> { None } + + fn get_config(&self) -> Result { + Err(ShellError::NushellFailed { + msg: "get_config not implemented on bogus".into(), + }) + } + + fn get_plugin_config(&self) -> Result, ShellError> { + Ok(None) + } + + fn eval_closure( + &self, + _closure: Spanned, + _positional: Vec, + _input: PipelineData, + _redirect_stdout: bool, + _redirect_stderr: bool, + ) -> Result { + Err(ShellError::NushellFailed { + msg: "eval_closure not implemented on bogus".into(), + }) + } } diff --git a/crates/nu-plugin/src/plugin/declaration.rs b/crates/nu-plugin/src/plugin/declaration.rs index 0afadb82d2..66aad026fb 100644 --- a/crates/nu-plugin/src/plugin/declaration.rs +++ b/crates/nu-plugin/src/plugin/declaration.rs @@ -3,11 +3,11 @@ use crate::protocol::{CallInfo, EvaluatedCall}; use std::path::{Path, PathBuf}; use std::sync::Arc; -use nu_engine::{get_eval_block, get_eval_expression}; +use nu_engine::get_eval_expression; use nu_protocol::engine::{Command, EngineState, Stack}; use nu_protocol::{ast::Call, PluginSignature, Signature}; -use nu_protocol::{Example, PipelineData, ShellError, Value}; +use nu_protocol::{Example, PipelineData, ShellError}; #[doc(hidden)] // Note: not for plugin authors / only used in nu-parser #[derive(Clone)] @@ -72,7 +72,6 @@ impl Command for PluginDeclaration { call: &Call, input: PipelineData, ) -> Result { - let eval_block = get_eval_block(engine_state); let eval_expression = get_eval_expression(engine_state); // Create the EvaluatedCall to send to the plugin first - it's best for this to fail early, @@ -80,32 +79,6 @@ impl Command for PluginDeclaration { let evaluated_call = EvaluatedCall::try_from_call(call, engine_state, stack, eval_expression)?; - // Fetch the configuration for a plugin - // - // The `plugin` must match the registered name of a plugin. For - // `register nu_plugin_example` the plugin config lookup uses `"example"` - let config = nu_engine::get_config(engine_state, stack) - .plugins - .get(&self.identity.plugin_name) - .cloned() - .map(|value| { - let span = value.span(); - match value { - Value::Closure { val, .. } => { - let input = PipelineData::Empty; - - let block = engine_state.get_block(val.block_id).clone(); - let mut stack = stack.captures_to_stack(val.captures); - - match eval_block(engine_state, &mut stack, &block, input, false, false) { - Ok(v) => v.into_value(span), - Err(e) => Value::error(e, call.head), - } - } - _ => value.clone(), - } - }); - // We need the current environment variables for `python` based plugins // Or we'll likely have a problem when a plugin is implemented in a virtual Python environment. let current_envs = nu_engine::env::env_to_strings(engine_state, stack).unwrap_or_default(); @@ -122,8 +95,9 @@ impl Command for PluginDeclaration { } })?; - // Create the context to execute in + // Create the context to execute in - this supports engine calls and custom values let context = Arc::new(PluginExecutionCommandContext::new( + self.identity.clone(), engine_state, stack, call, @@ -134,7 +108,6 @@ impl Command for PluginDeclaration { name: self.name.clone(), call: evaluated_call, input, - config, }, context, ) diff --git a/crates/nu-plugin/src/plugin/interface.rs b/crates/nu-plugin/src/plugin/interface.rs index 3dcca89602..acfe714851 100644 --- a/crates/nu-plugin/src/plugin/interface.rs +++ b/crates/nu-plugin/src/plugin/interface.rs @@ -22,6 +22,7 @@ use crate::{ mod stream; mod engine; +pub use engine::EngineInterface; pub(crate) use engine::{EngineInterfaceManager, ReceivedPluginCall}; mod plugin; diff --git a/crates/nu-plugin/src/plugin/interface/engine.rs b/crates/nu-plugin/src/plugin/interface/engine.rs index 7d5012ed11..a8b533a4b3 100644 --- a/crates/nu-plugin/src/plugin/interface/engine.rs +++ b/crates/nu-plugin/src/plugin/interface/engine.rs @@ -1,16 +1,19 @@ //! Interface used by the plugin to communicate with the engine. -use std::sync::{mpsc, Arc}; +use std::{ + collections::{btree_map, BTreeMap}, + sync::{mpsc, Arc}, +}; use nu_protocol::{ - IntoInterruptiblePipelineData, ListStream, PipelineData, PluginSignature, ShellError, Spanned, - Value, + engine::Closure, Config, IntoInterruptiblePipelineData, ListStream, PipelineData, + PluginSignature, ShellError, Spanned, Value, }; use crate::{ protocol::{ - CallInfo, CustomValueOp, PluginCall, PluginCallId, PluginCallResponse, PluginCustomValue, - PluginInput, ProtocolInfo, + CallInfo, CustomValueOp, EngineCall, EngineCallId, EngineCallResponse, PluginCall, + PluginCallId, PluginCallResponse, PluginCustomValue, PluginInput, ProtocolInfo, }, LabeledError, PluginOutput, }; @@ -47,8 +50,13 @@ mod tests; /// Internal shared state between the manager and each interface. struct EngineInterfaceState { + /// Sequence for generating engine call ids + engine_call_id_sequence: Sequence, /// Sequence for generating stream ids stream_id_sequence: Sequence, + /// Sender to subscribe to an engine call response + engine_call_subscription_sender: + mpsc::Sender<(EngineCallId, mpsc::Sender>)>, /// The synchronized output writer writer: Box>, } @@ -56,7 +64,12 @@ struct EngineInterfaceState { impl std::fmt::Debug for EngineInterfaceState { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.debug_struct("EngineInterfaceState") + .field("engine_call_id_sequence", &self.engine_call_id_sequence) .field("stream_id_sequence", &self.stream_id_sequence) + .field( + "engine_call_subscription_sender", + &self.engine_call_subscription_sender, + ) .finish_non_exhaustive() } } @@ -70,6 +83,12 @@ pub(crate) struct EngineInterfaceManager { plugin_call_sender: Option>, /// Receiver for PluginCalls. This is usually taken after initialization plugin_call_receiver: Option>, + /// Subscriptions for engine call responses + engine_call_subscriptions: + BTreeMap>>, + /// Receiver for engine call subscriptions + engine_call_subscription_receiver: + mpsc::Receiver<(EngineCallId, mpsc::Sender>)>, /// Manages stream messages and state stream_manager: StreamManager, /// Protocol version info, set after `Hello` received @@ -79,14 +98,19 @@ pub(crate) struct EngineInterfaceManager { impl EngineInterfaceManager { pub(crate) fn new(writer: impl PluginWrite + 'static) -> EngineInterfaceManager { let (plug_tx, plug_rx) = mpsc::channel(); + let (subscription_tx, subscription_rx) = mpsc::channel(); EngineInterfaceManager { state: Arc::new(EngineInterfaceState { + engine_call_id_sequence: Sequence::default(), stream_id_sequence: Sequence::default(), + engine_call_subscription_sender: subscription_tx, writer: Box::new(writer), }), plugin_call_sender: Some(plug_tx), plugin_call_receiver: Some(plug_rx), + engine_call_subscriptions: BTreeMap::new(), + engine_call_subscription_receiver: subscription_rx, stream_manager: StreamManager::new(), protocol_info: None, } @@ -122,6 +146,38 @@ impl EngineInterfaceManager { }) } + /// Flush any remaining subscriptions in the receiver into the map + fn receive_engine_call_subscriptions(&mut self) { + for (id, subscription) in self.engine_call_subscription_receiver.try_iter() { + if let btree_map::Entry::Vacant(e) = self.engine_call_subscriptions.entry(id) { + e.insert(subscription); + } else { + log::warn!("Duplicate engine call ID ignored: {id}") + } + } + } + + /// Send a [`EngineCallResponse`] to the appropriate sender + fn send_engine_call_response( + &mut self, + id: EngineCallId, + response: EngineCallResponse, + ) -> Result<(), ShellError> { + // Ensure all of the subscriptions have been flushed out of the receiver + self.receive_engine_call_subscriptions(); + // Remove the sender - there is only one response per engine call + if let Some(sender) = self.engine_call_subscriptions.remove(&id) { + if sender.send(response).is_err() { + log::warn!("Received an engine call response for id={id}, but the caller hung up"); + } + Ok(()) + } else { + Err(ShellError::PluginFailedToDecode { + msg: format!("Unknown engine call ID: {id}"), + }) + } + } + /// True if there are no other copies of the state (which would mean there are no interfaces /// and no stream readers/writers) pub(crate) fn is_finished(&self) -> bool { @@ -141,7 +197,13 @@ impl EngineInterfaceManager { } if let Err(err) = msg.and_then(|msg| self.consume(msg)) { + // Error to streams let _ = self.stream_manager.broadcast_read_error(err.clone()); + // Error to engine call waiters + self.receive_engine_call_subscriptions(); + for sender in std::mem::take(&mut self.engine_call_subscriptions).into_values() { + let _ = sender.send(EngineCallResponse::Error(err.clone())); + } return Err(err); } } @@ -200,7 +262,6 @@ impl InterfaceManager for EngineInterfaceManager { name, mut call, input, - config, }) => { let interface = self.interface_for_context(id); // If there's an error with initialization of the input stream, just send @@ -214,12 +275,7 @@ impl InterfaceManager for EngineInterfaceManager { // Send the plugin call to the receiver self.send_plugin_call(ReceivedPluginCall::Run { engine: interface, - call: CallInfo { - name, - call, - input, - config, - }, + call: CallInfo { name, call, input }, }) } err @ Err(_) => interface.write_response(err)?.write(), @@ -239,6 +295,21 @@ impl InterfaceManager for EngineInterfaceManager { drop(self.plugin_call_sender.take()); Ok(()) } + PluginInput::EngineCallResponse(id, response) => { + let response = match response { + EngineCallResponse::Error(err) => EngineCallResponse::Error(err), + EngineCallResponse::Config(config) => EngineCallResponse::Config(config), + EngineCallResponse::PipelineData(header) => { + // If there's an error with initializing this stream, change it to an engine + // call error response, but send it anyway + match self.read_pipeline_data(header, None) { + Ok(data) => EngineCallResponse::PipelineData(data), + Err(err) => EngineCallResponse::Error(err), + } + } + }; + self.send_engine_call_response(id, response) + } } } @@ -341,6 +412,264 @@ impl EngineInterface { self.write(PluginOutput::CallResponse(self.context()?, response))?; self.flush() } + + /// Write an engine call message. Returns the writer for the stream, and the receiver for + /// the response to the engine call. + fn write_engine_call( + &self, + call: EngineCall, + ) -> Result< + ( + PipelineDataWriter, + mpsc::Receiver>, + ), + ShellError, + > { + let context = self.context()?; + let id = self.state.engine_call_id_sequence.next()?; + let (tx, rx) = mpsc::channel(); + + // Convert the call into one with a header and handle the stream, if necessary + let (call, writer) = match call { + EngineCall::EvalClosure { + closure, + positional, + input, + redirect_stdout, + redirect_stderr, + } => { + let (header, writer) = self.init_write_pipeline_data(input)?; + ( + EngineCall::EvalClosure { + closure, + positional, + input: header, + redirect_stdout, + redirect_stderr, + }, + writer, + ) + } + // These calls have no pipeline data, so they're just the same on both sides + EngineCall::GetConfig => (EngineCall::GetConfig, Default::default()), + EngineCall::GetPluginConfig => (EngineCall::GetPluginConfig, Default::default()), + }; + + // Register the channel + self.state + .engine_call_subscription_sender + .send((id, tx)) + .map_err(|_| ShellError::NushellFailed { + msg: "EngineInterfaceManager hung up and is no longer accepting engine calls" + .into(), + })?; + + // Write request + self.write(PluginOutput::EngineCall { context, id, call })?; + self.flush()?; + + Ok((writer, rx)) + } + + /// Perform an engine call. Input and output streams are handled. + fn engine_call( + &self, + call: EngineCall, + ) -> Result, ShellError> { + let (writer, rx) = self.write_engine_call(call)?; + + // Finish writing stream in the background + writer.write_background()?; + + // Wait on receiver to get the response + rx.recv().map_err(|_| ShellError::NushellFailed { + msg: "Failed to get response to engine call because the channel was closed".into(), + }) + } + + /// Get the full shell configuration from the engine. As this is quite a large object, it is + /// provided on request only. + /// + /// # Example + /// + /// Format a value in the user's preferred way: + /// + /// ``` + /// # use nu_protocol::{Value, ShellError}; + /// # use nu_plugin::EngineInterface; + /// # fn example(engine: &EngineInterface, value: &Value) -> Result<(), ShellError> { + /// let config = engine.get_config()?; + /// eprintln!("{}", value.to_expanded_string(", ", &config)); + /// # Ok(()) + /// # } + /// ``` + pub fn get_config(&self) -> Result, ShellError> { + match self.engine_call(EngineCall::GetConfig)? { + EngineCallResponse::Config(config) => Ok(config), + EngineCallResponse::Error(err) => Err(err), + _ => Err(ShellError::PluginFailedToDecode { + msg: "Received unexpected response for EngineCall::GetConfig".into(), + }), + } + } + + /// Get the plugin-specific configuration from the engine. This lives in + /// `$env.config.plugins.NAME` for a plugin named `NAME`. If the config is set to a closure, + /// it is automatically evaluated each time. + /// + /// # Example + /// + /// Print this plugin's config: + /// + /// ``` + /// # use nu_protocol::{Value, ShellError}; + /// # use nu_plugin::EngineInterface; + /// # fn example(engine: &EngineInterface, value: &Value) -> Result<(), ShellError> { + /// let config = engine.get_plugin_config()?; + /// eprintln!("{:?}", config); + /// # Ok(()) + /// # } + /// ``` + pub fn get_plugin_config(&self) -> Result, ShellError> { + match self.engine_call(EngineCall::GetPluginConfig)? { + EngineCallResponse::PipelineData(PipelineData::Empty) => Ok(None), + EngineCallResponse::PipelineData(PipelineData::Value(value, _)) => Ok(Some(value)), + EngineCallResponse::Error(err) => Err(err), + _ => Err(ShellError::PluginFailedToDecode { + msg: "Received unexpected response for EngineCall::GetConfig".into(), + }), + } + } + + /// Ask the engine to evaluate a closure. Input to the closure is passed as a stream, and the + /// output is available as a stream. + /// + /// Set `redirect_stdout` to `true` to capture the standard output stream of an external + /// command, if the closure results in an external command. + /// + /// Set `redirect_stderr` to `true` to capture the standard error stream of an external command, + /// if the closure results in an external command. + /// + /// # Example + /// + /// Invoked as: + /// + /// ```nushell + /// my_command { seq 1 $in | each { |n| $"Hello, ($n)" } } + /// ``` + /// + /// ``` + /// # use nu_protocol::{Value, ShellError, PipelineData}; + /// # use nu_plugin::{EngineInterface, EvaluatedCall}; + /// # fn example(engine: &EngineInterface, call: &EvaluatedCall) -> Result<(), ShellError> { + /// let closure = call.req(0)?; + /// let input = PipelineData::Value(Value::int(4, call.head), None); + /// let output = engine.eval_closure_with_stream( + /// &closure, + /// vec![], + /// input, + /// true, + /// false, + /// )?; + /// for value in output { + /// eprintln!("Closure says: {}", value.as_str()?); + /// } + /// # Ok(()) + /// # } + /// ``` + /// + /// Output: + /// + /// ```text + /// Closure says: Hello, 1 + /// Closure says: Hello, 2 + /// Closure says: Hello, 3 + /// Closure says: Hello, 4 + /// ``` + pub fn eval_closure_with_stream( + &self, + closure: &Spanned, + mut positional: Vec, + input: PipelineData, + redirect_stdout: bool, + redirect_stderr: bool, + ) -> Result { + // Ensure closure args have custom values serialized + positional + .iter_mut() + .try_for_each(PluginCustomValue::serialize_custom_values_in)?; + + let call = EngineCall::EvalClosure { + closure: closure.clone(), + positional, + input, + redirect_stdout, + redirect_stderr, + }; + + match self.engine_call(call)? { + EngineCallResponse::Error(error) => Err(error), + EngineCallResponse::PipelineData(data) => Ok(data), + _ => Err(ShellError::PluginFailedToDecode { + msg: "Received unexpected response type for EngineCall::EvalClosure".into(), + }), + } + } + + /// Ask the engine to evaluate a closure. Input is optionally passed as a [`Value`], and output + /// of the closure is collected to a [`Value`] even if it is a stream. + /// + /// If the closure results in an external command, the return value will be a collected string + /// or binary value of the standard output stream of that command, similar to calling + /// [`eval_closure_with_stream()`](Self::eval_closure_with_stream) with `redirect_stdout` = + /// `true` and `redirect_stderr` = `false`. + /// + /// Use [`eval_closure_with_stream()`](Self::eval_closure_with_stream) if more control over the + /// input and output is desired. + /// + /// # Example + /// + /// Invoked as: + /// + /// ```nushell + /// my_command { |number| $number + 1} + /// ``` + /// + /// ``` + /// # use nu_protocol::{Value, ShellError}; + /// # use nu_plugin::{EngineInterface, EvaluatedCall}; + /// # fn example(engine: &EngineInterface, call: &EvaluatedCall) -> Result<(), ShellError> { + /// let closure = call.req(0)?; + /// for n in 0..4 { + /// let result = engine.eval_closure(&closure, vec![Value::int(n, call.head)], None)?; + /// eprintln!("{} => {}", n, result.as_int()?); + /// } + /// # Ok(()) + /// # } + /// ``` + /// + /// Output: + /// + /// ```text + /// 0 => 1 + /// 1 => 2 + /// 2 => 3 + /// 3 => 4 + /// ``` + pub fn eval_closure( + &self, + closure: &Spanned, + positional: Vec, + input: Option, + ) -> Result { + let input = input.map_or_else(|| PipelineData::Empty, |v| PipelineData::Value(v, None)); + let output = self.eval_closure_with_stream(closure, positional, input, true, false)?; + // Unwrap an error value + match output.into_value(closure.span) { + Value::Error { error, .. } => Err(*error), + value => Ok(value), + } + } } impl Interface for EngineInterface { diff --git a/crates/nu-plugin/src/plugin/interface/engine/tests.rs b/crates/nu-plugin/src/plugin/interface/engine/tests.rs index c6be0e5374..0519ee9820 100644 --- a/crates/nu-plugin/src/plugin/interface/engine/tests.rs +++ b/crates/nu-plugin/src/plugin/interface/engine/tests.rs @@ -1,22 +1,22 @@ -use std::sync::mpsc::TryRecvError; +use std::sync::mpsc::{self, TryRecvError}; use nu_protocol::{ - CustomValue, IntoInterruptiblePipelineData, PipelineData, PluginSignature, ShellError, Span, - Spanned, Value, + engine::Closure, Config, CustomValue, IntoInterruptiblePipelineData, PipelineData, + PluginSignature, ShellError, Span, Spanned, Value, }; use crate::{ plugin::interface::{test_util::TestCase, Interface, InterfaceManager}, protocol::{ test_util::{expected_test_custom_value, test_plugin_custom_value, TestCustomValue}, - CallInfo, CustomValueOp, ExternalStreamInfo, ListStreamInfo, PipelineDataHeader, - PluginCall, PluginCustomValue, PluginInput, Protocol, ProtocolInfo, RawStreamInfo, - StreamData, StreamMessage, + CallInfo, CustomValueOp, EngineCall, EngineCallId, EngineCallResponse, ExternalStreamInfo, + ListStreamInfo, PipelineDataHeader, PluginCall, PluginCustomValue, PluginInput, Protocol, + ProtocolInfo, RawStreamInfo, StreamData, StreamMessage, }, EvaluatedCall, LabeledError, PluginCallResponse, PluginOutput, }; -use super::ReceivedPluginCall; +use super::{EngineInterfaceManager, ReceivedPluginCall}; #[test] fn manager_consume_all_consumes_messages() -> Result<(), ShellError> { @@ -90,7 +90,7 @@ fn check_test_io_error(error: &ShellError) { } #[test] -fn manager_consume_all_propagates_error_to_readers() -> Result<(), ShellError> { +fn manager_consume_all_propagates_io_error_to_readers() -> Result<(), ShellError> { let mut test = TestCase::new(); let mut manager = test.engine(); @@ -170,6 +170,74 @@ fn manager_consume_all_propagates_message_error_to_readers() -> Result<(), Shell } } +fn fake_engine_call( + manager: &mut EngineInterfaceManager, + id: EngineCallId, +) -> mpsc::Receiver> { + // Set up a fake engine call subscription + let (tx, rx) = mpsc::channel(); + + manager.engine_call_subscriptions.insert(id, tx); + + rx +} + +#[test] +fn manager_consume_all_propagates_io_error_to_engine_calls() -> Result<(), ShellError> { + let mut test = TestCase::new(); + let mut manager = test.engine(); + let interface = manager.get_interface(); + + test.set_read_error(test_io_error()); + + // Set up a fake engine call subscription + let rx = fake_engine_call(&mut manager, 0); + + manager + .consume_all(&mut test) + .expect_err("consume_all did not error"); + + // We have to hold interface until now otherwise consume_all won't try to process the message + drop(interface); + + let message = rx.try_recv().expect("failed to get engine call message"); + match message { + EngineCallResponse::Error(error) => { + check_test_io_error(&error); + Ok(()) + } + _ => panic!("received something other than an error: {message:?}"), + } +} + +#[test] +fn manager_consume_all_propagates_message_error_to_engine_calls() -> Result<(), ShellError> { + let mut test = TestCase::new(); + let mut manager = test.engine(); + let interface = manager.get_interface(); + + test.add(invalid_input()); + + // Set up a fake engine call subscription + let rx = fake_engine_call(&mut manager, 0); + + manager + .consume_all(&mut test) + .expect_err("consume_all did not error"); + + // We have to hold interface until now otherwise consume_all won't try to process the message + drop(interface); + + let message = rx.try_recv().expect("failed to get engine call message"); + match message { + EngineCallResponse::Error(error) => { + check_invalid_input_error(&error); + Ok(()) + } + _ => panic!("received something other than an error: {message:?}"), + } +} + #[test] fn manager_consume_sets_protocol_info_on_hello() -> Result<(), ShellError> { let mut manager = TestCase::new().engine(); @@ -275,7 +343,6 @@ fn manager_consume_call_run_forwards_to_receiver_with_context() -> Result<(), Sh named: vec![], }, input: PipelineDataHeader::Empty, - config: None, }), ))?; @@ -310,7 +377,6 @@ fn manager_consume_call_run_forwards_to_receiver_with_pipeline_data() -> Result< named: vec![], }, input: PipelineDataHeader::ListStream(ListStreamInfo { id: 6 }), - config: None, }), ))?; @@ -364,7 +430,6 @@ fn manager_consume_call_run_deserializes_custom_values_in_args() -> Result<(), S )], }, input: PipelineDataHeader::Empty, - config: None, }), ))?; @@ -443,6 +508,43 @@ fn manager_consume_call_custom_value_op_forwards_to_receiver_with_context() -> R Ok(()) } +#[test] +fn manager_consume_engine_call_response_forwards_to_subscriber_with_pipeline_data( +) -> Result<(), ShellError> { + let mut manager = TestCase::new().engine(); + manager.protocol_info = Some(ProtocolInfo::default()); + + let rx = fake_engine_call(&mut manager, 0); + + manager.consume(PluginInput::EngineCallResponse( + 0, + EngineCallResponse::PipelineData(PipelineDataHeader::ListStream(ListStreamInfo { id: 0 })), + ))?; + + for i in 0..2 { + manager.consume(PluginInput::Stream(StreamMessage::Data( + 0, + Value::test_int(i).into(), + )))?; + } + + manager.consume(PluginInput::Stream(StreamMessage::End(0)))?; + + // Make sure the streams end and we don't deadlock + drop(manager); + + let response = rx.try_recv().expect("failed to get engine call response"); + + match response { + EngineCallResponse::PipelineData(data) => { + // Ensure we manage to receive the stream messages + assert_eq!(2, data.into_iter().count()); + Ok(()) + } + _ => panic!("unexpected response: {response:?}"), + } +} + #[test] fn manager_prepare_pipeline_data_deserializes_custom_values() -> Result<(), ShellError> { let manager = TestCase::new().engine(); @@ -683,6 +785,166 @@ fn interface_write_signature() -> Result<(), ShellError> { Ok(()) } +#[test] +fn interface_write_engine_call_registers_subscription() -> Result<(), ShellError> { + let mut manager = TestCase::new().engine(); + assert!( + manager.engine_call_subscriptions.is_empty(), + "engine call subscriptions not empty before start of test" + ); + + let interface = manager.interface_for_context(0); + let _ = interface.write_engine_call(EngineCall::GetConfig)?; + + manager.receive_engine_call_subscriptions(); + assert!( + !manager.engine_call_subscriptions.is_empty(), + "not registered" + ); + Ok(()) +} + +#[test] +fn interface_write_engine_call_writes_with_correct_context() -> Result<(), ShellError> { + let test = TestCase::new(); + let manager = test.engine(); + let interface = manager.interface_for_context(32); + let _ = interface.write_engine_call(EngineCall::GetConfig)?; + + match test.next_written().expect("nothing written") { + PluginOutput::EngineCall { context, call, .. } => { + assert_eq!(32, context, "context incorrect"); + assert!( + matches!(call, EngineCall::GetConfig), + "incorrect engine call (expected GetConfig): {call:?}" + ); + } + other => panic!("incorrect output: {other:?}"), + } + + assert!(!test.has_unconsumed_write()); + Ok(()) +} + +/// Fake responses to requests for engine call messages +fn start_fake_plugin_call_responder( + manager: EngineInterfaceManager, + take: usize, + mut f: impl FnMut(EngineCallId) -> EngineCallResponse + Send + 'static, +) { + std::thread::Builder::new() + .name("fake engine call responder".into()) + .spawn(move || { + for (id, sub) in manager + .engine_call_subscription_receiver + .into_iter() + .take(take) + { + sub.send(f(id)).expect("failed to send"); + } + }) + .expect("failed to spawn thread"); +} + +#[test] +fn interface_get_config() -> Result<(), ShellError> { + let test = TestCase::new(); + let manager = test.engine(); + let interface = manager.interface_for_context(0); + + start_fake_plugin_call_responder(manager, 1, |_| { + EngineCallResponse::Config(Config::default().into()) + }); + + let _ = interface.get_config()?; + assert!(test.has_unconsumed_write()); + Ok(()) +} + +#[test] +fn interface_get_plugin_config() -> Result<(), ShellError> { + let test = TestCase::new(); + let manager = test.engine(); + let interface = manager.interface_for_context(0); + + start_fake_plugin_call_responder(manager, 2, |id| { + if id == 0 { + EngineCallResponse::PipelineData(PipelineData::Empty) + } else { + EngineCallResponse::PipelineData(PipelineData::Value(Value::test_int(2), None)) + } + }); + + let first_config = interface.get_plugin_config()?; + assert!(first_config.is_none(), "should be None: {first_config:?}"); + + let second_config = interface.get_plugin_config()?; + assert_eq!(Some(Value::test_int(2)), second_config); + + assert!(test.has_unconsumed_write()); + Ok(()) +} + +#[test] +fn interface_eval_closure_with_stream() -> Result<(), ShellError> { + let test = TestCase::new(); + let manager = test.engine(); + let interface = manager.interface_for_context(0); + + start_fake_plugin_call_responder(manager, 1, |_| { + EngineCallResponse::PipelineData(PipelineData::Value(Value::test_int(2), None)) + }); + + let result = interface + .eval_closure_with_stream( + &Spanned { + item: Closure { + block_id: 42, + captures: vec![(0, Value::test_int(5))], + }, + span: Span::test_data(), + }, + vec![Value::test_string("test")], + PipelineData::Empty, + true, + false, + )? + .into_value(Span::test_data()); + + assert_eq!(Value::test_int(2), result); + + // Double check the message that was written, as it's complicated + match test.next_written().expect("nothing written") { + PluginOutput::EngineCall { call, .. } => match call { + EngineCall::EvalClosure { + closure, + positional, + input, + redirect_stdout, + redirect_stderr, + } => { + assert_eq!(42, closure.item.block_id, "closure.item.block_id"); + assert_eq!(1, closure.item.captures.len(), "closure.item.captures.len"); + assert_eq!( + (0, Value::test_int(5)), + closure.item.captures[0], + "closure.item.captures[0]" + ); + assert_eq!(Span::test_data(), closure.span, "closure.span"); + assert_eq!(1, positional.len(), "positional.len"); + assert_eq!(Value::test_string("test"), positional[0], "positional[0]"); + assert!(matches!(input, PipelineDataHeader::Empty)); + assert!(redirect_stdout); + assert!(!redirect_stderr); + } + _ => panic!("wrong engine call: {call:?}"), + }, + other => panic!("wrong output: {other:?}"), + } + + Ok(()) +} + #[test] fn interface_prepare_pipeline_data_serializes_custom_values() -> Result<(), ShellError> { let interface = TestCase::new().engine().get_interface(); diff --git a/crates/nu-plugin/src/plugin/interface/plugin.rs b/crates/nu-plugin/src/plugin/interface/plugin.rs index 88eeb422bb..e54ee69bdc 100644 --- a/crates/nu-plugin/src/plugin/interface/plugin.rs +++ b/crates/nu-plugin/src/plugin/interface/plugin.rs @@ -13,8 +13,9 @@ use nu_protocol::{ use crate::{ plugin::{context::PluginExecutionContext, PluginIdentity}, protocol::{ - CallInfo, CustomValueOp, PluginCall, PluginCallId, PluginCallResponse, PluginCustomValue, - PluginInput, PluginOutput, ProtocolInfo, + CallInfo, CustomValueOp, EngineCall, EngineCallId, EngineCallResponse, PluginCall, + PluginCallId, PluginCallResponse, PluginCustomValue, PluginInput, PluginOutput, + ProtocolInfo, StreamId, StreamMessage, }, sequence::Sequence, }; @@ -34,6 +35,12 @@ enum ReceivedPluginCallMessage { /// An critical error with the interface Error(ShellError), + + /// An engine call that should be evaluated and responded to, but is not the final response + /// + /// We send this back to the thread that made the plugin call so we don't block the reader + /// thread + EngineCall(EngineCallId, EngineCall), } /// Context for plugin call execution @@ -87,9 +94,11 @@ impl std::fmt::Debug for PluginInterfaceState { #[derive(Debug)] struct PluginCallSubscription { /// The sender back to the thread that is waiting for the plugin call response - sender: mpsc::Sender, - /// Optional context for the environment of a plugin call + sender: Option>, + /// Optional context for the environment of a plugin call for servicing engine calls context: Option, + /// Number of streams that still need to be read from the plugin call response + remaining_streams_to_read: i32, } /// Manages reading and dispatching messages for [`PluginInterface`]s. @@ -105,6 +114,10 @@ pub(crate) struct PluginInterfaceManager { plugin_call_subscriptions: BTreeMap, /// Receiver for plugin call subscriptions plugin_call_subscription_receiver: mpsc::Receiver<(PluginCallId, PluginCallSubscription)>, + /// Tracker for which plugin call streams being read belong to + /// + /// This is necessary so we know when we can remove context for plugin calls + plugin_call_input_streams: BTreeMap, } impl PluginInterfaceManager { @@ -126,6 +139,7 @@ impl PluginInterfaceManager { protocol_info: None, plugin_call_subscriptions: BTreeMap::new(), plugin_call_subscription_receiver: subscription_rx, + plugin_call_input_streams: BTreeMap::new(), } } @@ -140,6 +154,29 @@ impl PluginInterfaceManager { } } + /// Track the start of stream(s) + fn recv_stream_started(&mut self, call_id: PluginCallId, stream_id: StreamId) { + self.receive_plugin_call_subscriptions(); + if let Some(sub) = self.plugin_call_subscriptions.get_mut(&call_id) { + self.plugin_call_input_streams.insert(stream_id, call_id); + sub.remaining_streams_to_read += 1; + } + } + + /// Track the end of a stream + fn recv_stream_ended(&mut self, stream_id: StreamId) { + if let Some(call_id) = self.plugin_call_input_streams.remove(&stream_id) { + if let btree_map::Entry::Occupied(mut e) = self.plugin_call_subscriptions.entry(call_id) + { + e.get_mut().remaining_streams_to_read -= 1; + // Remove the subscription if there are no more streams to be read. + if e.get().remaining_streams_to_read <= 0 { + e.remove(); + } + } + } + } + /// Find the context corresponding to the given plugin call id fn get_context(&mut self, id: PluginCallId) -> Result, ShellError> { // Make sure we're up to date @@ -162,15 +199,22 @@ impl PluginInterfaceManager { // Ensure we're caught up on the subscriptions made self.receive_plugin_call_subscriptions(); - // Remove the subscription, since this would be the last message - if let Some(subscription) = self.plugin_call_subscriptions.remove(&id) { - if subscription + if let btree_map::Entry::Occupied(mut e) = self.plugin_call_subscriptions.entry(id) { + // Remove the subscription sender, since this will be the last message. + // + // We can spawn a new one if we need it for engine calls. + if e.get_mut() .sender - .send(ReceivedPluginCallMessage::Response(response)) - .is_err() + .take() + .and_then(|s| s.send(ReceivedPluginCallMessage::Response(response)).ok()) + .is_none() { log::warn!("Received a plugin call response for id={id}, but the caller hung up"); } + // If there are no registered streams, just remove it + if e.get().remaining_streams_to_read <= 0 { + e.remove(); + } Ok(()) } else { Err(ShellError::PluginFailedToDecode { @@ -179,6 +223,106 @@ impl PluginInterfaceManager { } } + /// Spawn a handler for engine calls for a plugin, in case we need to handle engine calls + /// after the response has already been received (in which case we have nowhere to send them) + fn spawn_engine_call_handler( + &mut self, + id: PluginCallId, + ) -> Result<&mpsc::Sender, ShellError> { + let interface = self.get_interface(); + + if let Some(sub) = self.plugin_call_subscriptions.get_mut(&id) { + if sub.sender.is_none() { + let (tx, rx) = mpsc::channel(); + let context = sub.context.clone(); + let handler = move || { + for msg in rx { + // This thread only handles engine calls. + match msg { + ReceivedPluginCallMessage::EngineCall(engine_call_id, engine_call) => { + if let Err(err) = interface.handle_engine_call( + engine_call_id, + engine_call, + &context, + ) { + log::warn!( + "Error in plugin post-response engine call handler: \ + {err:?}" + ); + return; + } + } + other => log::warn!( + "Bad message received in plugin post-response \ + engine call handler: {other:?}" + ), + } + } + }; + std::thread::Builder::new() + .name("plugin engine call handler".into()) + .spawn(handler) + .expect("failed to spawn thread"); + sub.sender = Some(tx); + Ok(sub.sender.as_ref().unwrap_or_else(|| unreachable!())) + } else { + Err(ShellError::NushellFailed { + msg: "Tried to spawn the fallback engine call handler before the plugin call \ + response had been received" + .into(), + }) + } + } else { + Err(ShellError::NushellFailed { + msg: format!("Couldn't find plugin ID={id} in subscriptions"), + }) + } + } + + /// Send an [`EngineCall`] to the appropriate sender + fn send_engine_call( + &mut self, + plugin_call_id: PluginCallId, + engine_call_id: EngineCallId, + call: EngineCall, + ) -> Result<(), ShellError> { + // Ensure we're caught up on the subscriptions made + self.receive_plugin_call_subscriptions(); + + // Don't remove the sender, as there could be more calls or responses + if let Some(subscription) = self.plugin_call_subscriptions.get(&plugin_call_id) { + let msg = ReceivedPluginCallMessage::EngineCall(engine_call_id, call); + // Call if there's an error sending the engine call + let send_error = |this: &Self| { + log::warn!( + "Received an engine call for plugin_call_id={plugin_call_id}, \ + but the caller hung up" + ); + // We really have no choice here but to send the response ourselves and hope we + // don't block + this.state.writer.write(&PluginInput::EngineCallResponse( + engine_call_id, + EngineCallResponse::Error(ShellError::IOError { + msg: "Can't make engine call because the original caller hung up".into(), + }), + ))?; + this.state.writer.flush() + }; + // Try to send to the sender if it exists + if let Some(sender) = subscription.sender.as_ref() { + sender.send(msg).or_else(|_| send_error(self)) + } else { + // The sender no longer exists. Spawn a specific one just for engine calls + let sender = self.spawn_engine_call_handler(plugin_call_id)?; + sender.send(msg).or_else(|_| send_error(self)) + } + } else { + Err(ShellError::PluginFailedToDecode { + msg: format!("Unknown plugin call ID: {plugin_call_id}"), + }) + } + } + /// True if there are no other copies of the state (which would mean there are no interfaces /// and no stream readers/writers) pub(crate) fn is_finished(&self) -> bool { @@ -207,7 +351,8 @@ impl PluginInterfaceManager { { let _ = subscription .sender - .send(ReceivedPluginCallMessage::Error(err.clone())); + .as_ref() + .map(|s| s.send(ReceivedPluginCallMessage::Error(err.clone()))); } return Err(err); } @@ -268,6 +413,10 @@ impl InterfaceManager for PluginInterfaceManager { // error response, but send it anyway let exec_context = self.get_context(id)?; let ctrlc = exec_context.as_ref().and_then(|c| c.0.ctrlc()); + // Register the streams in the response + for stream_id in data.stream_ids() { + self.recv_stream_started(id, stream_id); + } match self.read_pipeline_data(data, ctrlc) { Ok(data) => PluginCallResponse::PipelineData(data), Err(err) => PluginCallResponse::Error(err.into()), @@ -276,6 +425,42 @@ impl InterfaceManager for PluginInterfaceManager { }; self.send_plugin_call_response(id, response) } + PluginOutput::EngineCall { context, id, call } => { + // Handle reading the pipeline data, if any + let exec_context = self.get_context(context)?; + let ctrlc = exec_context.as_ref().and_then(|c| c.0.ctrlc()); + let call = match call { + EngineCall::GetConfig => Ok(EngineCall::GetConfig), + EngineCall::GetPluginConfig => Ok(EngineCall::GetPluginConfig), + EngineCall::EvalClosure { + closure, + mut positional, + input, + redirect_stdout, + redirect_stderr, + } => { + // Add source to any plugin custom values in the arguments + for arg in positional.iter_mut() { + PluginCustomValue::add_source(arg, &self.state.identity); + } + self.read_pipeline_data(input, ctrlc) + .map(|input| EngineCall::EvalClosure { + closure, + positional, + input, + redirect_stdout, + redirect_stderr, + }) + } + }; + match call { + Ok(call) => self.send_engine_call(context, id, call), + // If there was an error with setting up the call, just write the error + Err(err) => self + .get_interface() + .write_engine_call_response(id, EngineCallResponse::Error(err)), + } + } } } @@ -302,6 +487,14 @@ impl InterfaceManager for PluginInterfaceManager { PipelineData::Empty | PipelineData::ExternalStream { .. } => Ok(data), } } + + fn consume_stream_message(&mut self, message: StreamMessage) -> Result<(), ShellError> { + // Keep track of streams that end so we know if we don't need the context anymore + if let StreamMessage::End(id) = message { + self.recv_stream_ended(id); + } + self.stream_manager.handle_message(message) + } } /// A reference through which a plugin can be interacted with during execution. @@ -330,8 +523,38 @@ impl PluginInterface { self.flush() } + /// Write an [`EngineCallResponse`]. Writes the full stream contained in any [`PipelineData`] + /// before returning. + pub(crate) fn write_engine_call_response( + &self, + id: EngineCallId, + response: EngineCallResponse, + ) -> Result<(), ShellError> { + // Set up any stream if necessary + let (response, writer) = match response { + EngineCallResponse::PipelineData(data) => { + let (header, writer) = self.init_write_pipeline_data(data)?; + (EngineCallResponse::PipelineData(header), Some(writer)) + } + // No pipeline data: + EngineCallResponse::Error(err) => (EngineCallResponse::Error(err), None), + EngineCallResponse::Config(config) => (EngineCallResponse::Config(config), None), + }; + + // Write the response, including the pipeline data header if present + self.write(PluginInput::EngineCallResponse(id, response))?; + self.flush()?; + + // If we have a stream to write, do it now + if let Some(writer) = writer { + writer.write_background()?; + } + + Ok(()) + } + /// Write a plugin call message. Returns the writer for the stream, and the receiver for - /// messages (e.g. response) related to the plugin call + /// messages - i.e. response and engine calls - related to the plugin call fn write_plugin_call( &self, call: PluginCall, @@ -354,17 +577,16 @@ impl PluginInterface { } PluginCall::Run(CallInfo { name, - call, + mut call, input, - config, }) => { + verify_call_args(&mut call, &self.state.identity)?; let (header, writer) = self.init_write_pipeline_data(input)?; ( PluginCall::Run(CallInfo { name, call, input: header, - config, }), writer, ) @@ -377,8 +599,9 @@ impl PluginInterface { .send(( id, PluginCallSubscription { - sender: tx, + sender: Some(tx), context, + remaining_streams_to_read: 0, }, )) .map_err(|_| ShellError::NushellFailed { @@ -397,22 +620,62 @@ impl PluginInterface { fn receive_plugin_call_response( &self, rx: mpsc::Receiver, + context: &Option, ) -> Result, ShellError> { - if let Ok(msg) = rx.recv() { - // Handle message from receiver + // Handle message from receiver + for msg in rx { match msg { - ReceivedPluginCallMessage::Response(resp) => Ok(resp), - ReceivedPluginCallMessage::Error(err) => Err(err), + ReceivedPluginCallMessage::Response(resp) => { + return Ok(resp); + } + ReceivedPluginCallMessage::Error(err) => { + return Err(err); + } + ReceivedPluginCallMessage::EngineCall(engine_call_id, engine_call) => { + self.handle_engine_call(engine_call_id, engine_call, context)?; + } } - } else { - // If we fail to get a response - Err(ShellError::PluginFailedToDecode { - msg: "Failed to receive response to plugin call".into(), - }) } + // If we fail to get a response + Err(ShellError::PluginFailedToDecode { + msg: "Failed to receive response to plugin call".into(), + }) } - /// Perform a plugin call. Input and output streams are handled automatically. + /// Handle an engine call and write the response. + fn handle_engine_call( + &self, + engine_call_id: EngineCallId, + engine_call: EngineCall, + context: &Option, + ) -> Result<(), ShellError> { + let resp = + handle_engine_call(engine_call, context).unwrap_or_else(EngineCallResponse::Error); + // Handle stream + let (resp, writer) = match resp { + EngineCallResponse::Error(error) => (EngineCallResponse::Error(error), None), + EngineCallResponse::Config(config) => (EngineCallResponse::Config(config), None), + EngineCallResponse::PipelineData(data) => { + match self.init_write_pipeline_data(data) { + Ok((header, writer)) => { + (EngineCallResponse::PipelineData(header), Some(writer)) + } + // just respond with the error if we fail to set it up + Err(err) => (EngineCallResponse::Error(err), None), + } + } + }; + // Write the response, then the stream + self.write(PluginInput::EngineCallResponse(engine_call_id, resp))?; + self.flush()?; + if let Some(writer) = writer { + writer.write_background()?; + } + Ok(()) + } + + /// Perform a plugin call. Input and output streams are handled, and engine calls are handled + /// too if there are any before the final response. fn plugin_call( &self, call: PluginCall, @@ -423,7 +686,7 @@ impl PluginInterface { // Finish writing stream in the background writer.write_background()?; - self.receive_plugin_call_response(rx) + self.receive_plugin_call_response(rx, context) } /// Get the command signatures from the plugin. @@ -471,6 +734,20 @@ impl PluginInterface { } } +/// Check that custom values in call arguments come from the right source +fn verify_call_args( + call: &mut crate::EvaluatedCall, + source: &Arc, +) -> Result<(), ShellError> { + for arg in call.positional.iter_mut() { + PluginCustomValue::verify_source(arg, source)?; + } + for arg in call.named.iter_mut().flat_map(|(_, arg)| arg.as_mut()) { + PluginCustomValue::verify_source(arg, source)?; + } + Ok(()) +} + impl Interface for PluginInterface { type Output = PluginInput; @@ -529,3 +806,44 @@ impl Drop for PluginInterface { } } } + +/// Handle an engine call. +pub(crate) fn handle_engine_call( + call: EngineCall, + context: &Option, +) -> Result, ShellError> { + let call_name = call.name(); + let require_context = || { + context.as_ref().ok_or_else(|| ShellError::GenericError { + error: "A plugin execution context is required for this engine call".into(), + msg: format!( + "attempted to call {} outside of a command invocation", + call_name + ), + span: None, + help: Some("this is probably a bug with the plugin".into()), + inner: vec![], + }) + }; + match call { + EngineCall::GetConfig => { + let context = require_context()?; + let config = Box::new(context.get_config()?); + Ok(EngineCallResponse::Config(config)) + } + EngineCall::GetPluginConfig => { + let context = require_context()?; + let plugin_config = context.get_plugin_config()?; + Ok(plugin_config.map_or_else(EngineCallResponse::empty, EngineCallResponse::value)) + } + EngineCall::EvalClosure { + closure, + positional, + input, + redirect_stdout, + redirect_stderr, + } => require_context()? + .eval_closure(closure, positional, input, redirect_stdout, redirect_stderr) + .map(EngineCallResponse::PipelineData), + } +} diff --git a/crates/nu-plugin/src/plugin/interface/plugin/tests.rs b/crates/nu-plugin/src/plugin/interface/plugin/tests.rs index 57b306f86a..68d5dfebf3 100644 --- a/crates/nu-plugin/src/plugin/interface/plugin/tests.rs +++ b/crates/nu-plugin/src/plugin/interface/plugin/tests.rs @@ -1,7 +1,11 @@ -use std::sync::mpsc; +use std::{ + sync::{mpsc, Arc}, + time::Duration, +}; use nu_protocol::{ - IntoInterruptiblePipelineData, PipelineData, PluginSignature, ShellError, Span, Spanned, Value, + engine::Closure, IntoInterruptiblePipelineData, PipelineData, PluginSignature, ShellError, + Span, Spanned, Value, }; use crate::{ @@ -12,15 +16,16 @@ use crate::{ }, protocol::{ test_util::{expected_test_custom_value, test_plugin_custom_value}, - CallInfo, CustomValueOp, ExternalStreamInfo, ListStreamInfo, PipelineDataHeader, - PluginCall, PluginCallId, PluginCustomValue, PluginInput, Protocol, ProtocolInfo, - RawStreamInfo, StreamData, StreamMessage, + CallInfo, CustomValueOp, EngineCall, EngineCallResponse, ExternalStreamInfo, + ListStreamInfo, PipelineDataHeader, PluginCall, PluginCallId, PluginCustomValue, + PluginInput, Protocol, ProtocolInfo, RawStreamInfo, StreamData, StreamMessage, }, EvaluatedCall, PluginCallResponse, PluginOutput, }; use super::{ - PluginCallSubscription, PluginInterface, PluginInterfaceManager, ReceivedPluginCallMessage, + Context, PluginCallSubscription, PluginInterface, PluginInterfaceManager, + ReceivedPluginCallMessage, }; #[test] @@ -185,8 +190,9 @@ fn fake_plugin_call( manager.plugin_call_subscriptions.insert( id, PluginCallSubscription { - sender: tx, + sender: Some(tx), context: None, + remaining_streams_to_read: 0, }, ); @@ -338,6 +344,282 @@ fn manager_consume_call_response_forwards_to_subscriber_with_pipeline_data( } } +#[test] +fn manager_consume_call_response_registers_streams() -> Result<(), ShellError> { + let mut manager = TestCase::new().plugin("test"); + manager.protocol_info = Some(ProtocolInfo::default()); + + for n in [0, 1] { + fake_plugin_call(&mut manager, n); + } + + // Check list streams, external streams + manager.consume(PluginOutput::CallResponse( + 0, + PluginCallResponse::PipelineData(PipelineDataHeader::ListStream(ListStreamInfo { id: 0 })), + ))?; + manager.consume(PluginOutput::CallResponse( + 1, + PluginCallResponse::PipelineData(PipelineDataHeader::ExternalStream(ExternalStreamInfo { + span: Span::test_data(), + stdout: Some(RawStreamInfo { + id: 1, + is_binary: false, + known_size: None, + }), + stderr: Some(RawStreamInfo { + id: 2, + is_binary: false, + known_size: None, + }), + exit_code: Some(ListStreamInfo { id: 3 }), + trim_end_newline: false, + })), + ))?; + + // ListStream should have one + if let Some(sub) = manager.plugin_call_subscriptions.get(&0) { + assert_eq!( + 1, sub.remaining_streams_to_read, + "ListStream remaining_streams_to_read should be 1" + ); + } else { + panic!("failed to find subscription for ListStream (0), maybe it was removed"); + } + assert_eq!( + Some(&0), + manager.plugin_call_input_streams.get(&0), + "plugin_call_input_streams[0] should be Some(0)" + ); + + // ExternalStream should have three + if let Some(sub) = manager.plugin_call_subscriptions.get(&1) { + assert_eq!( + 3, sub.remaining_streams_to_read, + "ExternalStream remaining_streams_to_read should be 3" + ); + } else { + panic!("failed to find subscription for ExternalStream (1), maybe it was removed"); + } + for n in [1, 2, 3] { + assert_eq!( + Some(&1), + manager.plugin_call_input_streams.get(&n), + "plugin_call_input_streams[{n}] should be Some(1)" + ); + } + + Ok(()) +} + +#[test] +fn manager_consume_engine_call_forwards_to_subscriber_with_pipeline_data() -> Result<(), ShellError> +{ + let mut manager = TestCase::new().plugin("test"); + manager.protocol_info = Some(ProtocolInfo::default()); + + let rx = fake_plugin_call(&mut manager, 37); + + manager.consume(PluginOutput::EngineCall { + context: 37, + id: 46, + call: EngineCall::EvalClosure { + closure: Spanned { + item: Closure { + block_id: 0, + captures: vec![], + }, + span: Span::test_data(), + }, + positional: vec![], + input: PipelineDataHeader::ListStream(ListStreamInfo { id: 2 }), + redirect_stdout: false, + redirect_stderr: false, + }, + })?; + + for i in 0..2 { + manager.consume(PluginOutput::Stream(StreamMessage::Data( + 2, + Value::test_int(i).into(), + )))?; + } + manager.consume(PluginOutput::Stream(StreamMessage::End(2)))?; + + // Make sure the streams end and we don't deadlock + drop(manager); + + let message = rx.try_recv().expect("failed to get plugin call message"); + + match message { + ReceivedPluginCallMessage::EngineCall(id, call) => { + assert_eq!(46, id, "id"); + match call { + EngineCall::EvalClosure { input, .. } => { + // Count the stream messages + assert_eq!(2, input.into_iter().count()); + Ok(()) + } + _ => panic!("unexpected call: {call:?}"), + } + } + _ => panic!("unexpected response message: {message:?}"), + } +} + +#[test] +fn manager_handle_engine_call_after_response_received() -> Result<(), ShellError> { + let test = TestCase::new(); + let mut manager = test.plugin("test"); + manager.protocol_info = Some(ProtocolInfo::default()); + + let bogus = Context(Arc::new(PluginExecutionBogusContext)); + + // Set up a situation identical to what we would find if the response had been read, but there + // was still a stream being processed. We have nowhere to send the engine call in that case, + // so the manager has to create a place to handle it. + manager.plugin_call_subscriptions.insert( + 0, + PluginCallSubscription { + sender: None, + context: Some(bogus), + remaining_streams_to_read: 1, + }, + ); + + manager.send_engine_call(0, 0, EngineCall::GetConfig)?; + + // Not really much choice but to wait here, as the thread will have been spawned in the + // background; we don't have a way to know if it's executed + let mut waited = 0; + while !test.has_unconsumed_write() { + if waited > 100 { + panic!("nothing written before timeout, expected engine call response"); + } else { + std::thread::sleep(Duration::from_millis(1)); + waited += 1; + } + } + + // The GetConfig call on bogus should result in an error response being written + match test.next_written().expect("nothing written") { + PluginInput::EngineCallResponse(id, resp) => { + assert_eq!(0, id, "id"); + match resp { + EngineCallResponse::Error(err) => { + assert!(err.to_string().contains("bogus"), "wrong error: {err}"); + } + _ => panic!("unexpected engine call response, expected error: {resp:?}"), + } + } + other => panic!("unexpected message, not engine call response: {other:?}"), + } + + // Whatever was used to make this happen should have been held onto, since spawning a thread + // is expensive + let sender = &manager + .plugin_call_subscriptions + .get(&0) + .expect("missing subscription 0") + .sender; + + assert!( + sender.is_some(), + "failed to keep spawned engine call handler channel" + ); + Ok(()) +} + +#[test] +fn manager_send_plugin_call_response_removes_context_only_if_no_streams_to_read( +) -> Result<(), ShellError> { + let mut manager = TestCase::new().plugin("test"); + + for n in [0, 1] { + manager.plugin_call_subscriptions.insert( + n, + PluginCallSubscription { + sender: None, + context: None, + remaining_streams_to_read: n as i32, + }, + ); + } + + for n in [0, 1] { + manager.send_plugin_call_response(n, PluginCallResponse::Signature(vec![]))?; + } + + // 0 should not still be present, but 1 should be + assert!( + !manager.plugin_call_subscriptions.contains_key(&0), + "didn't clean up when there weren't remaining streams" + ); + assert!( + manager.plugin_call_subscriptions.contains_key(&1), + "clean up even though there were remaining streams" + ); + Ok(()) +} + +#[test] +fn manager_consume_stream_end_removes_context_only_if_last_stream() -> Result<(), ShellError> { + let mut manager = TestCase::new().plugin("test"); + manager.protocol_info = Some(ProtocolInfo::default()); + + for n in [1, 2] { + manager.plugin_call_subscriptions.insert( + n, + PluginCallSubscription { + sender: None, + context: None, + remaining_streams_to_read: n as i32, + }, + ); + } + + // 1 owns [10], 2 owns [21, 22] + manager.plugin_call_input_streams.insert(10, 1); + manager.plugin_call_input_streams.insert(21, 2); + manager.plugin_call_input_streams.insert(22, 2); + + // Register the streams so we don't have errors + let streams: Vec<_> = [10, 21, 22] + .into_iter() + .map(|id| { + let interface = manager.get_interface(); + manager + .stream_manager + .get_handle() + .read_stream::(id, interface) + }) + .collect(); + + // Ending 10 should cause 1 to be removed + manager.consume(StreamMessage::End(10).into())?; + assert!( + !manager.plugin_call_subscriptions.contains_key(&1), + "contains(1) after End(10)" + ); + + // Ending 21 should not cause 2 to be removed + manager.consume(StreamMessage::End(21).into())?; + assert!( + manager.plugin_call_subscriptions.contains_key(&2), + "!contains(2) after End(21)" + ); + + // Ending 22 should cause 2 to be removed + manager.consume(StreamMessage::End(22).into())?; + assert!( + !manager.plugin_call_subscriptions.contains_key(&2), + "contains(2) after End(22)" + ); + + drop(streams); + Ok(()) +} + #[test] fn manager_prepare_pipeline_data_adds_source_to_values() -> Result<(), ShellError> { let manager = TestCase::new().plugin("test"); @@ -518,7 +800,6 @@ fn interface_write_plugin_call_writes_run_with_value_input() -> Result<(), Shell named: vec![], }, input: PipelineData::Value(Value::test_int(-1), None), - config: None, }), None, )?; @@ -557,7 +838,6 @@ fn interface_write_plugin_call_writes_run_with_stream_input() -> Result<(), Shel named: vec![], }, input: values.clone().into_pipeline_data(None), - config: None, }), None, )?; @@ -622,7 +902,7 @@ fn interface_receive_plugin_call_receives_response() -> Result<(), ShellError> { .expect("failed to send on new channel"); drop(tx); // so we don't deadlock on recv() - let response = interface.receive_plugin_call_response(rx)?; + let response = interface.receive_plugin_call_response(rx, &None)?; assert!( matches!(response, PluginCallResponse::Signature(_)), "wrong response: {response:?}" @@ -645,7 +925,7 @@ fn interface_receive_plugin_call_receives_error() -> Result<(), ShellError> { drop(tx); // so we don't deadlock on recv() let error = interface - .receive_plugin_call_response(rx) + .receive_plugin_call_response(rx, &None) .expect_err("did not receive error"); assert!( matches!(error, ShellError::ExternalNotSupported { .. }), @@ -654,6 +934,49 @@ fn interface_receive_plugin_call_receives_error() -> Result<(), ShellError> { Ok(()) } +#[test] +fn interface_receive_plugin_call_handles_engine_call() -> Result<(), ShellError> { + let test = TestCase::new(); + let interface = test.plugin("test").get_interface(); + + // Set up a fake channel just for the engine call + let (tx, rx) = mpsc::channel(); + tx.send(ReceivedPluginCallMessage::EngineCall( + 0, + EngineCall::GetConfig, + )) + .expect("failed to send on new channel"); + + // The context should be a bogus context, which will return an error for GetConfig + let context = Some(Context(Arc::new(PluginExecutionBogusContext))); + + // We don't actually send a response, so `receive_plugin_call_response` should actually return + // an error, but it should still do the engine call + drop(tx); + interface + .receive_plugin_call_response(rx, &context) + .expect_err("no error even though there was no response"); + + // Check for the engine call response output + match test + .next_written() + .expect("no engine call response written") + { + PluginInput::EngineCallResponse(id, resp) => { + assert_eq!(0, id, "id"); + match resp { + EngineCallResponse::Error(err) => { + assert!(err.to_string().contains("bogus"), "wrong error: {err}"); + } + _ => panic!("unexpected engine call response, maybe bogus is wrong: {resp:?}"), + } + } + other => panic!("unexpected message: {other:?}"), + } + assert!(!test.has_unconsumed_write()); + Ok(()) +} + /// Fake responses to requests for plugin call messages fn start_fake_plugin_call_responder( manager: PluginInterfaceManager, @@ -669,7 +992,11 @@ fn start_fake_plugin_call_responder( .take(take) { for message in f(id) { - sub.sender.send(message).expect("failed to send"); + sub.sender + .as_ref() + .expect("sender is None") + .send(message) + .expect("failed to send"); } } }) @@ -717,7 +1044,6 @@ fn interface_run() -> Result<(), ShellError> { named: vec![], }, input: PipelineData::Empty, - config: None, }, PluginExecutionBogusContext.into(), )?; diff --git a/crates/nu-plugin/src/plugin/mod.rs b/crates/nu-plugin/src/plugin/mod.rs index d0787d536f..43a58ab587 100644 --- a/crates/nu-plugin/src/plugin/mod.rs +++ b/crates/nu-plugin/src/plugin/mod.rs @@ -3,20 +3,24 @@ pub use declaration::PluginDeclaration; use nu_engine::documentation::get_flags_section; use std::collections::HashMap; use std::ffi::OsStr; -use std::sync::{Arc, Mutex}; +use std::sync::mpsc::TrySendError; +use std::sync::{mpsc, Arc, Mutex}; use crate::plugin::interface::{EngineInterfaceManager, ReceivedPluginCall}; -use crate::protocol::{CallInfo, CustomValueOp, LabeledError, PluginInput, PluginOutput}; +use crate::protocol::{ + CallInfo, CustomValueOp, LabeledError, PluginCustomValue, PluginInput, PluginOutput, +}; use crate::EncodingType; -use std::env; use std::fmt::Write; use std::io::{BufReader, Read, Write as WriteTrait}; use std::path::Path; use std::process::{Child, ChildStdout, Command as CommandSys, Stdio}; +use std::{env, thread}; -use nu_protocol::{PipelineData, PluginSignature, ShellError, Value}; +use nu_protocol::{PipelineData, PluginSignature, ShellError, Spanned, Value}; mod interface; +pub use interface::EngineInterface; pub(crate) use interface::PluginInterface; mod context; @@ -184,6 +188,10 @@ pub fn get_signature( /// If large amounts of data are expected to need to be received or produced, it may be more /// appropriate to implement [StreamingPlugin] instead. /// +/// The plugin must be able to be safely shared between threads, so that multiple invocations can +/// be run in parallel. If interior mutability is desired, consider synchronization primitives such +/// as [mutexes](std::sync::Mutex) and [channels](std::sync::mpsc). +/// /// # Examples /// Basic usage: /// ``` @@ -200,9 +208,9 @@ pub fn get_signature( /// } /// /// fn run( -/// &mut self, +/// &self, /// name: &str, -/// config: &Option, +/// engine: &EngineInterface, /// call: &EvaluatedCall, /// input: &Value, /// ) -> Result { @@ -211,10 +219,10 @@ pub fn get_signature( /// } /// /// # fn main() { -/// # serve_plugin(&mut HelloPlugin{}, MsgPackSerializer) +/// # serve_plugin(&HelloPlugin{}, MsgPackSerializer) /// # } /// ``` -pub trait Plugin { +pub trait Plugin: Sync { /// The signature of the plugin /// /// This method returns the [PluginSignature]s that describe the capabilities @@ -234,12 +242,15 @@ pub trait Plugin { /// metadata describing how the plugin was invoked and `input` contains the structured /// data passed to the command implemented by this [Plugin]. /// + /// `engine` provides an interface back to the Nushell engine. See [`EngineInterface`] docs for + /// details on what methods are available. + /// /// This variant does not support streaming. Consider implementing [StreamingPlugin] instead /// if streaming is desired. fn run( - &mut self, + &self, name: &str, - config: &Option, + engine: &EngineInterface, call: &EvaluatedCall, input: &Value, ) -> Result; @@ -270,9 +281,9 @@ pub trait Plugin { /// } /// /// fn run( -/// &mut self, +/// &self, /// name: &str, -/// config: &Option, +/// engine: &EngineInterface, /// call: &EvaluatedCall, /// input: PipelineData, /// ) -> Result { @@ -287,10 +298,10 @@ pub trait Plugin { /// } /// /// # fn main() { -/// # serve_plugin(&mut LowercasePlugin{}, MsgPackSerializer) +/// # serve_plugin(&LowercasePlugin{}, MsgPackSerializer) /// # } /// ``` -pub trait StreamingPlugin { +pub trait StreamingPlugin: Sync { /// The signature of the plugin /// /// This method returns the [PluginSignature]s that describe the capabilities @@ -315,9 +326,9 @@ pub trait StreamingPlugin { /// potentially large quantities of bytes. The API is more complex however, and [Plugin] is /// recommended instead if this is not a concern. fn run( - &mut self, + &self, name: &str, - config: &Option, + engine: &EngineInterface, call: &EvaluatedCall, input: PipelineData, ) -> Result; @@ -331,9 +342,9 @@ impl StreamingPlugin for T { } fn run( - &mut self, + &self, name: &str, - config: &Option, + engine: &EngineInterface, call: &EvaluatedCall, input: PipelineData, ) -> Result { @@ -342,7 +353,7 @@ impl StreamingPlugin for T { let span = input.span().unwrap_or(call.head); let input_value = input.into_value(span); // Wrap the output in PipelineData::Value - ::run(self, name, config, call, &input_value) + ::run(self, name, engine, call, &input_value) .map(|value| PipelineData::Value(value, None)) } } @@ -360,14 +371,14 @@ impl StreamingPlugin for T { /// # impl MyPlugin { fn new() -> Self { Self }} /// # impl Plugin for MyPlugin { /// # fn signature(&self) -> Vec {todo!();} -/// # fn run(&mut self, name: &str, config: &Option, call: &EvaluatedCall, input: &Value) +/// # fn run(&self, name: &str, engine: &EngineInterface, call: &EvaluatedCall, input: &Value) /// # -> Result {todo!();} /// # } /// fn main() { -/// serve_plugin(&mut MyPlugin::new(), MsgPackSerializer) +/// serve_plugin(&MyPlugin::new(), MsgPackSerializer) /// } /// ``` -pub fn serve_plugin(plugin: &mut impl StreamingPlugin, encoder: impl PluginEncoder + 'static) { +pub fn serve_plugin(plugin: &impl StreamingPlugin, encoder: impl PluginEncoder + 'static) { let mut args = env::args().skip(1); let number_of_args = args.len(); let first_arg = args.next(); @@ -487,61 +498,95 @@ pub fn serve_plugin(plugin: &mut impl StreamingPlugin, encoder: impl PluginEncod std::process::exit(1); }); - for plugin_call in call_receiver { - match plugin_call { - // Sending the signature back to nushell to create the declaration definition - ReceivedPluginCall::Signature { engine } => { - try_or_report!(engine, engine.write_signature(plugin.signature())); - } - // Run the plugin, handling any input or output streams - ReceivedPluginCall::Run { - engine, - call: - CallInfo { - name, - config, - call, - input, - }, - } => { - let result = plugin.run(&name, &config, &call, input); - let write_result = engine - .write_response(result) - .and_then(|writer| writer.write_background()); - try_or_report!(engine, write_result); - } - // Do an operation on a custom value - ReceivedPluginCall::CustomValueOp { - engine, - custom_value, - op, - } => { - let local_value = try_or_report!( - engine, - custom_value - .item - .deserialize_to_custom_value(custom_value.span) - ); - match op { - CustomValueOp::ToBaseValue => { - let result = local_value - .to_base_value(custom_value.span) - .map(|value| PipelineData::Value(value, None)); - let write_result = engine - .write_response(result) - .and_then(|writer| writer.write_background()); - try_or_report!(engine, write_result); + // Handle each Run plugin call on a thread + thread::scope(|scope| { + let run = |engine, call_info| { + let CallInfo { name, call, input } = call_info; + let result = plugin.run(&name, &engine, &call, input); + let write_result = engine + .write_response(result) + .and_then(|writer| writer.write()); + try_or_report!(engine, write_result); + }; + + // As an optimization: create one thread that can be reused for Run calls in sequence + let (run_tx, run_rx) = mpsc::sync_channel(0); + thread::Builder::new() + .name("plugin runner (primary)".into()) + .spawn_scoped(scope, move || { + for (engine, call) in run_rx { + run(engine, call); + } + }) + .unwrap_or_else(|err| { + // If we fail to spawn the runner thread, we should exit + eprintln!("Plugin `{plugin_name}` failed to launch: {err}"); + std::process::exit(1); + }); + + for plugin_call in call_receiver { + match plugin_call { + // Sending the signature back to nushell to create the declaration definition + ReceivedPluginCall::Signature { engine } => { + try_or_report!(engine, engine.write_signature(plugin.signature())); + } + // Run the plugin on a background thread, handling any input or output streams + ReceivedPluginCall::Run { engine, call } => { + // Try to run it on the primary thread + match run_tx.try_send((engine, call)) { + Ok(()) => (), + // If the primary thread isn't ready, spawn a secondary thread to do it + Err(TrySendError::Full((engine, call))) + | Err(TrySendError::Disconnected((engine, call))) => { + let engine_clone = engine.clone(); + try_or_report!( + engine_clone, + thread::Builder::new() + .name("plugin runner (secondary)".into()) + .spawn_scoped(scope, move || run(engine, call)) + .map_err(ShellError::from) + ); + } } } + // Do an operation on a custom value + ReceivedPluginCall::CustomValueOp { + engine, + custom_value, + op, + } => { + try_or_report!(engine, custom_value_op(&engine, custom_value, op)); + } } } - } + }); // This will stop the manager drop(interface); } -fn print_help(plugin: &mut impl StreamingPlugin, encoder: impl PluginEncoder) { +fn custom_value_op( + engine: &EngineInterface, + custom_value: Spanned, + op: CustomValueOp, +) -> Result<(), ShellError> { + let local_value = custom_value + .item + .deserialize_to_custom_value(custom_value.span)?; + match op { + CustomValueOp::ToBaseValue => { + let result = local_value + .to_base_value(custom_value.span) + .map(|value| PipelineData::Value(value, None)); + engine + .write_response(result) + .and_then(|writer| writer.write_background())?; + Ok(()) + } + } +} + +fn print_help(plugin: &impl StreamingPlugin, encoder: impl PluginEncoder) { println!("Nushell Plugin"); println!("Encoder: {}", encoder.name()); diff --git a/crates/nu-plugin/src/protocol/mod.rs b/crates/nu-plugin/src/protocol/mod.rs index a5e054ed15..07504a8e2c 100644 --- a/crates/nu-plugin/src/protocol/mod.rs +++ b/crates/nu-plugin/src/protocol/mod.rs @@ -9,13 +9,15 @@ mod tests; pub(crate) mod test_util; pub use evaluated_call::EvaluatedCall; -use nu_protocol::{PluginSignature, RawStream, ShellError, Span, Spanned, Value}; +use nu_protocol::{ + engine::Closure, Config, PipelineData, PluginSignature, RawStream, ShellError, Span, Spanned, + Value, +}; pub use plugin_custom_value::PluginCustomValue; -pub(crate) use protocol_info::ProtocolInfo; -use serde::{Deserialize, Serialize}; - +pub use protocol_info::ProtocolInfo; #[cfg(test)] -pub(crate) use protocol_info::Protocol; +pub use protocol_info::{Feature, Protocol}; +use serde::{Deserialize, Serialize}; /// A sequential identifier for a stream pub type StreamId = usize; @@ -23,6 +25,9 @@ pub type StreamId = usize; /// A sequential identifier for a [`PluginCall`] pub type PluginCallId = usize; +/// A sequential identifier for an [`EngineCall`] +pub type EngineCallId = usize; + /// Information about a plugin command invocation. This includes an [`EvaluatedCall`] as a /// serializable representation of [`nu_protocol::ast::Call`]. The type parameter determines /// the input type. @@ -34,8 +39,6 @@ pub struct CallInfo { pub call: EvaluatedCall, /// Pipeline input. This is usually [`nu_protocol::PipelineData`] or [`PipelineDataHeader`] pub input: D, - /// Plugin configuration, if available - pub config: Option, } /// The initial (and perhaps only) part of any [`nu_protocol::PipelineData`] sent over the wire. @@ -57,6 +60,30 @@ pub enum PipelineDataHeader { ExternalStream(ExternalStreamInfo), } +impl PipelineDataHeader { + /// Return a list of stream IDs embedded in the header + pub(crate) fn stream_ids(&self) -> Vec { + match self { + PipelineDataHeader::Empty => vec![], + PipelineDataHeader::Value(_) => vec![], + PipelineDataHeader::ListStream(info) => vec![info.id], + PipelineDataHeader::ExternalStream(info) => { + let mut out = vec![]; + if let Some(stdout) = &info.stdout { + out.push(stdout.id); + } + if let Some(stderr) = &info.stderr { + out.push(stderr.id); + } + if let Some(exit_code) = &info.exit_code { + out.push(exit_code.id); + } + out + } + } + } +} + /// Additional information about list (value) streams #[derive(Serialize, Deserialize, Debug, PartialEq, Eq, Clone)] pub struct ListStreamInfo { @@ -117,6 +144,9 @@ pub enum PluginInput { /// Don't expect any more plugin calls. Exit after all currently executing plugin calls are /// finished. Goodbye, + /// Response to an [`EngineCall`]. The ID should be the same one sent with the engine call this + /// is responding to + EngineCallResponse(EngineCallId, EngineCallResponse), /// Stream control or data message. Untagged to keep them as small as possible. /// /// For example, `Stream(Ack(0))` is encoded as `{"Ack": 0}` @@ -301,6 +331,15 @@ pub enum PluginOutput { /// A response to a [`PluginCall`]. The ID should be the same sent with the plugin call this /// is a response to CallResponse(PluginCallId, PluginCallResponse), + /// Execute an [`EngineCall`]. Engine calls must be executed within the `context` of a plugin + /// call, and the `id` should not have been used before + EngineCall { + /// The plugin call (by ID) to execute in the context of + context: PluginCallId, + /// A new identifier for this engine call. The response will reference this ID + id: EngineCallId, + call: EngineCall, + }, /// Stream control or data message. Untagged to keep them as small as possible. /// /// For example, `Stream(Ack(0))` is encoded as `{"Ack": 0}` @@ -324,3 +363,61 @@ impl From for PluginOutput { PluginOutput::Stream(stream_msg) } } + +/// A remote call back to the engine during the plugin's execution. +/// +/// The type parameter determines the input type, for calls that take pipeline data. +#[derive(Serialize, Deserialize, Debug, Clone)] +pub enum EngineCall { + /// Get the full engine configuration + GetConfig, + /// Get the plugin-specific configuration (`$env.config.plugins.NAME`) + GetPluginConfig, + /// Evaluate a closure with stream input/output + EvalClosure { + /// The closure to call. + /// + /// This may come from a [`Value::Closure`] passed in as an argument to the plugin. + closure: Spanned, + /// Positional arguments to add to the closure call + positional: Vec, + /// Input to the closure + input: D, + /// Whether to redirect stdout from external commands + redirect_stdout: bool, + /// Whether to redirect stderr from external commands + redirect_stderr: bool, + }, +} + +impl EngineCall { + /// Get the name of the engine call so it can be embedded in things like error messages + pub fn name(&self) -> &'static str { + match self { + EngineCall::GetConfig => "GetConfig", + EngineCall::GetPluginConfig => "GetPluginConfig", + EngineCall::EvalClosure { .. } => "EvalClosure", + } + } +} + +/// The response to an [EngineCall]. The type parameter determines the output type for pipeline +/// data. +#[derive(Serialize, Deserialize, Debug, Clone)] +pub enum EngineCallResponse { + Error(ShellError), + PipelineData(D), + Config(Box), +} + +impl EngineCallResponse { + /// Build an [`EngineCallResponse::PipelineData`] from a [`Value`] + pub(crate) fn value(value: Value) -> EngineCallResponse { + EngineCallResponse::PipelineData(PipelineData::Value(value, None)) + } + + /// An [`EngineCallResponse::PipelineData`] with [`PipelineData::Empty`] + pub(crate) const fn empty() -> EngineCallResponse { + EngineCallResponse::PipelineData(PipelineData::Empty) + } +} diff --git a/crates/nu-plugin/src/protocol/plugin_custom_value.rs b/crates/nu-plugin/src/protocol/plugin_custom_value.rs index 3b97070e15..439eb77580 100644 --- a/crates/nu-plugin/src/protocol/plugin_custom_value.rs +++ b/crates/nu-plugin/src/protocol/plugin_custom_value.rs @@ -146,6 +146,11 @@ impl PluginCustomValue { Self::add_source(list_value, source); } } + Value::Closure { ref mut val, .. } => { + for (_, captured_value) in val.captures.iter_mut() { + Self::add_source(captured_value, source); + } + } // All of these don't contain other values Value::Bool { .. } | Value::Int { .. } @@ -156,7 +161,6 @@ impl PluginCustomValue { | Value::String { .. } | Value::Glob { .. } | Value::Block { .. } - | Value::Closure { .. } | Value::Nothing { .. } | Value::Error { .. } | Value::Binary { .. } @@ -214,6 +218,10 @@ impl PluginCustomValue { Value::List { ref mut vals, .. } => vals .iter_mut() .try_for_each(|list_value| Self::verify_source(list_value, source)), + Value::Closure { ref mut val, .. } => val + .captures + .iter_mut() + .try_for_each(|(_, captured_value)| Self::verify_source(captured_value, source)), // All of these don't contain other values Value::Bool { .. } | Value::Int { .. } @@ -224,7 +232,6 @@ impl PluginCustomValue { | Value::String { .. } | Value::Glob { .. } | Value::Block { .. } - | Value::Closure { .. } | Value::Nothing { .. } | Value::Error { .. } | Value::Binary { .. } @@ -266,6 +273,11 @@ impl PluginCustomValue { Value::List { ref mut vals, .. } => vals .iter_mut() .try_for_each(Self::serialize_custom_values_in), + Value::Closure { ref mut val, .. } => val + .captures + .iter_mut() + .map(|(_, captured_value)| captured_value) + .try_for_each(Self::serialize_custom_values_in), // All of these don't contain other values Value::Bool { .. } | Value::Int { .. } @@ -276,7 +288,6 @@ impl PluginCustomValue { | Value::String { .. } | Value::Glob { .. } | Value::Block { .. } - | Value::Closure { .. } | Value::Nothing { .. } | Value::Error { .. } | Value::Binary { .. } @@ -316,6 +327,11 @@ impl PluginCustomValue { Value::List { ref mut vals, .. } => vals .iter_mut() .try_for_each(Self::deserialize_custom_values_in), + Value::Closure { ref mut val, .. } => val + .captures + .iter_mut() + .map(|(_, captured_value)| captured_value) + .try_for_each(Self::deserialize_custom_values_in), // All of these don't contain other values Value::Bool { .. } | Value::Int { .. } @@ -326,7 +342,6 @@ impl PluginCustomValue { | Value::String { .. } | Value::Glob { .. } | Value::Block { .. } - | Value::Closure { .. } | Value::Nothing { .. } | Value::Error { .. } | Value::Binary { .. } diff --git a/crates/nu-plugin/src/protocol/plugin_custom_value/tests.rs b/crates/nu-plugin/src/protocol/plugin_custom_value/tests.rs index 73a683ab91..4b798ab264 100644 --- a/crates/nu-plugin/src/protocol/plugin_custom_value/tests.rs +++ b/crates/nu-plugin/src/protocol/plugin_custom_value/tests.rs @@ -1,4 +1,6 @@ -use nu_protocol::{ast::RangeInclusion, record, CustomValue, Range, ShellError, Span, Value}; +use nu_protocol::{ + ast::RangeInclusion, engine::Closure, record, CustomValue, Range, ShellError, Span, Value, +}; use crate::{ plugin::PluginIdentity, @@ -180,6 +182,50 @@ fn add_source_nested_list() -> Result<(), ShellError> { }) } +fn check_closure_custom_values( + val: &Value, + indices: impl IntoIterator, + mut f: impl FnMut(usize, &dyn CustomValue) -> Result<(), ShellError>, +) -> Result<(), ShellError> { + let closure = val.as_closure()?; + for index in indices { + let val = closure + .captures + .get(index) + .unwrap_or_else(|| panic!("[{index}] not present in closure")); + let custom_value = val + .1 + .as_custom_value() + .unwrap_or_else(|_| panic!("[{index}] not custom value")); + f(index, custom_value)?; + } + Ok(()) +} + +#[test] +fn add_source_nested_closure() -> Result<(), ShellError> { + let orig_custom_val = Value::test_custom_value(Box::new(test_plugin_custom_value())); + let mut val = Value::test_closure(Closure { + block_id: 0, + captures: vec![(0, orig_custom_val.clone()), (1, orig_custom_val.clone())], + }); + let source = PluginIdentity::new_fake("foo"); + PluginCustomValue::add_source(&mut val, &source); + + check_closure_custom_values(&val, 0..=1, |index, custom_value| { + let plugin_custom_value: &PluginCustomValue = custom_value + .as_any() + .downcast_ref() + .unwrap_or_else(|| panic!("[{index}] not PluginCustomValue")); + assert_eq!( + Some(&source), + plugin_custom_value.source.as_ref(), + "[{index}] source not set correctly" + ); + Ok(()) + }) +} + #[test] fn verify_source_error_message() -> Result<(), ShellError> { let span = Span::new(5, 7); @@ -322,6 +368,40 @@ fn verify_source_nested_list() -> Result<(), ShellError> { Ok(()) } +#[test] +fn verify_source_nested_closure() -> Result<(), ShellError> { + let native_val = Value::test_custom_value(Box::new(TestCustomValue(32))); + let source = PluginIdentity::new_fake("test"); + for (name, mut val) in [ + ( + "first capture", + Value::test_closure(Closure { + block_id: 0, + captures: vec![(0, native_val.clone()), (1, Value::test_nothing())], + }), + ), + ( + "second capture", + Value::test_closure(Closure { + block_id: 0, + captures: vec![(0, Value::test_nothing()), (1, native_val.clone())], + }), + ), + ] { + PluginCustomValue::verify_source(&mut val, &source) + .expect_err(&format!("error not generated on {name}")); + } + + let mut ok_closure = Value::test_closure(Closure { + block_id: 0, + captures: vec![(0, Value::test_nothing())], + }); + PluginCustomValue::verify_source(&mut ok_closure, &source) + .expect("ok_closure should not generate error"); + + Ok(()) +} + #[test] fn serialize_in_root() -> Result<(), ShellError> { let span = Span::new(4, 10); @@ -406,6 +486,28 @@ fn serialize_in_list() -> Result<(), ShellError> { }) } +#[test] +fn serialize_in_closure() -> Result<(), ShellError> { + let orig_custom_val = Value::test_custom_value(Box::new(TestCustomValue(24))); + let mut val = Value::test_closure(Closure { + block_id: 0, + captures: vec![(0, orig_custom_val.clone()), (1, orig_custom_val.clone())], + }); + PluginCustomValue::serialize_custom_values_in(&mut val)?; + + check_closure_custom_values(&val, 0..=1, |index, custom_value| { + let plugin_custom_value: &PluginCustomValue = custom_value + .as_any() + .downcast_ref() + .unwrap_or_else(|| panic!("[{index}] not PluginCustomValue")); + assert_eq!( + "TestCustomValue", plugin_custom_value.name, + "[{index}] name not set correctly" + ); + Ok(()) + }) +} + #[test] fn deserialize_in_root() -> Result<(), ShellError> { let span = Span::new(4, 10); @@ -490,3 +592,26 @@ fn deserialize_in_list() -> Result<(), ShellError> { Ok(()) }) } + +#[test] +fn deserialize_in_closure() -> Result<(), ShellError> { + let orig_custom_val = Value::test_custom_value(Box::new(test_plugin_custom_value())); + let mut val = Value::test_closure(Closure { + block_id: 0, + captures: vec![(0, orig_custom_val.clone()), (1, orig_custom_val.clone())], + }); + PluginCustomValue::deserialize_custom_values_in(&mut val)?; + + check_closure_custom_values(&val, 0..=1, |index, custom_value| { + let test_custom_value: &TestCustomValue = custom_value + .as_any() + .downcast_ref() + .unwrap_or_else(|| panic!("[{index}] not TestCustomValue")); + assert_eq!( + expected_test_custom_value(), + *test_custom_value, + "[{index}] name not deserialized correctly" + ); + Ok(()) + }) +} diff --git a/crates/nu-plugin/src/serializers/tests.rs b/crates/nu-plugin/src/serializers/tests.rs index 4bac0b8b85..be727c3a6f 100644 --- a/crates/nu-plugin/src/serializers/tests.rs +++ b/crates/nu-plugin/src/serializers/tests.rs @@ -125,7 +125,6 @@ macro_rules! generate_tests { name: name.clone(), call: call.clone(), input: PipelineDataHeader::Value(input.clone()), - config: None, }); let plugin_input = PluginInput::Call(1, plugin_call); diff --git a/crates/nu-protocol/src/engine/engine_state.rs b/crates/nu-protocol/src/engine/engine_state.rs index 90157caed0..099753527f 100644 --- a/crates/nu-protocol/src/engine/engine_state.rs +++ b/crates/nu-protocol/src/engine/engine_state.rs @@ -855,6 +855,15 @@ impl EngineState { .expect("internal error: missing block") } + /// Optionally get a block by id, if it exists + /// + /// Prefer to use [`.get_block()`] in most cases - `BlockId`s that don't exist are normally a + /// compiler error. This only exists to stop plugins from crashing the engine if they send us + /// something invalid. + pub fn try_get_block(&self, block_id: BlockId) -> Option<&Block> { + self.blocks.get(block_id) + } + pub fn get_module(&self, module_id: ModuleId) -> &Module { self.modules .get(module_id) diff --git a/crates/nu_plugin_custom_values/src/main.rs b/crates/nu_plugin_custom_values/src/main.rs index a97274c4fd..0d1df03b0a 100644 --- a/crates/nu_plugin_custom_values/src/main.rs +++ b/crates/nu_plugin_custom_values/src/main.rs @@ -2,7 +2,7 @@ mod cool_custom_value; mod second_custom_value; use cool_custom_value::CoolCustomValue; -use nu_plugin::{serve_plugin, MsgPackSerializer, Plugin}; +use nu_plugin::{serve_plugin, EngineInterface, MsgPackSerializer, Plugin}; use nu_plugin::{EvaluatedCall, LabeledError}; use nu_protocol::{Category, PluginSignature, ShellError, SyntaxShape, Value}; use second_custom_value::SecondCustomValue; @@ -17,6 +17,11 @@ impl Plugin for CustomValuePlugin { .category(Category::Experimental), PluginSignature::build("custom-value generate2") .usage("PluginSignature for a plugin that generates a different custom value") + .optional( + "closure", + SyntaxShape::Closure(Some(vec![SyntaxShape::Any])), + "An optional closure to pass the custom value to", + ) .category(Category::Experimental), PluginSignature::build("custom-value update") .usage("PluginSignature for a plugin that updates a custom value") @@ -33,15 +38,15 @@ impl Plugin for CustomValuePlugin { } fn run( - &mut self, + &self, name: &str, - _config: &Option, + engine: &EngineInterface, call: &EvaluatedCall, input: &Value, ) -> Result { match name { "custom-value generate" => self.generate(call, input), - "custom-value generate2" => self.generate2(call, input), + "custom-value generate2" => self.generate2(engine, call), "custom-value update" => self.update(call, input), "custom-value update-arg" => self.update(call, &call.req(0)?), _ => Err(LabeledError { @@ -54,15 +59,30 @@ impl Plugin for CustomValuePlugin { } impl CustomValuePlugin { - fn generate(&mut self, call: &EvaluatedCall, _input: &Value) -> Result { + fn generate(&self, call: &EvaluatedCall, _input: &Value) -> Result { Ok(CoolCustomValue::new("abc").into_value(call.head)) } - fn generate2(&mut self, call: &EvaluatedCall, _input: &Value) -> Result { - Ok(SecondCustomValue::new("xyz").into_value(call.head)) + fn generate2( + &self, + engine: &EngineInterface, + call: &EvaluatedCall, + ) -> Result { + let second_custom_value = SecondCustomValue::new("xyz").into_value(call.head); + // If we were passed a closure, execute that instead + if let Some(closure) = call.opt(0)? { + let result = engine.eval_closure( + &closure, + vec![second_custom_value.clone()], + Some(second_custom_value), + )?; + Ok(result) + } else { + Ok(second_custom_value) + } } - fn update(&mut self, call: &EvaluatedCall, input: &Value) -> Result { + fn update(&self, call: &EvaluatedCall, input: &Value) -> Result { if let Ok(mut value) = CoolCustomValue::try_from_value(input) { value.cool += "xyz"; return Ok(value.into_value(call.head)); @@ -84,5 +104,5 @@ impl CustomValuePlugin { } fn main() { - serve_plugin(&mut CustomValuePlugin, MsgPackSerializer {}) + serve_plugin(&CustomValuePlugin, MsgPackSerializer {}) } diff --git a/crates/nu_plugin_example/src/example.rs b/crates/nu_plugin_example/src/example.rs index b9b8a2de51..784f79eb34 100644 --- a/crates/nu_plugin_example/src/example.rs +++ b/crates/nu_plugin_example/src/example.rs @@ -1,13 +1,14 @@ -use nu_plugin::{EvaluatedCall, LabeledError}; +use nu_plugin::{EngineInterface, EvaluatedCall, LabeledError}; use nu_protocol::{record, Value}; pub struct Example; impl Example { pub fn config( &self, - config: &Option, + engine: &EngineInterface, call: &EvaluatedCall, ) -> Result { + let config = engine.get_plugin_config()?; match config { Some(config) => Ok(config.clone()), None => Err(LabeledError { diff --git a/crates/nu_plugin_example/src/main.rs b/crates/nu_plugin_example/src/main.rs index 2effdfe780..1be1469316 100644 --- a/crates/nu_plugin_example/src/main.rs +++ b/crates/nu_plugin_example/src/main.rs @@ -6,7 +6,7 @@ fn main() { // used to encode and decode the messages. The available options are // MsgPackSerializer and JsonSerializer. Both are defined in the serializer // folder in nu-plugin. - serve_plugin(&mut Example {}, MsgPackSerializer {}) + serve_plugin(&Example {}, MsgPackSerializer {}) // Note // When creating plugins in other languages one needs to consider how a plugin diff --git a/crates/nu_plugin_example/src/nu/mod.rs b/crates/nu_plugin_example/src/nu/mod.rs index d8b7893d83..84117861f0 100644 --- a/crates/nu_plugin_example/src/nu/mod.rs +++ b/crates/nu_plugin_example/src/nu/mod.rs @@ -1,5 +1,5 @@ use crate::Example; -use nu_plugin::{EvaluatedCall, LabeledError, Plugin}; +use nu_plugin::{EngineInterface, EvaluatedCall, LabeledError, Plugin}; use nu_protocol::{Category, PluginExample, PluginSignature, SyntaxShape, Type, Value}; impl Plugin for Example { @@ -52,9 +52,9 @@ impl Plugin for Example { } fn run( - &mut self, + &self, name: &str, - config: &Option, + engine: &EngineInterface, call: &EvaluatedCall, input: &Value, ) -> Result { @@ -63,7 +63,7 @@ impl Plugin for Example { "nu-example-1" => self.test1(call, input), "nu-example-2" => self.test2(call, input), "nu-example-3" => self.test3(call, input), - "nu-example-config" => self.config(config, call), + "nu-example-config" => self.config(engine, call), _ => Err(LabeledError { label: "Plugin call with wrong name signature".into(), msg: "the signature used to call the plugin does not match any name in the plugin signature vector".into(), diff --git a/crates/nu_plugin_formats/src/lib.rs b/crates/nu_plugin_formats/src/lib.rs index 26710e6abe..0f29dd9d7a 100644 --- a/crates/nu_plugin_formats/src/lib.rs +++ b/crates/nu_plugin_formats/src/lib.rs @@ -1,7 +1,7 @@ mod from; use from::{eml, ics, ini, vcf}; -use nu_plugin::{EvaluatedCall, LabeledError, Plugin}; +use nu_plugin::{EngineInterface, EvaluatedCall, LabeledError, Plugin}; use nu_protocol::{Category, PluginSignature, SyntaxShape, Type, Value}; pub struct FromCmds; @@ -39,9 +39,9 @@ impl Plugin for FromCmds { } fn run( - &mut self, + &self, name: &str, - _config: &Option, + _engine: &EngineInterface, call: &EvaluatedCall, input: &Value, ) -> Result { diff --git a/crates/nu_plugin_formats/src/main.rs b/crates/nu_plugin_formats/src/main.rs index daa64bbfba..e6c7179781 100644 --- a/crates/nu_plugin_formats/src/main.rs +++ b/crates/nu_plugin_formats/src/main.rs @@ -2,5 +2,5 @@ use nu_plugin::{serve_plugin, MsgPackSerializer}; use nu_plugin_formats::FromCmds; fn main() { - serve_plugin(&mut FromCmds, MsgPackSerializer {}) + serve_plugin(&FromCmds, MsgPackSerializer {}) } diff --git a/crates/nu_plugin_gstat/src/main.rs b/crates/nu_plugin_gstat/src/main.rs index ecd10f2a5b..d28d6d7e52 100644 --- a/crates/nu_plugin_gstat/src/main.rs +++ b/crates/nu_plugin_gstat/src/main.rs @@ -2,5 +2,5 @@ use nu_plugin::{serve_plugin, MsgPackSerializer}; use nu_plugin_gstat::GStat; fn main() { - serve_plugin(&mut GStat::new(), MsgPackSerializer {}) + serve_plugin(&GStat::new(), MsgPackSerializer {}) } diff --git a/crates/nu_plugin_gstat/src/nu/mod.rs b/crates/nu_plugin_gstat/src/nu/mod.rs index 8e99ca2fa5..2cf3ddb813 100644 --- a/crates/nu_plugin_gstat/src/nu/mod.rs +++ b/crates/nu_plugin_gstat/src/nu/mod.rs @@ -1,5 +1,5 @@ use crate::GStat; -use nu_plugin::{EvaluatedCall, LabeledError, Plugin}; +use nu_plugin::{EngineInterface, EvaluatedCall, LabeledError, Plugin}; use nu_protocol::{Category, PluginSignature, Spanned, SyntaxShape, Value}; impl Plugin for GStat { @@ -11,9 +11,9 @@ impl Plugin for GStat { } fn run( - &mut self, + &self, name: &str, - _config: &Option, + _engine: &EngineInterface, call: &EvaluatedCall, input: &Value, ) -> Result { diff --git a/crates/nu_plugin_inc/src/inc.rs b/crates/nu_plugin_inc/src/inc.rs index 894dd1368c..1aa9e59481 100644 --- a/crates/nu_plugin_inc/src/inc.rs +++ b/crates/nu_plugin_inc/src/inc.rs @@ -2,20 +2,20 @@ use nu_plugin::LabeledError; use nu_protocol::{ast::CellPath, Span, Value}; use semver::{BuildMetadata, Prerelease, Version}; -#[derive(Debug, Eq, PartialEq)] +#[derive(Debug, Eq, PartialEq, Clone, Copy)] pub enum Action { SemVerAction(SemVerAction), Default, } -#[derive(Debug, Eq, PartialEq)] +#[derive(Debug, Eq, PartialEq, Clone, Copy)] pub enum SemVerAction { Major, Minor, Patch, } -#[derive(Default)] +#[derive(Default, Clone)] pub struct Inc { pub error: Option, pub cell_path: Option, diff --git a/crates/nu_plugin_inc/src/main.rs b/crates/nu_plugin_inc/src/main.rs index a6b6ff0617..47bcb3f950 100644 --- a/crates/nu_plugin_inc/src/main.rs +++ b/crates/nu_plugin_inc/src/main.rs @@ -2,5 +2,5 @@ use nu_plugin::{serve_plugin, JsonSerializer}; use nu_plugin_inc::Inc; fn main() { - serve_plugin(&mut Inc::new(), JsonSerializer {}) + serve_plugin(&Inc::new(), JsonSerializer {}) } diff --git a/crates/nu_plugin_inc/src/nu/mod.rs b/crates/nu_plugin_inc/src/nu/mod.rs index 5d2b6fa0a1..0dc7b078ac 100644 --- a/crates/nu_plugin_inc/src/nu/mod.rs +++ b/crates/nu_plugin_inc/src/nu/mod.rs @@ -1,6 +1,6 @@ use crate::inc::SemVerAction; use crate::Inc; -use nu_plugin::{EvaluatedCall, LabeledError, Plugin}; +use nu_plugin::{EngineInterface, EvaluatedCall, LabeledError, Plugin}; use nu_protocol::{ast::CellPath, PluginSignature, SyntaxShape, Value}; impl Plugin for Inc { @@ -26,9 +26,9 @@ impl Plugin for Inc { } fn run( - &mut self, + &self, name: &str, - _config: &Option, + _engine: &EngineInterface, call: &EvaluatedCall, input: &Value, ) -> Result { @@ -36,20 +36,22 @@ impl Plugin for Inc { return Ok(Value::nothing(call.head)); } + let mut inc = self.clone(); + let cell_path: Option = call.opt(0)?; - self.cell_path = cell_path; + inc.cell_path = cell_path; if call.has_flag("major")? { - self.for_semver(SemVerAction::Major); + inc.for_semver(SemVerAction::Major); } if call.has_flag("minor")? { - self.for_semver(SemVerAction::Minor); + inc.for_semver(SemVerAction::Minor); } if call.has_flag("patch")? { - self.for_semver(SemVerAction::Patch); + inc.for_semver(SemVerAction::Patch); } - self.inc(call.head, input) + inc.inc(call.head, input) } } diff --git a/crates/nu_plugin_query/src/main.rs b/crates/nu_plugin_query/src/main.rs index 96cdde9f1b..e65bd29c6f 100644 --- a/crates/nu_plugin_query/src/main.rs +++ b/crates/nu_plugin_query/src/main.rs @@ -2,5 +2,5 @@ use nu_plugin::{serve_plugin, JsonSerializer}; use nu_plugin_query::Query; fn main() { - serve_plugin(&mut Query {}, JsonSerializer {}) + serve_plugin(&Query {}, JsonSerializer {}) } diff --git a/crates/nu_plugin_query/src/nu/mod.rs b/crates/nu_plugin_query/src/nu/mod.rs index fcd5eaa4c3..b726ce1961 100644 --- a/crates/nu_plugin_query/src/nu/mod.rs +++ b/crates/nu_plugin_query/src/nu/mod.rs @@ -1,5 +1,5 @@ use crate::Query; -use nu_plugin::{EvaluatedCall, LabeledError, Plugin}; +use nu_plugin::{EngineInterface, EvaluatedCall, LabeledError, Plugin}; use nu_protocol::{Category, PluginExample, PluginSignature, Spanned, SyntaxShape, Value}; impl Plugin for Query { @@ -46,9 +46,9 @@ impl Plugin for Query { } fn run( - &mut self, + &self, name: &str, - _config: &Option, + _engine: &EngineInterface, call: &EvaluatedCall, input: &Value, ) -> Result { diff --git a/crates/nu_plugin_stream_example/README.md b/crates/nu_plugin_stream_example/README.md index cf1a8dc971..7957319bb0 100644 --- a/crates/nu_plugin_stream_example/README.md +++ b/crates/nu_plugin_stream_example/README.md @@ -46,3 +46,18 @@ strings on input will be concatenated into an external stream (raw input) on std Hello worldhowareyou + +## `stream_example for-each` + +This command demonstrates executing closures on values in streams. Each value received on the input +will be printed to the plugin's stderr. This works even with external commands. + +> ```nushell +> ls | get name | stream_example for-each { |f| ^file $f } +> ``` + + CODE_OF_CONDUCT.md: ASCII text + + CONTRIBUTING.md: ASCII text, with very long lines (303) + + ... diff --git a/crates/nu_plugin_stream_example/src/example.rs b/crates/nu_plugin_stream_example/src/example.rs index cd57165f79..4c432d03bd 100644 --- a/crates/nu_plugin_stream_example/src/example.rs +++ b/crates/nu_plugin_stream_example/src/example.rs @@ -1,5 +1,5 @@ -use nu_plugin::{EvaluatedCall, LabeledError}; -use nu_protocol::{ListStream, PipelineData, RawStream, Value}; +use nu_plugin::{EngineInterface, EvaluatedCall, LabeledError}; +use nu_protocol::{IntoInterruptiblePipelineData, ListStream, PipelineData, RawStream, Value}; pub struct Example; @@ -64,4 +64,52 @@ impl Example { trim_end_newline: false, }) } + + pub fn for_each( + &self, + engine: &EngineInterface, + call: &EvaluatedCall, + input: PipelineData, + ) -> Result { + let closure = call.req(0)?; + let config = engine.get_config()?; + for value in input { + let result = engine.eval_closure(&closure, vec![value.clone()], Some(value))?; + eprintln!("{}", result.to_expanded_string(", ", &config)); + } + Ok(PipelineData::Empty) + } + + pub fn generate( + &self, + engine: &EngineInterface, + call: &EvaluatedCall, + ) -> Result { + let engine = engine.clone(); + let call = call.clone(); + let initial: Value = call.req(0)?; + let closure = call.req(1)?; + + let mut next = (!initial.is_nothing()).then_some(initial); + + Ok(std::iter::from_fn(move || { + next.take() + .and_then(|value| { + engine + .eval_closure(&closure, vec![value.clone()], Some(value)) + .and_then(|record| { + if record.is_nothing() { + Ok(None) + } else { + let record = record.as_record()?; + next = record.get("next").cloned(); + Ok(record.get("out").cloned()) + } + }) + .transpose() + }) + .map(|result| result.unwrap_or_else(|err| Value::error(err, call.head))) + }) + .into_pipeline_data(None)) + } } diff --git a/crates/nu_plugin_stream_example/src/main.rs b/crates/nu_plugin_stream_example/src/main.rs index f40146790e..538a0283aa 100644 --- a/crates/nu_plugin_stream_example/src/main.rs +++ b/crates/nu_plugin_stream_example/src/main.rs @@ -6,7 +6,7 @@ fn main() { // used to encode and decode the messages. The available options are // MsgPackSerializer and JsonSerializer. Both are defined in the serializer // folder in nu-plugin. - serve_plugin(&mut Example {}, MsgPackSerializer {}) + serve_plugin(&Example {}, MsgPackSerializer {}) // Note // When creating plugins in other languages one needs to consider how a plugin diff --git a/crates/nu_plugin_stream_example/src/nu/mod.rs b/crates/nu_plugin_stream_example/src/nu/mod.rs index 1422de5d8c..6592f7dba9 100644 --- a/crates/nu_plugin_stream_example/src/nu/mod.rs +++ b/crates/nu_plugin_stream_example/src/nu/mod.rs @@ -1,5 +1,5 @@ use crate::Example; -use nu_plugin::{EvaluatedCall, LabeledError, StreamingPlugin}; +use nu_plugin::{EngineInterface, EvaluatedCall, LabeledError, StreamingPlugin}; use nu_protocol::{ Category, PipelineData, PluginExample, PluginSignature, Span, SyntaxShape, Type, Value, }; @@ -57,13 +57,50 @@ impl StreamingPlugin for Example { result: Some(Value::string("ab", span)), }]) .category(Category::Experimental), + PluginSignature::build("stream_example for-each") + .usage("Example execution of a closure with a stream") + .extra_usage("Prints each value the closure returns to stderr") + .input_output_type(Type::ListStream, Type::Nothing) + .required( + "closure", + SyntaxShape::Closure(Some(vec![SyntaxShape::Any])), + "The closure to run for each input value", + ) + .plugin_examples(vec![PluginExample { + example: "ls | get name | stream_example for-each { |f| ^file $f }".into(), + description: "example with an external command".into(), + result: None, + }]) + .category(Category::Experimental), + PluginSignature::build("stream_example generate") + .usage("Example execution of a closure to produce a stream") + .extra_usage("See the builtin `generate` command") + .input_output_type(Type::Nothing, Type::ListStream) + .required( + "initial", + SyntaxShape::Any, + "The initial value to pass to the closure" + ) + .required( + "closure", + SyntaxShape::Closure(Some(vec![SyntaxShape::Any])), + "The closure to run to generate values", + ) + .plugin_examples(vec![PluginExample { + example: "stream_example generate 0 { |i| if $i <= 10 { {out: $i, next: ($i + 2)} } }".into(), + description: "Generate a sequence of numbers".into(), + result: Some(Value::test_list( + [0, 2, 4, 6, 8, 10].into_iter().map(Value::test_int).collect(), + )), + }]) + .category(Category::Experimental), ] } fn run( - &mut self, + &self, name: &str, - _config: &Option, + engine: &EngineInterface, call: &EvaluatedCall, input: PipelineData, ) -> Result { @@ -76,6 +113,8 @@ impl StreamingPlugin for Example { "stream_example seq" => self.seq(call, input), "stream_example sum" => self.sum(call, input), "stream_example collect-external" => self.collect_external(call, input), + "stream_example for-each" => self.for_each(engine, call, input), + "stream_example generate" => self.generate(engine, call), _ => Err(LabeledError { label: "Plugin call with wrong name signature".into(), msg: "the signature used to call the plugin does not match any name in the plugin signature vector".into(), diff --git a/tests/plugins/custom_values.rs b/tests/plugins/custom_values.rs index 0e58ab632f..6db4883921 100644 --- a/tests/plugins/custom_values.rs +++ b/tests/plugins/custom_values.rs @@ -54,6 +54,20 @@ fn can_generate_and_updated_multiple_types_of_custom_values() { ); } +#[test] +fn can_generate_custom_value_and_pass_through_closure() { + let actual = nu_with_plugins!( + cwd: "tests", + plugin: ("nu_plugin_custom_values"), + "custom-value generate2 { custom-value update }" + ); + + assert_eq!( + actual.out, + "I used to be a DIFFERENT custom value! (xyzabc)" + ); +} + #[test] fn can_get_describe_plugin_custom_values() { let actual = nu_with_plugins!( diff --git a/tests/plugins/stream.rs b/tests/plugins/stream.rs index 732fd7c6a9..6850893215 100644 --- a/tests/plugins/stream.rs +++ b/tests/plugins/stream.rs @@ -164,3 +164,25 @@ fn collect_external_big_stream() { assert_eq!(actual.out, "10000"); } + +#[test] +fn for_each_prints_on_stderr() { + let actual = nu_with_plugins!( + cwd: "tests/fixtures/formats", + plugin: ("nu_plugin_stream_example"), + "[a b c] | stream_example for-each { $in }" + ); + + assert_eq!(actual.err, "a\nb\nc\n"); +} + +#[test] +fn generate_sequence() { + let actual = nu_with_plugins!( + cwd: "tests/fixtures/formats", + plugin: ("nu_plugin_stream_example"), + "stream_example generate 0 { |i| if $i <= 10 { {out: $i, next: ($i + 2)} } } | to json --raw" + ); + + assert_eq!(actual.out, "[0,2,4,6,8,10]"); +}