Allow plugins to set environment variables in their caller's scope (#12204)

# Description

Adds the `AddEnvVar` plugin call, which allows plugins to set
environment variables in the caller's scope. This is the first engine
call that mutates the caller's stack, and opens the door to more
operations like this if needed.

This also comes with an extra benefit: in doing this, I needed to
refactor how context was handled, and I was able to avoid cloning
`EngineInterface` / `Stack` / `Call` in most cases that plugin calls are
used. They now only need to be cloned if the plugin call returns a
stream. The performance increase is welcome (5.5x faster on `inc`!):

```nushell
# Before
> timeit { 1..100 | each { |i| $"2.0.($i)" | inc -p } }
405ms 941µs 952ns
# After
> timeit { 1..100 | each { |i| $"2.0.($i)" | inc -p } }
73ms 68µs 749ns
```

# User-Facing Changes
- New engine call: `add_env_var()`
- Performance enhancement for plugin calls

# Tests + Formatting
- 🟢 `toolkit fmt`
- 🟢 `toolkit clippy`
- 🟢 `toolkit test`
- 🟢 `toolkit test stdlib`

# After Submitting
- [x] Document env manipulation in plugins guide
- [x] Document `AddEnvVar` in plugin protocol
This commit is contained in:
Devyn Cairns 2024-03-15 04:45:45 -07:00 committed by GitHub
parent 687fbc49c8
commit f6faf73e02
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
12 changed files with 329 additions and 142 deletions

View File

@ -59,6 +59,7 @@ mod plugin;
mod protocol; mod protocol;
mod sequence; mod sequence;
mod serializers; mod serializers;
mod util;
pub use plugin::{ pub use plugin::{
serve_plugin, EngineInterface, Plugin, PluginCommand, PluginEncoder, SimplePluginCommand, serve_plugin, EngineInterface, Plugin, PluginCommand, PluginEncoder, SimplePluginCommand,

View File

@ -1,4 +1,5 @@
use std::{ use std::{
borrow::Cow,
collections::HashMap, collections::HashMap,
sync::{atomic::AtomicBool, Arc}, sync::{atomic::AtomicBool, Arc},
}; };
@ -10,6 +11,8 @@ use nu_protocol::{
Config, IntoSpanned, IoStream, PipelineData, PluginIdentity, ShellError, Span, Spanned, Value, Config, IntoSpanned, IoStream, PipelineData, PluginIdentity, ShellError, Span, Spanned, Value,
}; };
use crate::util::MutableCow;
/// Object safe trait for abstracting operations required of the plugin context. /// Object safe trait for abstracting operations required of the plugin context.
pub(crate) trait PluginExecutionContext: Send + Sync { pub(crate) trait PluginExecutionContext: Send + Sync {
/// The [Span] for the command execution (`call.head`) /// The [Span] for the command execution (`call.head`)
@ -26,8 +29,10 @@ pub(crate) trait PluginExecutionContext: Send + Sync {
fn get_env_var(&self, name: &str) -> Result<Option<Value>, ShellError>; fn get_env_var(&self, name: &str) -> Result<Option<Value>, ShellError>;
/// Get all environment variables /// Get all environment variables
fn get_env_vars(&self) -> Result<HashMap<String, Value>, ShellError>; fn get_env_vars(&self) -> Result<HashMap<String, Value>, ShellError>;
// Get current working directory /// Get current working directory
fn get_current_dir(&self) -> Result<Spanned<String>, ShellError>; fn get_current_dir(&self) -> Result<Spanned<String>, ShellError>;
/// Set an environment variable
fn add_env_var(&mut self, name: String, value: Value) -> Result<(), ShellError>;
/// Evaluate a closure passed to the plugin /// Evaluate a closure passed to the plugin
fn eval_closure( fn eval_closure(
&self, &self,
@ -37,33 +42,35 @@ pub(crate) trait PluginExecutionContext: Send + Sync {
redirect_stdout: bool, redirect_stdout: bool,
redirect_stderr: bool, redirect_stderr: bool,
) -> Result<PipelineData, ShellError>; ) -> Result<PipelineData, ShellError>;
/// Create an owned version of the context with `'static` lifetime
fn boxed(&self) -> Box<dyn PluginExecutionContext>;
} }
/// The execution context of a plugin command. /// The execution context of a plugin command. Can be borrowed.
pub(crate) struct PluginExecutionCommandContext { pub(crate) struct PluginExecutionCommandContext<'a> {
identity: Arc<PluginIdentity>, identity: Arc<PluginIdentity>,
engine_state: EngineState, engine_state: Cow<'a, EngineState>,
stack: Stack, stack: MutableCow<'a, Stack>,
call: Call, call: Cow<'a, Call>,
} }
impl PluginExecutionCommandContext { impl<'a> PluginExecutionCommandContext<'a> {
pub fn new( pub fn new(
identity: Arc<PluginIdentity>, identity: Arc<PluginIdentity>,
engine_state: &EngineState, engine_state: &'a EngineState,
stack: &Stack, stack: &'a mut Stack,
call: &Call, call: &'a Call,
) -> PluginExecutionCommandContext { ) -> PluginExecutionCommandContext<'a> {
PluginExecutionCommandContext { PluginExecutionCommandContext {
identity, identity,
engine_state: engine_state.clone(), engine_state: Cow::Borrowed(engine_state),
stack: stack.clone(), stack: MutableCow::Borrowed(stack),
call: call.clone(), call: Cow::Borrowed(call),
} }
} }
} }
impl PluginExecutionContext for PluginExecutionCommandContext { impl<'a> PluginExecutionContext for PluginExecutionCommandContext<'a> {
fn command_span(&self) -> Span { fn command_span(&self) -> Span {
self.call.head self.call.head
} }
@ -131,6 +138,11 @@ impl PluginExecutionContext for PluginExecutionCommandContext {
Ok(cwd.into_spanned(self.call.head)) Ok(cwd.into_spanned(self.call.head))
} }
fn add_env_var(&mut self, name: String, value: Value) -> Result<(), ShellError> {
self.stack.add_env_var(name, value);
Ok(())
}
fn eval_closure( fn eval_closure(
&self, &self,
closure: Spanned<Closure>, closure: Spanned<Closure>,
@ -191,6 +203,15 @@ impl PluginExecutionContext for PluginExecutionCommandContext {
eval_block_with_early_return(&self.engine_state, stack, block, input) eval_block_with_early_return(&self.engine_state, stack, block, input)
} }
fn boxed(&self) -> Box<dyn PluginExecutionContext + 'static> {
Box::new(PluginExecutionCommandContext {
identity: self.identity.clone(),
engine_state: Cow::Owned(self.engine_state.clone().into_owned()),
stack: self.stack.owned(),
call: Cow::Owned(self.call.clone().into_owned()),
})
}
} }
/// A bogus execution context for testing that doesn't really implement anything properly /// A bogus execution context for testing that doesn't really implement anything properly
@ -239,6 +260,12 @@ impl PluginExecutionContext for PluginExecutionBogusContext {
}) })
} }
fn add_env_var(&mut self, _name: String, _value: Value) -> Result<(), ShellError> {
Err(ShellError::NushellFailed {
msg: "add_env_var not implemented on bogus".into(),
})
}
fn eval_closure( fn eval_closure(
&self, &self,
_closure: Spanned<Closure>, _closure: Spanned<Closure>,
@ -251,4 +278,8 @@ impl PluginExecutionContext for PluginExecutionBogusContext {
msg: "eval_closure not implemented on bogus".into(), msg: "eval_closure not implemented on bogus".into(),
}) })
} }
fn boxed(&self) -> Box<dyn PluginExecutionContext + 'static> {
Box::new(PluginExecutionBogusContext)
}
} }

View File

@ -108,12 +108,12 @@ impl Command for PluginDeclaration {
})?; })?;
// Create the context to execute in - this supports engine calls and custom values // Create the context to execute in - this supports engine calls and custom values
let context = Arc::new(PluginExecutionCommandContext::new( let mut context = PluginExecutionCommandContext::new(
self.source.identity.clone(), self.source.identity.clone(),
engine_state, engine_state,
stack, stack,
call, call,
)); );
plugin.run( plugin.run(
CallInfo { CallInfo {
@ -121,7 +121,7 @@ impl Command for PluginDeclaration {
call: evaluated_call, call: evaluated_call,
input, input,
}, },
context, &mut context,
) )
} }

View File

@ -458,6 +458,9 @@ impl EngineInterface {
EngineCall::GetEnvVar(name) => (EngineCall::GetEnvVar(name), Default::default()), EngineCall::GetEnvVar(name) => (EngineCall::GetEnvVar(name), Default::default()),
EngineCall::GetEnvVars => (EngineCall::GetEnvVars, Default::default()), EngineCall::GetEnvVars => (EngineCall::GetEnvVars, Default::default()),
EngineCall::GetCurrentDir => (EngineCall::GetCurrentDir, Default::default()), EngineCall::GetCurrentDir => (EngineCall::GetCurrentDir, Default::default()),
EngineCall::AddEnvVar(name, value) => {
(EngineCall::AddEnvVar(name, value), Default::default())
}
}; };
// Register the channel // Register the channel
@ -622,6 +625,30 @@ impl EngineInterface {
} }
} }
/// Set an environment variable in the caller's scope.
///
/// If called after the plugin response has already been sent (i.e. during a stream), this will
/// only affect the environment for engine calls related to this plugin call, and will not be
/// propagated to the environment of the caller.
///
/// # Example
/// ```rust,no_run
/// # use nu_protocol::{Value, ShellError};
/// # use nu_plugin::EngineInterface;
/// # fn example(engine: &EngineInterface) -> Result<(), ShellError> {
/// engine.add_env_var("FOO", Value::test_string("bar"))
/// # }
/// ```
pub fn add_env_var(&self, name: impl Into<String>, value: Value) -> Result<(), ShellError> {
match self.engine_call(EngineCall::AddEnvVar(name.into(), value))? {
EngineCallResponse::PipelineData(_) => Ok(()),
EngineCallResponse::Error(err) => Err(err),
_ => Err(ShellError::PluginFailedToDecode {
msg: "Received unexpected response type for EngineCall::AddEnvVar".into(),
}),
}
}
/// Ask the engine to evaluate a closure. Input to the closure is passed as a stream, and the /// Ask the engine to evaluate a closure. Input to the closure is passed as a stream, and the
/// output is available as a stream. /// output is available as a stream.
/// ///

View File

@ -953,6 +953,20 @@ fn interface_get_env_vars() -> Result<(), ShellError> {
Ok(()) Ok(())
} }
#[test]
fn interface_add_env_var() -> Result<(), ShellError> {
let test = TestCase::new();
let manager = test.engine();
let interface = manager.interface_for_context(0);
start_fake_plugin_call_responder(manager, 1, move |_| EngineCallResponse::empty());
interface.add_env_var("FOO", Value::test_string("bar"))?;
assert!(test.has_unconsumed_write());
Ok(())
}
#[test] #[test]
fn interface_eval_closure_with_stream() -> Result<(), ShellError> { fn interface_eval_closure_with_stream() -> Result<(), ShellError> {
let test = TestCase::new(); let test = TestCase::new();

View File

@ -2,7 +2,7 @@
use std::{ use std::{
collections::{btree_map, BTreeMap}, collections::{btree_map, BTreeMap},
sync::{mpsc, Arc, OnceLock}, sync::{atomic::AtomicBool, mpsc, Arc, OnceLock},
}; };
use nu_protocol::{ use nu_protocol::{
@ -44,8 +44,7 @@ enum ReceivedPluginCallMessage {
} }
/// Context for plugin call execution /// Context for plugin call execution
#[derive(Clone)] pub(crate) struct Context(Box<dyn PluginExecutionContext>);
pub(crate) struct Context(Arc<dyn PluginExecutionContext>);
impl std::fmt::Debug for Context { impl std::fmt::Debug for Context {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
@ -70,7 +69,7 @@ struct PluginInterfaceState {
/// Sequence for generating stream ids /// Sequence for generating stream ids
stream_id_sequence: Sequence, stream_id_sequence: Sequence,
/// Sender to subscribe to a plugin call response /// Sender to subscribe to a plugin call response
plugin_call_subscription_sender: mpsc::Sender<(PluginCallId, PluginCallSubscription)>, plugin_call_subscription_sender: mpsc::Sender<(PluginCallId, PluginCallState)>,
/// An error that should be propagated to further plugin calls /// An error that should be propagated to further plugin calls
error: OnceLock<ShellError>, error: OnceLock<ShellError>,
/// The synchronized output writer /// The synchronized output writer
@ -91,14 +90,15 @@ impl std::fmt::Debug for PluginInterfaceState {
} }
} }
/// Sent to the [`PluginInterfaceManager`] before making a plugin call to indicate interest in its /// State that the manager keeps for each plugin call during its lifetime.
/// response.
#[derive(Debug)] #[derive(Debug)]
struct PluginCallSubscription { struct PluginCallState {
/// The sender back to the thread that is waiting for the plugin call response /// The sender back to the thread that is waiting for the plugin call response
sender: Option<mpsc::Sender<ReceivedPluginCallMessage>>, sender: Option<mpsc::Sender<ReceivedPluginCallMessage>>,
/// Optional context for the environment of a plugin call for servicing engine calls /// Interrupt signal to be used for stream iterators
context: Option<Context>, ctrlc: Option<Arc<AtomicBool>>,
/// Channel to receive context on to be used if needed
context_rx: Option<mpsc::Receiver<Context>>,
/// Number of streams that still need to be read from the plugin call response /// Number of streams that still need to be read from the plugin call response
remaining_streams_to_read: i32, remaining_streams_to_read: i32,
} }
@ -112,10 +112,10 @@ pub(crate) struct PluginInterfaceManager {
stream_manager: StreamManager, stream_manager: StreamManager,
/// Protocol version info, set after `Hello` received /// Protocol version info, set after `Hello` received
protocol_info: Option<ProtocolInfo>, protocol_info: Option<ProtocolInfo>,
/// Subscriptions for messages related to plugin calls /// State related to plugin calls
plugin_call_subscriptions: BTreeMap<PluginCallId, PluginCallSubscription>, plugin_call_states: BTreeMap<PluginCallId, PluginCallState>,
/// Receiver for plugin call subscriptions /// Receiver for plugin call subscriptions
plugin_call_subscription_receiver: mpsc::Receiver<(PluginCallId, PluginCallSubscription)>, plugin_call_subscription_receiver: mpsc::Receiver<(PluginCallId, PluginCallState)>,
/// Tracker for which plugin call streams being read belong to /// Tracker for which plugin call streams being read belong to
/// ///
/// This is necessary so we know when we can remove context for plugin calls /// This is necessary so we know when we can remove context for plugin calls
@ -142,7 +142,7 @@ impl PluginInterfaceManager {
}), }),
stream_manager: StreamManager::new(), stream_manager: StreamManager::new(),
protocol_info: None, protocol_info: None,
plugin_call_subscriptions: BTreeMap::new(), plugin_call_states: BTreeMap::new(),
plugin_call_subscription_receiver: subscription_rx, plugin_call_subscription_receiver: subscription_rx,
plugin_call_input_streams: BTreeMap::new(), plugin_call_input_streams: BTreeMap::new(),
gc: None, gc: None,
@ -158,9 +158,9 @@ impl PluginInterfaceManager {
/// Consume pending messages in the `plugin_call_subscription_receiver` /// Consume pending messages in the `plugin_call_subscription_receiver`
fn receive_plugin_call_subscriptions(&mut self) { fn receive_plugin_call_subscriptions(&mut self) {
while let Ok((id, subscription)) = self.plugin_call_subscription_receiver.try_recv() { while let Ok((id, state)) = self.plugin_call_subscription_receiver.try_recv() {
if let btree_map::Entry::Vacant(e) = self.plugin_call_subscriptions.entry(id) { if let btree_map::Entry::Vacant(e) = self.plugin_call_states.entry(id) {
e.insert(subscription); e.insert(state);
} else { } else {
log::warn!("Duplicate plugin call ID ignored: {id}"); log::warn!("Duplicate plugin call ID ignored: {id}");
} }
@ -172,8 +172,8 @@ impl PluginInterfaceManager {
self.plugin_call_input_streams.insert(stream_id, call_id); self.plugin_call_input_streams.insert(stream_id, call_id);
// Increment the number of streams on the subscription so context stays alive // Increment the number of streams on the subscription so context stays alive
self.receive_plugin_call_subscriptions(); self.receive_plugin_call_subscriptions();
if let Some(sub) = self.plugin_call_subscriptions.get_mut(&call_id) { if let Some(state) = self.plugin_call_states.get_mut(&call_id) {
sub.remaining_streams_to_read += 1; state.remaining_streams_to_read += 1;
} }
// Add a lock to the garbage collector for each stream // Add a lock to the garbage collector for each stream
if let Some(ref gc) = self.gc { if let Some(ref gc) = self.gc {
@ -184,8 +184,7 @@ impl PluginInterfaceManager {
/// Track the end of an incoming stream /// Track the end of an incoming stream
fn recv_stream_ended(&mut self, stream_id: StreamId) { fn recv_stream_ended(&mut self, stream_id: StreamId) {
if let Some(call_id) = self.plugin_call_input_streams.remove(&stream_id) { if let Some(call_id) = self.plugin_call_input_streams.remove(&stream_id) {
if let btree_map::Entry::Occupied(mut e) = self.plugin_call_subscriptions.entry(call_id) if let btree_map::Entry::Occupied(mut e) = self.plugin_call_states.entry(call_id) {
{
e.get_mut().remaining_streams_to_read -= 1; e.get_mut().remaining_streams_to_read -= 1;
// Remove the subscription if there are no more streams to be read. // Remove the subscription if there are no more streams to be read.
if e.get().remaining_streams_to_read <= 0 { if e.get().remaining_streams_to_read <= 0 {
@ -200,14 +199,14 @@ impl PluginInterfaceManager {
} }
} }
/// Find the context corresponding to the given plugin call id /// Find the ctrlc signal corresponding to the given plugin call id
fn get_context(&mut self, id: PluginCallId) -> Result<Option<Context>, ShellError> { fn get_ctrlc(&mut self, id: PluginCallId) -> Result<Option<Arc<AtomicBool>>, ShellError> {
// Make sure we're up to date // Make sure we're up to date
self.receive_plugin_call_subscriptions(); self.receive_plugin_call_subscriptions();
// Find the subscription and return the context // Find the subscription and return the context
self.plugin_call_subscriptions self.plugin_call_states
.get(&id) .get(&id)
.map(|sub| sub.context.clone()) .map(|state| state.ctrlc.clone())
.ok_or_else(|| ShellError::PluginFailedToDecode { .ok_or_else(|| ShellError::PluginFailedToDecode {
msg: format!("Unknown plugin call ID: {id}"), msg: format!("Unknown plugin call ID: {id}"),
}) })
@ -222,7 +221,7 @@ impl PluginInterfaceManager {
// Ensure we're caught up on the subscriptions made // Ensure we're caught up on the subscriptions made
self.receive_plugin_call_subscriptions(); self.receive_plugin_call_subscriptions();
if let btree_map::Entry::Occupied(mut e) = self.plugin_call_subscriptions.entry(id) { if let btree_map::Entry::Occupied(mut e) = self.plugin_call_states.entry(id) {
// Remove the subscription sender, since this will be the last message. // Remove the subscription sender, since this will be the last message.
// //
// We can spawn a new one if we need it for engine calls. // We can spawn a new one if we need it for engine calls.
@ -254,11 +253,23 @@ impl PluginInterfaceManager {
) -> Result<&mpsc::Sender<ReceivedPluginCallMessage>, ShellError> { ) -> Result<&mpsc::Sender<ReceivedPluginCallMessage>, ShellError> {
let interface = self.get_interface(); let interface = self.get_interface();
if let Some(sub) = self.plugin_call_subscriptions.get_mut(&id) { if let Some(state) = self.plugin_call_states.get_mut(&id) {
if sub.sender.is_none() { if state.sender.is_none() {
let (tx, rx) = mpsc::channel(); let (tx, rx) = mpsc::channel();
let context = sub.context.clone(); let context_rx =
state
.context_rx
.take()
.ok_or_else(|| ShellError::NushellFailed {
msg: "Tried to spawn the fallback engine call handler more than once"
.into(),
})?;
let handler = move || { let handler = move || {
// We receive on the thread so that we don't block the reader thread
let mut context = context_rx
.recv()
.ok() // The plugin call won't send context if it's not required.
.map(|c| c.0);
for msg in rx { for msg in rx {
// This thread only handles engine calls. // This thread only handles engine calls.
match msg { match msg {
@ -266,7 +277,7 @@ impl PluginInterfaceManager {
if let Err(err) = interface.handle_engine_call( if let Err(err) = interface.handle_engine_call(
engine_call_id, engine_call_id,
engine_call, engine_call,
&context, context.as_deref_mut(),
) { ) {
log::warn!( log::warn!(
"Error in plugin post-response engine call handler: \ "Error in plugin post-response engine call handler: \
@ -286,8 +297,8 @@ impl PluginInterfaceManager {
.name("plugin engine call handler".into()) .name("plugin engine call handler".into())
.spawn(handler) .spawn(handler)
.expect("failed to spawn thread"); .expect("failed to spawn thread");
sub.sender = Some(tx); state.sender = Some(tx);
Ok(sub.sender.as_ref().unwrap_or_else(|| unreachable!())) Ok(state.sender.as_ref().unwrap_or_else(|| unreachable!()))
} else { } else {
Err(ShellError::NushellFailed { Err(ShellError::NushellFailed {
msg: "Tried to spawn the fallback engine call handler before the plugin call \ msg: "Tried to spawn the fallback engine call handler before the plugin call \
@ -313,7 +324,7 @@ impl PluginInterfaceManager {
self.receive_plugin_call_subscriptions(); self.receive_plugin_call_subscriptions();
// Don't remove the sender, as there could be more calls or responses // Don't remove the sender, as there could be more calls or responses
if let Some(subscription) = self.plugin_call_subscriptions.get(&plugin_call_id) { if let Some(subscription) = self.plugin_call_states.get(&plugin_call_id) {
let msg = ReceivedPluginCallMessage::EngineCall(engine_call_id, call); let msg = ReceivedPluginCallMessage::EngineCall(engine_call_id, call);
// Call if there's an error sending the engine call // Call if there's an error sending the engine call
let send_error = |this: &Self| { let send_error = |this: &Self| {
@ -374,9 +385,7 @@ impl PluginInterfaceManager {
let _ = self.stream_manager.broadcast_read_error(err.clone()); let _ = self.stream_manager.broadcast_read_error(err.clone());
// Error to call waiters // Error to call waiters
self.receive_plugin_call_subscriptions(); self.receive_plugin_call_subscriptions();
for subscription in for subscription in std::mem::take(&mut self.plugin_call_states).into_values() {
std::mem::take(&mut self.plugin_call_subscriptions).into_values()
{
let _ = subscription let _ = subscription
.sender .sender
.as_ref() .as_ref()
@ -460,15 +469,14 @@ impl InterfaceManager for PluginInterfaceManager {
PluginCallResponse::PipelineData(data) => { PluginCallResponse::PipelineData(data) => {
// If there's an error with initializing this stream, change it to a plugin // If there's an error with initializing this stream, change it to a plugin
// error response, but send it anyway // error response, but send it anyway
let exec_context = self.get_context(id)?; let ctrlc = self.get_ctrlc(id)?;
let ctrlc = exec_context.as_ref().and_then(|c| c.0.ctrlc());
// Register the streams in the response // Register the streams in the response
for stream_id in data.stream_ids() { for stream_id in data.stream_ids() {
self.recv_stream_started(id, stream_id); self.recv_stream_started(id, stream_id);
} }
match self.read_pipeline_data(data, ctrlc) { match self.read_pipeline_data(data, ctrlc.as_ref()) {
Ok(data) => PluginCallResponse::PipelineData(data), Ok(data) => PluginCallResponse::PipelineData(data),
Err(err) => PluginCallResponse::Error(err.into()), Err(err) => PluginCallResponse::Error(err.into()),
} }
@ -485,14 +493,14 @@ impl InterfaceManager for PluginInterfaceManager {
} }
PluginOutput::EngineCall { context, id, call } => { PluginOutput::EngineCall { context, id, call } => {
// Handle reading the pipeline data, if any // Handle reading the pipeline data, if any
let exec_context = self.get_context(context)?; let ctrlc = self.get_ctrlc(context)?;
let ctrlc = exec_context.as_ref().and_then(|c| c.0.ctrlc());
let call = match call { let call = match call {
EngineCall::GetConfig => Ok(EngineCall::GetConfig), EngineCall::GetConfig => Ok(EngineCall::GetConfig),
EngineCall::GetPluginConfig => Ok(EngineCall::GetPluginConfig), EngineCall::GetPluginConfig => Ok(EngineCall::GetPluginConfig),
EngineCall::GetEnvVar(name) => Ok(EngineCall::GetEnvVar(name)), EngineCall::GetEnvVar(name) => Ok(EngineCall::GetEnvVar(name)),
EngineCall::GetEnvVars => Ok(EngineCall::GetEnvVars), EngineCall::GetEnvVars => Ok(EngineCall::GetEnvVars),
EngineCall::GetCurrentDir => Ok(EngineCall::GetCurrentDir), EngineCall::GetCurrentDir => Ok(EngineCall::GetCurrentDir),
EngineCall::AddEnvVar(name, value) => Ok(EngineCall::AddEnvVar(name, value)),
EngineCall::EvalClosure { EngineCall::EvalClosure {
closure, closure,
mut positional, mut positional,
@ -504,13 +512,14 @@ impl InterfaceManager for PluginInterfaceManager {
for arg in positional.iter_mut() { for arg in positional.iter_mut() {
PluginCustomValue::add_source(arg, &self.state.source); PluginCustomValue::add_source(arg, &self.state.source);
} }
self.read_pipeline_data(input, ctrlc) self.read_pipeline_data(input, ctrlc.as_ref()).map(|input| {
.map(|input| EngineCall::EvalClosure { EngineCall::EvalClosure {
closure, closure,
positional, positional,
input, input,
redirect_stdout, redirect_stdout,
redirect_stderr, redirect_stderr,
}
}) })
} }
}; };
@ -622,7 +631,8 @@ impl PluginInterface {
fn write_plugin_call( fn write_plugin_call(
&self, &self,
call: PluginCall<PipelineData>, call: PluginCall<PipelineData>,
context: Option<Context>, ctrlc: Option<Arc<AtomicBool>>,
context_rx: mpsc::Receiver<Context>,
) -> Result< ) -> Result<
( (
PipelineDataWriter<Self>, PipelineDataWriter<Self>,
@ -662,9 +672,10 @@ impl PluginInterface {
.plugin_call_subscription_sender .plugin_call_subscription_sender
.send(( .send((
id, id,
PluginCallSubscription { PluginCallState {
sender: Some(tx), sender: Some(tx),
context, ctrlc,
context_rx: Some(context_rx),
remaining_streams_to_read: 0, remaining_streams_to_read: 0,
}, },
)) ))
@ -703,19 +714,26 @@ impl PluginInterface {
fn receive_plugin_call_response( fn receive_plugin_call_response(
&self, &self,
rx: mpsc::Receiver<ReceivedPluginCallMessage>, rx: mpsc::Receiver<ReceivedPluginCallMessage>,
context: &Option<Context>, mut context: Option<&mut (dyn PluginExecutionContext + '_)>,
context_tx: mpsc::Sender<Context>,
) -> Result<PluginCallResponse<PipelineData>, ShellError> { ) -> Result<PluginCallResponse<PipelineData>, ShellError> {
// Handle message from receiver // Handle message from receiver
for msg in rx { for msg in rx {
match msg { match msg {
ReceivedPluginCallMessage::Response(resp) => { ReceivedPluginCallMessage::Response(resp) => {
if resp.has_stream() {
// If the response has a stream, we need to register the context
if let Some(context) = context {
let _ = context_tx.send(Context(context.boxed()));
}
}
return Ok(resp); return Ok(resp);
} }
ReceivedPluginCallMessage::Error(err) => { ReceivedPluginCallMessage::Error(err) => {
return Err(err); return Err(err);
} }
ReceivedPluginCallMessage::EngineCall(engine_call_id, engine_call) => { ReceivedPluginCallMessage::EngineCall(engine_call_id, engine_call) => {
self.handle_engine_call(engine_call_id, engine_call, context)?; self.handle_engine_call(engine_call_id, engine_call, context.as_deref_mut())?;
} }
} }
} }
@ -730,7 +748,7 @@ impl PluginInterface {
&self, &self,
engine_call_id: EngineCallId, engine_call_id: EngineCallId,
engine_call: EngineCall<PipelineData>, engine_call: EngineCall<PipelineData>,
context: &Option<Context>, context: Option<&mut (dyn PluginExecutionContext + '_)>,
) -> Result<(), ShellError> { ) -> Result<(), ShellError> {
let resp = let resp =
handle_engine_call(engine_call, context).unwrap_or_else(EngineCallResponse::Error); handle_engine_call(engine_call, context).unwrap_or_else(EngineCallResponse::Error);
@ -763,7 +781,7 @@ impl PluginInterface {
fn plugin_call( fn plugin_call(
&self, &self,
call: PluginCall<PipelineData>, call: PluginCall<PipelineData>,
context: &Option<Context>, context: Option<&mut dyn PluginExecutionContext>,
) -> Result<PluginCallResponse<PipelineData>, ShellError> { ) -> Result<PluginCallResponse<PipelineData>, ShellError> {
// Check for an error in the state first, and return it if set. // Check for an error in the state first, and return it if set.
if let Some(error) = self.state.error.get() { if let Some(error) = self.state.error.get() {
@ -777,17 +795,24 @@ impl PluginInterface {
gc.increment_locks(1); gc.increment_locks(1);
} }
let (writer, rx) = self.write_plugin_call(call, context.clone())?; // Create the channel to send context on if needed
let (context_tx, context_rx) = mpsc::channel();
let (writer, rx) = self.write_plugin_call(
call,
context.as_ref().and_then(|c| c.ctrlc().cloned()),
context_rx,
)?;
// Finish writing stream in the background // Finish writing stream in the background
writer.write_background()?; writer.write_background()?;
self.receive_plugin_call_response(rx, context) self.receive_plugin_call_response(rx, context, context_tx)
} }
/// Get the command signatures from the plugin. /// Get the command signatures from the plugin.
pub(crate) fn get_signature(&self) -> Result<Vec<PluginSignature>, ShellError> { pub(crate) fn get_signature(&self) -> Result<Vec<PluginSignature>, ShellError> {
match self.plugin_call(PluginCall::Signature, &None)? { match self.plugin_call(PluginCall::Signature, None)? {
PluginCallResponse::Signature(sigs) => Ok(sigs), PluginCallResponse::Signature(sigs) => Ok(sigs),
PluginCallResponse::Error(err) => Err(err.into()), PluginCallResponse::Error(err) => Err(err.into()),
_ => Err(ShellError::PluginFailedToDecode { _ => Err(ShellError::PluginFailedToDecode {
@ -800,10 +825,9 @@ impl PluginInterface {
pub(crate) fn run( pub(crate) fn run(
&self, &self,
call: CallInfo<PipelineData>, call: CallInfo<PipelineData>,
context: Arc<impl PluginExecutionContext + 'static>, context: &mut dyn PluginExecutionContext,
) -> Result<PipelineData, ShellError> { ) -> Result<PipelineData, ShellError> {
let context = Some(Context(context)); match self.plugin_call(PluginCall::Run(call), Some(context))? {
match self.plugin_call(PluginCall::Run(call), &context)? {
PluginCallResponse::PipelineData(data) => Ok(data), PluginCallResponse::PipelineData(data) => Ok(data),
PluginCallResponse::Error(err) => Err(err.into()), PluginCallResponse::Error(err) => Err(err.into()),
_ => Err(ShellError::PluginFailedToDecode { _ => Err(ShellError::PluginFailedToDecode {
@ -821,7 +845,7 @@ impl PluginInterface {
let op_name = op.name(); let op_name = op.name();
let span = value.span; let span = value.span;
let call = PluginCall::CustomValueOp(value, op); let call = PluginCall::CustomValueOp(value, op);
match self.plugin_call(call, &None)? { match self.plugin_call(call, None)? {
PluginCallResponse::PipelineData(out_data) => Ok(out_data.into_value(span)), PluginCallResponse::PipelineData(out_data) => Ok(out_data.into_value(span)),
PluginCallResponse::Error(err) => Err(err.into()), PluginCallResponse::Error(err) => Err(err.into()),
_ => Err(ShellError::PluginFailedToDecode { _ => Err(ShellError::PluginFailedToDecode {
@ -869,7 +893,7 @@ impl PluginInterface {
value.into_spanned(Span::unknown()), value.into_spanned(Span::unknown()),
CustomValueOp::PartialCmp(other_value), CustomValueOp::PartialCmp(other_value),
); );
match self.plugin_call(call, &None)? { match self.plugin_call(call, None)? {
PluginCallResponse::Ordering(ordering) => Ok(ordering), PluginCallResponse::Ordering(ordering) => Ok(ordering),
PluginCallResponse::Error(err) => Err(err.into()), PluginCallResponse::Error(err) => Err(err.into()),
_ => Err(ShellError::PluginFailedToDecode { _ => Err(ShellError::PluginFailedToDecode {
@ -977,11 +1001,11 @@ impl Drop for PluginInterface {
/// Handle an engine call. /// Handle an engine call.
pub(crate) fn handle_engine_call( pub(crate) fn handle_engine_call(
call: EngineCall<PipelineData>, call: EngineCall<PipelineData>,
context: &Option<Context>, context: Option<&mut (dyn PluginExecutionContext + '_)>,
) -> Result<EngineCallResponse<PipelineData>, ShellError> { ) -> Result<EngineCallResponse<PipelineData>, ShellError> {
let call_name = call.name(); let call_name = call.name();
let require_context = || {
context.as_ref().ok_or_else(|| ShellError::GenericError { let context = context.ok_or_else(|| ShellError::GenericError {
error: "A plugin execution context is required for this engine call".into(), error: "A plugin execution context is required for this engine call".into(),
msg: format!( msg: format!(
"attempted to call {} outside of a command invocation", "attempted to call {} outside of a command invocation",
@ -990,43 +1014,40 @@ pub(crate) fn handle_engine_call(
span: None, span: None,
help: Some("this is probably a bug with the plugin".into()), help: Some("this is probably a bug with the plugin".into()),
inner: vec![], inner: vec![],
}) })?;
};
match call { match call {
EngineCall::GetConfig => { EngineCall::GetConfig => {
let context = require_context()?;
let config = Box::new(context.get_config()?); let config = Box::new(context.get_config()?);
Ok(EngineCallResponse::Config(config)) Ok(EngineCallResponse::Config(config))
} }
EngineCall::GetPluginConfig => { EngineCall::GetPluginConfig => {
let context = require_context()?;
let plugin_config = context.get_plugin_config()?; let plugin_config = context.get_plugin_config()?;
Ok(plugin_config.map_or_else(EngineCallResponse::empty, EngineCallResponse::value)) Ok(plugin_config.map_or_else(EngineCallResponse::empty, EngineCallResponse::value))
} }
EngineCall::GetEnvVar(name) => { EngineCall::GetEnvVar(name) => {
let context = require_context()?;
let value = context.get_env_var(&name)?; let value = context.get_env_var(&name)?;
Ok(value.map_or_else(EngineCallResponse::empty, EngineCallResponse::value)) Ok(value.map_or_else(EngineCallResponse::empty, EngineCallResponse::value))
} }
EngineCall::GetEnvVars => { EngineCall::GetEnvVars => context.get_env_vars().map(EngineCallResponse::ValueMap),
let context = require_context()?;
context.get_env_vars().map(EngineCallResponse::ValueMap)
}
EngineCall::GetCurrentDir => { EngineCall::GetCurrentDir => {
let context = require_context()?;
let current_dir = context.get_current_dir()?; let current_dir = context.get_current_dir()?;
Ok(EngineCallResponse::value(Value::string( Ok(EngineCallResponse::value(Value::string(
current_dir.item, current_dir.item,
current_dir.span, current_dir.span,
))) )))
} }
EngineCall::AddEnvVar(name, value) => {
context.add_env_var(name, value)?;
Ok(EngineCallResponse::empty())
}
EngineCall::EvalClosure { EngineCall::EvalClosure {
closure, closure,
positional, positional,
input, input,
redirect_stdout, redirect_stdout,
redirect_stderr, redirect_stderr,
} => require_context()? } => context
.eval_closure(closure, positional, input, redirect_stdout, redirect_stderr) .eval_closure(closure, positional, input, redirect_stdout, redirect_stderr)
.map(EngineCallResponse::PipelineData), .map(EngineCallResponse::PipelineData),
} }

View File

@ -1,7 +1,4 @@
use std::{ use std::{sync::mpsc, time::Duration};
sync::{mpsc, Arc},
time::Duration,
};
use nu_protocol::{ use nu_protocol::{
engine::Closure, IntoInterruptiblePipelineData, PipelineData, PluginSignature, ShellError, engine::Closure, IntoInterruptiblePipelineData, PipelineData, PluginSignature, ShellError,
@ -24,8 +21,7 @@ use crate::{
}; };
use super::{ use super::{
Context, PluginCallSubscription, PluginInterface, PluginInterfaceManager, Context, PluginCallState, PluginInterface, PluginInterfaceManager, ReceivedPluginCallMessage,
ReceivedPluginCallMessage,
}; };
#[test] #[test]
@ -187,11 +183,12 @@ fn fake_plugin_call(
// Set up a fake plugin call subscription // Set up a fake plugin call subscription
let (tx, rx) = mpsc::channel(); let (tx, rx) = mpsc::channel();
manager.plugin_call_subscriptions.insert( manager.plugin_call_states.insert(
id, id,
PluginCallSubscription { PluginCallState {
sender: Some(tx), sender: Some(tx),
context: None, ctrlc: None,
context_rx: None,
remaining_streams_to_read: 0, remaining_streams_to_read: 0,
}, },
); );
@ -388,7 +385,7 @@ fn manager_consume_call_response_registers_streams() -> Result<(), ShellError> {
))?; ))?;
// ListStream should have one // ListStream should have one
if let Some(sub) = manager.plugin_call_subscriptions.get(&0) { if let Some(sub) = manager.plugin_call_states.get(&0) {
assert_eq!( assert_eq!(
1, sub.remaining_streams_to_read, 1, sub.remaining_streams_to_read,
"ListStream remaining_streams_to_read should be 1" "ListStream remaining_streams_to_read should be 1"
@ -403,7 +400,7 @@ fn manager_consume_call_response_registers_streams() -> Result<(), ShellError> {
); );
// ExternalStream should have three // ExternalStream should have three
if let Some(sub) = manager.plugin_call_subscriptions.get(&1) { if let Some(sub) = manager.plugin_call_states.get(&1) {
assert_eq!( assert_eq!(
3, sub.remaining_streams_to_read, 3, sub.remaining_streams_to_read,
"ExternalStream remaining_streams_to_read should be 3" "ExternalStream remaining_streams_to_read should be 3"
@ -483,20 +480,25 @@ fn manager_handle_engine_call_after_response_received() -> Result<(), ShellError
let mut manager = test.plugin("test"); let mut manager = test.plugin("test");
manager.protocol_info = Some(ProtocolInfo::default()); manager.protocol_info = Some(ProtocolInfo::default());
let bogus = Context(Arc::new(PluginExecutionBogusContext)); let (context_tx, context_rx) = mpsc::channel();
// Set up a situation identical to what we would find if the response had been read, but there // Set up a situation identical to what we would find if the response had been read, but there
// was still a stream being processed. We have nowhere to send the engine call in that case, // was still a stream being processed. We have nowhere to send the engine call in that case,
// so the manager has to create a place to handle it. // so the manager has to create a place to handle it.
manager.plugin_call_subscriptions.insert( manager.plugin_call_states.insert(
0, 0,
PluginCallSubscription { PluginCallState {
sender: None, sender: None,
context: Some(bogus), ctrlc: None,
context_rx: Some(context_rx),
remaining_streams_to_read: 1, remaining_streams_to_read: 1,
}, },
); );
// The engine will get the context from the channel
let bogus = Context(Box::new(PluginExecutionBogusContext));
context_tx.send(bogus).expect("failed to send");
manager.send_engine_call(0, 0, EngineCall::GetConfig)?; manager.send_engine_call(0, 0, EngineCall::GetConfig)?;
// Not really much choice but to wait here, as the thread will have been spawned in the // Not really much choice but to wait here, as the thread will have been spawned in the
@ -528,7 +530,7 @@ fn manager_handle_engine_call_after_response_received() -> Result<(), ShellError
// Whatever was used to make this happen should have been held onto, since spawning a thread // Whatever was used to make this happen should have been held onto, since spawning a thread
// is expensive // is expensive
let sender = &manager let sender = &manager
.plugin_call_subscriptions .plugin_call_states
.get(&0) .get(&0)
.expect("missing subscription 0") .expect("missing subscription 0")
.sender; .sender;
@ -546,11 +548,12 @@ fn manager_send_plugin_call_response_removes_context_only_if_no_streams_to_read(
let mut manager = TestCase::new().plugin("test"); let mut manager = TestCase::new().plugin("test");
for n in [0, 1] { for n in [0, 1] {
manager.plugin_call_subscriptions.insert( manager.plugin_call_states.insert(
n, n,
PluginCallSubscription { PluginCallState {
sender: None, sender: None,
context: None, ctrlc: None,
context_rx: None,
remaining_streams_to_read: n as i32, remaining_streams_to_read: n as i32,
}, },
); );
@ -562,11 +565,11 @@ fn manager_send_plugin_call_response_removes_context_only_if_no_streams_to_read(
// 0 should not still be present, but 1 should be // 0 should not still be present, but 1 should be
assert!( assert!(
!manager.plugin_call_subscriptions.contains_key(&0), !manager.plugin_call_states.contains_key(&0),
"didn't clean up when there weren't remaining streams" "didn't clean up when there weren't remaining streams"
); );
assert!( assert!(
manager.plugin_call_subscriptions.contains_key(&1), manager.plugin_call_states.contains_key(&1),
"clean up even though there were remaining streams" "clean up even though there were remaining streams"
); );
Ok(()) Ok(())
@ -578,11 +581,12 @@ fn manager_consume_stream_end_removes_context_only_if_last_stream() -> Result<()
manager.protocol_info = Some(ProtocolInfo::default()); manager.protocol_info = Some(ProtocolInfo::default());
for n in [1, 2] { for n in [1, 2] {
manager.plugin_call_subscriptions.insert( manager.plugin_call_states.insert(
n, n,
PluginCallSubscription { PluginCallState {
sender: None, sender: None,
context: None, ctrlc: None,
context_rx: None,
remaining_streams_to_read: n as i32, remaining_streams_to_read: n as i32,
}, },
); );
@ -608,21 +612,21 @@ fn manager_consume_stream_end_removes_context_only_if_last_stream() -> Result<()
// Ending 10 should cause 1 to be removed // Ending 10 should cause 1 to be removed
manager.consume(StreamMessage::End(10).into())?; manager.consume(StreamMessage::End(10).into())?;
assert!( assert!(
!manager.plugin_call_subscriptions.contains_key(&1), !manager.plugin_call_states.contains_key(&1),
"contains(1) after End(10)" "contains(1) after End(10)"
); );
// Ending 21 should not cause 2 to be removed // Ending 21 should not cause 2 to be removed
manager.consume(StreamMessage::End(21).into())?; manager.consume(StreamMessage::End(21).into())?;
assert!( assert!(
manager.plugin_call_subscriptions.contains_key(&2), manager.plugin_call_states.contains_key(&2),
"!contains(2) after End(21)" "!contains(2) after End(21)"
); );
// Ending 22 should cause 2 to be removed // Ending 22 should cause 2 to be removed
manager.consume(StreamMessage::End(22).into())?; manager.consume(StreamMessage::End(22).into())?;
assert!( assert!(
!manager.plugin_call_subscriptions.contains_key(&2), !manager.plugin_call_states.contains_key(&2),
"contains(2) after End(22)" "contains(2) after End(22)"
); );
@ -728,18 +732,15 @@ fn interface_goodbye() -> Result<(), ShellError> {
fn interface_write_plugin_call_registers_subscription() -> Result<(), ShellError> { fn interface_write_plugin_call_registers_subscription() -> Result<(), ShellError> {
let mut manager = TestCase::new().plugin("test"); let mut manager = TestCase::new().plugin("test");
assert!( assert!(
manager.plugin_call_subscriptions.is_empty(), manager.plugin_call_states.is_empty(),
"plugin call subscriptions not empty before start of test" "plugin call subscriptions not empty before start of test"
); );
let interface = manager.get_interface(); let interface = manager.get_interface();
let _ = interface.write_plugin_call(PluginCall::Signature, None)?; let _ = interface.write_plugin_call(PluginCall::Signature, None, mpsc::channel().1)?;
manager.receive_plugin_call_subscriptions(); manager.receive_plugin_call_subscriptions();
assert!( assert!(!manager.plugin_call_states.is_empty(), "not registered");
!manager.plugin_call_subscriptions.is_empty(),
"not registered"
);
Ok(()) Ok(())
} }
@ -749,7 +750,8 @@ fn interface_write_plugin_call_writes_signature() -> Result<(), ShellError> {
let manager = test.plugin("test"); let manager = test.plugin("test");
let interface = manager.get_interface(); let interface = manager.get_interface();
let (writer, _) = interface.write_plugin_call(PluginCall::Signature, None)?; let (writer, _) =
interface.write_plugin_call(PluginCall::Signature, None, mpsc::channel().1)?;
writer.write()?; writer.write()?;
let written = test.next_written().expect("nothing written"); let written = test.next_written().expect("nothing written");
@ -778,6 +780,7 @@ fn interface_write_plugin_call_writes_custom_value_op() -> Result<(), ShellError
CustomValueOp::ToBaseValue, CustomValueOp::ToBaseValue,
), ),
None, None,
mpsc::channel().1,
)?; )?;
writer.write()?; writer.write()?;
@ -812,6 +815,7 @@ fn interface_write_plugin_call_writes_run_with_value_input() -> Result<(), Shell
input: PipelineData::Value(Value::test_int(-1), None), input: PipelineData::Value(Value::test_int(-1), None),
}), }),
None, None,
mpsc::channel().1,
)?; )?;
writer.write()?; writer.write()?;
@ -850,6 +854,7 @@ fn interface_write_plugin_call_writes_run_with_stream_input() -> Result<(), Shel
input: values.clone().into_pipeline_data(None), input: values.clone().into_pipeline_data(None),
}), }),
None, None,
mpsc::channel().1,
)?; )?;
writer.write()?; writer.write()?;
@ -912,7 +917,7 @@ fn interface_receive_plugin_call_receives_response() -> Result<(), ShellError> {
.expect("failed to send on new channel"); .expect("failed to send on new channel");
drop(tx); // so we don't deadlock on recv() drop(tx); // so we don't deadlock on recv()
let response = interface.receive_plugin_call_response(rx, &None)?; let response = interface.receive_plugin_call_response(rx, None, mpsc::channel().0)?;
assert!( assert!(
matches!(response, PluginCallResponse::Signature(_)), matches!(response, PluginCallResponse::Signature(_)),
"wrong response: {response:?}" "wrong response: {response:?}"
@ -935,7 +940,7 @@ fn interface_receive_plugin_call_receives_error() -> Result<(), ShellError> {
drop(tx); // so we don't deadlock on recv() drop(tx); // so we don't deadlock on recv()
let error = interface let error = interface
.receive_plugin_call_response(rx, &None) .receive_plugin_call_response(rx, None, mpsc::channel().0)
.expect_err("did not receive error"); .expect_err("did not receive error");
assert!( assert!(
matches!(error, ShellError::ExternalNotSupported { .. }), matches!(error, ShellError::ExternalNotSupported { .. }),
@ -958,13 +963,13 @@ fn interface_receive_plugin_call_handles_engine_call() -> Result<(), ShellError>
.expect("failed to send on new channel"); .expect("failed to send on new channel");
// The context should be a bogus context, which will return an error for GetConfig // The context should be a bogus context, which will return an error for GetConfig
let context = Some(Context(Arc::new(PluginExecutionBogusContext))); let mut context = PluginExecutionBogusContext;
// We don't actually send a response, so `receive_plugin_call_response` should actually return // We don't actually send a response, so `receive_plugin_call_response` should actually return
// an error, but it should still do the engine call // an error, but it should still do the engine call
drop(tx); drop(tx);
interface interface
.receive_plugin_call_response(rx, &context) .receive_plugin_call_response(rx, Some(&mut context), mpsc::channel().0)
.expect_err("no error even though there was no response"); .expect_err("no error even though there was no response");
// Check for the engine call response output // Check for the engine call response output
@ -996,15 +1001,16 @@ fn start_fake_plugin_call_responder(
std::thread::Builder::new() std::thread::Builder::new()
.name("fake plugin call responder".into()) .name("fake plugin call responder".into())
.spawn(move || { .spawn(move || {
for (id, sub) in manager for (id, state) in manager
.plugin_call_subscription_receiver .plugin_call_subscription_receiver
.into_iter() .into_iter()
.take(take) .take(take)
{ {
for message in f(id) { for message in f(id) {
sub.sender state
.sender
.as_ref() .as_ref()
.expect("sender is None") .expect("sender was not set")
.send(message) .send(message)
.expect("failed to send"); .expect("failed to send");
} }
@ -1055,7 +1061,7 @@ fn interface_run() -> Result<(), ShellError> {
}, },
input: PipelineData::Empty, input: PipelineData::Empty,
}, },
PluginExecutionBogusContext.into(), &mut PluginExecutionBogusContext,
)?; )?;
assert_eq!( assert_eq!(

View File

@ -348,6 +348,21 @@ impl PluginCallResponse<PipelineDataHeader> {
} }
} }
impl PluginCallResponse<PipelineData> {
/// Does this response have a stream?
pub(crate) fn has_stream(&self) -> bool {
match self {
PluginCallResponse::PipelineData(data) => match data {
PipelineData::Empty => false,
PipelineData::Value(..) => false,
PipelineData::ListStream(..) => true,
PipelineData::ExternalStream { .. } => true,
},
_ => false,
}
}
}
/// Options that can be changed to affect how the engine treats the plugin /// Options that can be changed to affect how the engine treats the plugin
#[derive(Serialize, Deserialize, Debug, Clone)] #[derive(Serialize, Deserialize, Debug, Clone)]
pub enum PluginOption { pub enum PluginOption {
@ -447,6 +462,8 @@ pub enum EngineCall<D> {
GetEnvVars, GetEnvVars,
/// Get current working directory /// Get current working directory
GetCurrentDir, GetCurrentDir,
/// Set an environment variable in the caller's scope
AddEnvVar(String, Value),
/// Evaluate a closure with stream input/output /// Evaluate a closure with stream input/output
EvalClosure { EvalClosure {
/// The closure to call. /// The closure to call.
@ -473,6 +490,7 @@ impl<D> EngineCall<D> {
EngineCall::GetEnvVar(_) => "GetEnv", EngineCall::GetEnvVar(_) => "GetEnv",
EngineCall::GetEnvVars => "GetEnvs", EngineCall::GetEnvVars => "GetEnvs",
EngineCall::GetCurrentDir => "GetCurrentDir", EngineCall::GetCurrentDir => "GetCurrentDir",
EngineCall::AddEnvVar(..) => "AddEnvVar",
EngineCall::EvalClosure { .. } => "EvalClosure", EngineCall::EvalClosure { .. } => "EvalClosure",
} }
} }

View File

@ -0,0 +1,3 @@
mod mutable_cow;
pub(crate) use mutable_cow::*;

View File

@ -0,0 +1,35 @@
/// Like [`Cow`] but with a mutable reference instead. So not exactly clone-on-write, but can be
/// made owned.
pub enum MutableCow<'a, T> {
Borrowed(&'a mut T),
Owned(T),
}
impl<'a, T: Clone> MutableCow<'a, T> {
pub fn owned(&self) -> MutableCow<'static, T> {
match self {
MutableCow::Borrowed(r) => MutableCow::Owned((*r).clone()),
MutableCow::Owned(o) => MutableCow::Owned(o.clone()),
}
}
}
impl<'a, T> std::ops::Deref for MutableCow<'a, T> {
type Target = T;
fn deref(&self) -> &T {
match self {
MutableCow::Borrowed(r) => r,
MutableCow::Owned(o) => o,
}
}
}
impl<'a, T> std::ops::DerefMut for MutableCow<'a, T> {
fn deref_mut(&mut self) -> &mut Self::Target {
match self {
MutableCow::Borrowed(r) => r,
MutableCow::Owned(o) => o,
}
}
}

View File

@ -19,6 +19,12 @@ impl SimplePluginCommand for NuExampleEnv {
"The name of the environment variable to get", "The name of the environment variable to get",
) )
.switch("cwd", "Get current working directory instead", None) .switch("cwd", "Get current working directory instead", None)
.named(
"set",
SyntaxShape::Any,
"Set an environment variable to the value",
None,
)
.search_terms(vec!["example".into(), "env".into()]) .search_terms(vec!["example".into(), "env".into()])
.input_output_type(Type::Nothing, Type::Any) .input_output_type(Type::Nothing, Type::Any)
} }
@ -31,8 +37,22 @@ impl SimplePluginCommand for NuExampleEnv {
_input: &Value, _input: &Value,
) -> Result<Value, LabeledError> { ) -> Result<Value, LabeledError> {
if call.has_flag("cwd")? { if call.has_flag("cwd")? {
match call.get_flag_value("set") {
None => {
// Get working directory // Get working directory
Ok(Value::string(engine.get_current_dir()?, call.head)) Ok(Value::string(engine.get_current_dir()?, call.head))
}
Some(value) => Err(LabeledError {
label: "Invalid arguments".into(),
msg: "--cwd can't be used with --set".into(),
span: Some(value.span()),
}),
}
} else if let Some(value) = call.get_flag_value("set") {
// Set single env var
let name = call.req::<String>(0)?;
engine.add_env_var(name, value)?;
Ok(Value::nothing(call.head))
} else if let Some(name) = call.opt::<String>(0)? { } else if let Some(name) = call.opt::<String>(0)? {
// Get single env var // Get single env var
Ok(engine Ok(engine

View File

@ -42,3 +42,14 @@ fn get_current_dir() {
assert!(result.status.success()); assert!(result.status.success());
assert_eq!(cwd, result.out); assert_eq!(cwd, result.out);
} }
#[test]
fn set_env() {
let result = nu_with_plugins!(
cwd: ".",
plugin: ("nu_plugin_example"),
"nu-example-env NUSHELL_OPINION --set=rocks; $env.NUSHELL_OPINION"
);
assert!(result.status.success());
assert_eq!("rocks", result.out);
}