Allow plugins to set environment variables in their caller's scope (#12204)

# Description

Adds the `AddEnvVar` plugin call, which allows plugins to set
environment variables in the caller's scope. This is the first engine
call that mutates the caller's stack, and opens the door to more
operations like this if needed.

This also comes with an extra benefit: in doing this, I needed to
refactor how context was handled, and I was able to avoid cloning
`EngineInterface` / `Stack` / `Call` in most cases that plugin calls are
used. They now only need to be cloned if the plugin call returns a
stream. The performance increase is welcome (5.5x faster on `inc`!):

```nushell
# Before
> timeit { 1..100 | each { |i| $"2.0.($i)" | inc -p } }
405ms 941µs 952ns
# After
> timeit { 1..100 | each { |i| $"2.0.($i)" | inc -p } }
73ms 68µs 749ns
```

# User-Facing Changes
- New engine call: `add_env_var()`
- Performance enhancement for plugin calls

# Tests + Formatting
- 🟢 `toolkit fmt`
- 🟢 `toolkit clippy`
- 🟢 `toolkit test`
- 🟢 `toolkit test stdlib`

# After Submitting
- [x] Document env manipulation in plugins guide
- [x] Document `AddEnvVar` in plugin protocol
This commit is contained in:
Devyn Cairns 2024-03-15 04:45:45 -07:00 committed by GitHub
parent 687fbc49c8
commit f6faf73e02
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
12 changed files with 329 additions and 142 deletions

View File

@ -59,6 +59,7 @@ mod plugin;
mod protocol;
mod sequence;
mod serializers;
mod util;
pub use plugin::{
serve_plugin, EngineInterface, Plugin, PluginCommand, PluginEncoder, SimplePluginCommand,

View File

@ -1,4 +1,5 @@
use std::{
borrow::Cow,
collections::HashMap,
sync::{atomic::AtomicBool, Arc},
};
@ -10,6 +11,8 @@ use nu_protocol::{
Config, IntoSpanned, IoStream, PipelineData, PluginIdentity, ShellError, Span, Spanned, Value,
};
use crate::util::MutableCow;
/// Object safe trait for abstracting operations required of the plugin context.
pub(crate) trait PluginExecutionContext: Send + Sync {
/// The [Span] for the command execution (`call.head`)
@ -26,8 +29,10 @@ pub(crate) trait PluginExecutionContext: Send + Sync {
fn get_env_var(&self, name: &str) -> Result<Option<Value>, ShellError>;
/// Get all environment variables
fn get_env_vars(&self) -> Result<HashMap<String, Value>, ShellError>;
// Get current working directory
/// Get current working directory
fn get_current_dir(&self) -> Result<Spanned<String>, ShellError>;
/// Set an environment variable
fn add_env_var(&mut self, name: String, value: Value) -> Result<(), ShellError>;
/// Evaluate a closure passed to the plugin
fn eval_closure(
&self,
@ -37,33 +42,35 @@ pub(crate) trait PluginExecutionContext: Send + Sync {
redirect_stdout: bool,
redirect_stderr: bool,
) -> Result<PipelineData, ShellError>;
/// Create an owned version of the context with `'static` lifetime
fn boxed(&self) -> Box<dyn PluginExecutionContext>;
}
/// The execution context of a plugin command.
pub(crate) struct PluginExecutionCommandContext {
/// The execution context of a plugin command. Can be borrowed.
pub(crate) struct PluginExecutionCommandContext<'a> {
identity: Arc<PluginIdentity>,
engine_state: EngineState,
stack: Stack,
call: Call,
engine_state: Cow<'a, EngineState>,
stack: MutableCow<'a, Stack>,
call: Cow<'a, Call>,
}
impl PluginExecutionCommandContext {
impl<'a> PluginExecutionCommandContext<'a> {
pub fn new(
identity: Arc<PluginIdentity>,
engine_state: &EngineState,
stack: &Stack,
call: &Call,
) -> PluginExecutionCommandContext {
engine_state: &'a EngineState,
stack: &'a mut Stack,
call: &'a Call,
) -> PluginExecutionCommandContext<'a> {
PluginExecutionCommandContext {
identity,
engine_state: engine_state.clone(),
stack: stack.clone(),
call: call.clone(),
engine_state: Cow::Borrowed(engine_state),
stack: MutableCow::Borrowed(stack),
call: Cow::Borrowed(call),
}
}
}
impl PluginExecutionContext for PluginExecutionCommandContext {
impl<'a> PluginExecutionContext for PluginExecutionCommandContext<'a> {
fn command_span(&self) -> Span {
self.call.head
}
@ -131,6 +138,11 @@ impl PluginExecutionContext for PluginExecutionCommandContext {
Ok(cwd.into_spanned(self.call.head))
}
fn add_env_var(&mut self, name: String, value: Value) -> Result<(), ShellError> {
self.stack.add_env_var(name, value);
Ok(())
}
fn eval_closure(
&self,
closure: Spanned<Closure>,
@ -191,6 +203,15 @@ impl PluginExecutionContext for PluginExecutionCommandContext {
eval_block_with_early_return(&self.engine_state, stack, block, input)
}
fn boxed(&self) -> Box<dyn PluginExecutionContext + 'static> {
Box::new(PluginExecutionCommandContext {
identity: self.identity.clone(),
engine_state: Cow::Owned(self.engine_state.clone().into_owned()),
stack: self.stack.owned(),
call: Cow::Owned(self.call.clone().into_owned()),
})
}
}
/// A bogus execution context for testing that doesn't really implement anything properly
@ -239,6 +260,12 @@ impl PluginExecutionContext for PluginExecutionBogusContext {
})
}
fn add_env_var(&mut self, _name: String, _value: Value) -> Result<(), ShellError> {
Err(ShellError::NushellFailed {
msg: "add_env_var not implemented on bogus".into(),
})
}
fn eval_closure(
&self,
_closure: Spanned<Closure>,
@ -251,4 +278,8 @@ impl PluginExecutionContext for PluginExecutionBogusContext {
msg: "eval_closure not implemented on bogus".into(),
})
}
fn boxed(&self) -> Box<dyn PluginExecutionContext + 'static> {
Box::new(PluginExecutionBogusContext)
}
}

View File

@ -108,12 +108,12 @@ impl Command for PluginDeclaration {
})?;
// Create the context to execute in - this supports engine calls and custom values
let context = Arc::new(PluginExecutionCommandContext::new(
let mut context = PluginExecutionCommandContext::new(
self.source.identity.clone(),
engine_state,
stack,
call,
));
);
plugin.run(
CallInfo {
@ -121,7 +121,7 @@ impl Command for PluginDeclaration {
call: evaluated_call,
input,
},
context,
&mut context,
)
}

View File

@ -458,6 +458,9 @@ impl EngineInterface {
EngineCall::GetEnvVar(name) => (EngineCall::GetEnvVar(name), Default::default()),
EngineCall::GetEnvVars => (EngineCall::GetEnvVars, Default::default()),
EngineCall::GetCurrentDir => (EngineCall::GetCurrentDir, Default::default()),
EngineCall::AddEnvVar(name, value) => {
(EngineCall::AddEnvVar(name, value), Default::default())
}
};
// Register the channel
@ -622,6 +625,30 @@ impl EngineInterface {
}
}
/// Set an environment variable in the caller's scope.
///
/// If called after the plugin response has already been sent (i.e. during a stream), this will
/// only affect the environment for engine calls related to this plugin call, and will not be
/// propagated to the environment of the caller.
///
/// # Example
/// ```rust,no_run
/// # use nu_protocol::{Value, ShellError};
/// # use nu_plugin::EngineInterface;
/// # fn example(engine: &EngineInterface) -> Result<(), ShellError> {
/// engine.add_env_var("FOO", Value::test_string("bar"))
/// # }
/// ```
pub fn add_env_var(&self, name: impl Into<String>, value: Value) -> Result<(), ShellError> {
match self.engine_call(EngineCall::AddEnvVar(name.into(), value))? {
EngineCallResponse::PipelineData(_) => Ok(()),
EngineCallResponse::Error(err) => Err(err),
_ => Err(ShellError::PluginFailedToDecode {
msg: "Received unexpected response type for EngineCall::AddEnvVar".into(),
}),
}
}
/// Ask the engine to evaluate a closure. Input to the closure is passed as a stream, and the
/// output is available as a stream.
///

View File

@ -953,6 +953,20 @@ fn interface_get_env_vars() -> Result<(), ShellError> {
Ok(())
}
#[test]
fn interface_add_env_var() -> Result<(), ShellError> {
let test = TestCase::new();
let manager = test.engine();
let interface = manager.interface_for_context(0);
start_fake_plugin_call_responder(manager, 1, move |_| EngineCallResponse::empty());
interface.add_env_var("FOO", Value::test_string("bar"))?;
assert!(test.has_unconsumed_write());
Ok(())
}
#[test]
fn interface_eval_closure_with_stream() -> Result<(), ShellError> {
let test = TestCase::new();

View File

@ -2,7 +2,7 @@
use std::{
collections::{btree_map, BTreeMap},
sync::{mpsc, Arc, OnceLock},
sync::{atomic::AtomicBool, mpsc, Arc, OnceLock},
};
use nu_protocol::{
@ -44,8 +44,7 @@ enum ReceivedPluginCallMessage {
}
/// Context for plugin call execution
#[derive(Clone)]
pub(crate) struct Context(Arc<dyn PluginExecutionContext>);
pub(crate) struct Context(Box<dyn PluginExecutionContext>);
impl std::fmt::Debug for Context {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
@ -70,7 +69,7 @@ struct PluginInterfaceState {
/// Sequence for generating stream ids
stream_id_sequence: Sequence,
/// Sender to subscribe to a plugin call response
plugin_call_subscription_sender: mpsc::Sender<(PluginCallId, PluginCallSubscription)>,
plugin_call_subscription_sender: mpsc::Sender<(PluginCallId, PluginCallState)>,
/// An error that should be propagated to further plugin calls
error: OnceLock<ShellError>,
/// The synchronized output writer
@ -91,14 +90,15 @@ impl std::fmt::Debug for PluginInterfaceState {
}
}
/// Sent to the [`PluginInterfaceManager`] before making a plugin call to indicate interest in its
/// response.
/// State that the manager keeps for each plugin call during its lifetime.
#[derive(Debug)]
struct PluginCallSubscription {
struct PluginCallState {
/// The sender back to the thread that is waiting for the plugin call response
sender: Option<mpsc::Sender<ReceivedPluginCallMessage>>,
/// Optional context for the environment of a plugin call for servicing engine calls
context: Option<Context>,
/// Interrupt signal to be used for stream iterators
ctrlc: Option<Arc<AtomicBool>>,
/// Channel to receive context on to be used if needed
context_rx: Option<mpsc::Receiver<Context>>,
/// Number of streams that still need to be read from the plugin call response
remaining_streams_to_read: i32,
}
@ -112,10 +112,10 @@ pub(crate) struct PluginInterfaceManager {
stream_manager: StreamManager,
/// Protocol version info, set after `Hello` received
protocol_info: Option<ProtocolInfo>,
/// Subscriptions for messages related to plugin calls
plugin_call_subscriptions: BTreeMap<PluginCallId, PluginCallSubscription>,
/// State related to plugin calls
plugin_call_states: BTreeMap<PluginCallId, PluginCallState>,
/// Receiver for plugin call subscriptions
plugin_call_subscription_receiver: mpsc::Receiver<(PluginCallId, PluginCallSubscription)>,
plugin_call_subscription_receiver: mpsc::Receiver<(PluginCallId, PluginCallState)>,
/// Tracker for which plugin call streams being read belong to
///
/// This is necessary so we know when we can remove context for plugin calls
@ -142,7 +142,7 @@ impl PluginInterfaceManager {
}),
stream_manager: StreamManager::new(),
protocol_info: None,
plugin_call_subscriptions: BTreeMap::new(),
plugin_call_states: BTreeMap::new(),
plugin_call_subscription_receiver: subscription_rx,
plugin_call_input_streams: BTreeMap::new(),
gc: None,
@ -158,9 +158,9 @@ impl PluginInterfaceManager {
/// Consume pending messages in the `plugin_call_subscription_receiver`
fn receive_plugin_call_subscriptions(&mut self) {
while let Ok((id, subscription)) = self.plugin_call_subscription_receiver.try_recv() {
if let btree_map::Entry::Vacant(e) = self.plugin_call_subscriptions.entry(id) {
e.insert(subscription);
while let Ok((id, state)) = self.plugin_call_subscription_receiver.try_recv() {
if let btree_map::Entry::Vacant(e) = self.plugin_call_states.entry(id) {
e.insert(state);
} else {
log::warn!("Duplicate plugin call ID ignored: {id}");
}
@ -172,8 +172,8 @@ impl PluginInterfaceManager {
self.plugin_call_input_streams.insert(stream_id, call_id);
// Increment the number of streams on the subscription so context stays alive
self.receive_plugin_call_subscriptions();
if let Some(sub) = self.plugin_call_subscriptions.get_mut(&call_id) {
sub.remaining_streams_to_read += 1;
if let Some(state) = self.plugin_call_states.get_mut(&call_id) {
state.remaining_streams_to_read += 1;
}
// Add a lock to the garbage collector for each stream
if let Some(ref gc) = self.gc {
@ -184,8 +184,7 @@ impl PluginInterfaceManager {
/// Track the end of an incoming stream
fn recv_stream_ended(&mut self, stream_id: StreamId) {
if let Some(call_id) = self.plugin_call_input_streams.remove(&stream_id) {
if let btree_map::Entry::Occupied(mut e) = self.plugin_call_subscriptions.entry(call_id)
{
if let btree_map::Entry::Occupied(mut e) = self.plugin_call_states.entry(call_id) {
e.get_mut().remaining_streams_to_read -= 1;
// Remove the subscription if there are no more streams to be read.
if e.get().remaining_streams_to_read <= 0 {
@ -200,14 +199,14 @@ impl PluginInterfaceManager {
}
}
/// Find the context corresponding to the given plugin call id
fn get_context(&mut self, id: PluginCallId) -> Result<Option<Context>, ShellError> {
/// Find the ctrlc signal corresponding to the given plugin call id
fn get_ctrlc(&mut self, id: PluginCallId) -> Result<Option<Arc<AtomicBool>>, ShellError> {
// Make sure we're up to date
self.receive_plugin_call_subscriptions();
// Find the subscription and return the context
self.plugin_call_subscriptions
self.plugin_call_states
.get(&id)
.map(|sub| sub.context.clone())
.map(|state| state.ctrlc.clone())
.ok_or_else(|| ShellError::PluginFailedToDecode {
msg: format!("Unknown plugin call ID: {id}"),
})
@ -222,7 +221,7 @@ impl PluginInterfaceManager {
// Ensure we're caught up on the subscriptions made
self.receive_plugin_call_subscriptions();
if let btree_map::Entry::Occupied(mut e) = self.plugin_call_subscriptions.entry(id) {
if let btree_map::Entry::Occupied(mut e) = self.plugin_call_states.entry(id) {
// Remove the subscription sender, since this will be the last message.
//
// We can spawn a new one if we need it for engine calls.
@ -254,11 +253,23 @@ impl PluginInterfaceManager {
) -> Result<&mpsc::Sender<ReceivedPluginCallMessage>, ShellError> {
let interface = self.get_interface();
if let Some(sub) = self.plugin_call_subscriptions.get_mut(&id) {
if sub.sender.is_none() {
if let Some(state) = self.plugin_call_states.get_mut(&id) {
if state.sender.is_none() {
let (tx, rx) = mpsc::channel();
let context = sub.context.clone();
let context_rx =
state
.context_rx
.take()
.ok_or_else(|| ShellError::NushellFailed {
msg: "Tried to spawn the fallback engine call handler more than once"
.into(),
})?;
let handler = move || {
// We receive on the thread so that we don't block the reader thread
let mut context = context_rx
.recv()
.ok() // The plugin call won't send context if it's not required.
.map(|c| c.0);
for msg in rx {
// This thread only handles engine calls.
match msg {
@ -266,7 +277,7 @@ impl PluginInterfaceManager {
if let Err(err) = interface.handle_engine_call(
engine_call_id,
engine_call,
&context,
context.as_deref_mut(),
) {
log::warn!(
"Error in plugin post-response engine call handler: \
@ -286,8 +297,8 @@ impl PluginInterfaceManager {
.name("plugin engine call handler".into())
.spawn(handler)
.expect("failed to spawn thread");
sub.sender = Some(tx);
Ok(sub.sender.as_ref().unwrap_or_else(|| unreachable!()))
state.sender = Some(tx);
Ok(state.sender.as_ref().unwrap_or_else(|| unreachable!()))
} else {
Err(ShellError::NushellFailed {
msg: "Tried to spawn the fallback engine call handler before the plugin call \
@ -313,7 +324,7 @@ impl PluginInterfaceManager {
self.receive_plugin_call_subscriptions();
// Don't remove the sender, as there could be more calls or responses
if let Some(subscription) = self.plugin_call_subscriptions.get(&plugin_call_id) {
if let Some(subscription) = self.plugin_call_states.get(&plugin_call_id) {
let msg = ReceivedPluginCallMessage::EngineCall(engine_call_id, call);
// Call if there's an error sending the engine call
let send_error = |this: &Self| {
@ -374,9 +385,7 @@ impl PluginInterfaceManager {
let _ = self.stream_manager.broadcast_read_error(err.clone());
// Error to call waiters
self.receive_plugin_call_subscriptions();
for subscription in
std::mem::take(&mut self.plugin_call_subscriptions).into_values()
{
for subscription in std::mem::take(&mut self.plugin_call_states).into_values() {
let _ = subscription
.sender
.as_ref()
@ -460,15 +469,14 @@ impl InterfaceManager for PluginInterfaceManager {
PluginCallResponse::PipelineData(data) => {
// If there's an error with initializing this stream, change it to a plugin
// error response, but send it anyway
let exec_context = self.get_context(id)?;
let ctrlc = exec_context.as_ref().and_then(|c| c.0.ctrlc());
let ctrlc = self.get_ctrlc(id)?;
// Register the streams in the response
for stream_id in data.stream_ids() {
self.recv_stream_started(id, stream_id);
}
match self.read_pipeline_data(data, ctrlc) {
match self.read_pipeline_data(data, ctrlc.as_ref()) {
Ok(data) => PluginCallResponse::PipelineData(data),
Err(err) => PluginCallResponse::Error(err.into()),
}
@ -485,14 +493,14 @@ impl InterfaceManager for PluginInterfaceManager {
}
PluginOutput::EngineCall { context, id, call } => {
// Handle reading the pipeline data, if any
let exec_context = self.get_context(context)?;
let ctrlc = exec_context.as_ref().and_then(|c| c.0.ctrlc());
let ctrlc = self.get_ctrlc(context)?;
let call = match call {
EngineCall::GetConfig => Ok(EngineCall::GetConfig),
EngineCall::GetPluginConfig => Ok(EngineCall::GetPluginConfig),
EngineCall::GetEnvVar(name) => Ok(EngineCall::GetEnvVar(name)),
EngineCall::GetEnvVars => Ok(EngineCall::GetEnvVars),
EngineCall::GetCurrentDir => Ok(EngineCall::GetCurrentDir),
EngineCall::AddEnvVar(name, value) => Ok(EngineCall::AddEnvVar(name, value)),
EngineCall::EvalClosure {
closure,
mut positional,
@ -504,14 +512,15 @@ impl InterfaceManager for PluginInterfaceManager {
for arg in positional.iter_mut() {
PluginCustomValue::add_source(arg, &self.state.source);
}
self.read_pipeline_data(input, ctrlc)
.map(|input| EngineCall::EvalClosure {
self.read_pipeline_data(input, ctrlc.as_ref()).map(|input| {
EngineCall::EvalClosure {
closure,
positional,
input,
redirect_stdout,
redirect_stderr,
})
}
})
}
};
match call {
@ -622,7 +631,8 @@ impl PluginInterface {
fn write_plugin_call(
&self,
call: PluginCall<PipelineData>,
context: Option<Context>,
ctrlc: Option<Arc<AtomicBool>>,
context_rx: mpsc::Receiver<Context>,
) -> Result<
(
PipelineDataWriter<Self>,
@ -662,9 +672,10 @@ impl PluginInterface {
.plugin_call_subscription_sender
.send((
id,
PluginCallSubscription {
PluginCallState {
sender: Some(tx),
context,
ctrlc,
context_rx: Some(context_rx),
remaining_streams_to_read: 0,
},
))
@ -703,19 +714,26 @@ impl PluginInterface {
fn receive_plugin_call_response(
&self,
rx: mpsc::Receiver<ReceivedPluginCallMessage>,
context: &Option<Context>,
mut context: Option<&mut (dyn PluginExecutionContext + '_)>,
context_tx: mpsc::Sender<Context>,
) -> Result<PluginCallResponse<PipelineData>, ShellError> {
// Handle message from receiver
for msg in rx {
match msg {
ReceivedPluginCallMessage::Response(resp) => {
if resp.has_stream() {
// If the response has a stream, we need to register the context
if let Some(context) = context {
let _ = context_tx.send(Context(context.boxed()));
}
}
return Ok(resp);
}
ReceivedPluginCallMessage::Error(err) => {
return Err(err);
}
ReceivedPluginCallMessage::EngineCall(engine_call_id, engine_call) => {
self.handle_engine_call(engine_call_id, engine_call, context)?;
self.handle_engine_call(engine_call_id, engine_call, context.as_deref_mut())?;
}
}
}
@ -730,7 +748,7 @@ impl PluginInterface {
&self,
engine_call_id: EngineCallId,
engine_call: EngineCall<PipelineData>,
context: &Option<Context>,
context: Option<&mut (dyn PluginExecutionContext + '_)>,
) -> Result<(), ShellError> {
let resp =
handle_engine_call(engine_call, context).unwrap_or_else(EngineCallResponse::Error);
@ -763,7 +781,7 @@ impl PluginInterface {
fn plugin_call(
&self,
call: PluginCall<PipelineData>,
context: &Option<Context>,
context: Option<&mut dyn PluginExecutionContext>,
) -> Result<PluginCallResponse<PipelineData>, ShellError> {
// Check for an error in the state first, and return it if set.
if let Some(error) = self.state.error.get() {
@ -777,17 +795,24 @@ impl PluginInterface {
gc.increment_locks(1);
}
let (writer, rx) = self.write_plugin_call(call, context.clone())?;
// Create the channel to send context on if needed
let (context_tx, context_rx) = mpsc::channel();
let (writer, rx) = self.write_plugin_call(
call,
context.as_ref().and_then(|c| c.ctrlc().cloned()),
context_rx,
)?;
// Finish writing stream in the background
writer.write_background()?;
self.receive_plugin_call_response(rx, context)
self.receive_plugin_call_response(rx, context, context_tx)
}
/// Get the command signatures from the plugin.
pub(crate) fn get_signature(&self) -> Result<Vec<PluginSignature>, ShellError> {
match self.plugin_call(PluginCall::Signature, &None)? {
match self.plugin_call(PluginCall::Signature, None)? {
PluginCallResponse::Signature(sigs) => Ok(sigs),
PluginCallResponse::Error(err) => Err(err.into()),
_ => Err(ShellError::PluginFailedToDecode {
@ -800,10 +825,9 @@ impl PluginInterface {
pub(crate) fn run(
&self,
call: CallInfo<PipelineData>,
context: Arc<impl PluginExecutionContext + 'static>,
context: &mut dyn PluginExecutionContext,
) -> Result<PipelineData, ShellError> {
let context = Some(Context(context));
match self.plugin_call(PluginCall::Run(call), &context)? {
match self.plugin_call(PluginCall::Run(call), Some(context))? {
PluginCallResponse::PipelineData(data) => Ok(data),
PluginCallResponse::Error(err) => Err(err.into()),
_ => Err(ShellError::PluginFailedToDecode {
@ -821,7 +845,7 @@ impl PluginInterface {
let op_name = op.name();
let span = value.span;
let call = PluginCall::CustomValueOp(value, op);
match self.plugin_call(call, &None)? {
match self.plugin_call(call, None)? {
PluginCallResponse::PipelineData(out_data) => Ok(out_data.into_value(span)),
PluginCallResponse::Error(err) => Err(err.into()),
_ => Err(ShellError::PluginFailedToDecode {
@ -869,7 +893,7 @@ impl PluginInterface {
value.into_spanned(Span::unknown()),
CustomValueOp::PartialCmp(other_value),
);
match self.plugin_call(call, &None)? {
match self.plugin_call(call, None)? {
PluginCallResponse::Ordering(ordering) => Ok(ordering),
PluginCallResponse::Error(err) => Err(err.into()),
_ => Err(ShellError::PluginFailedToDecode {
@ -977,56 +1001,53 @@ impl Drop for PluginInterface {
/// Handle an engine call.
pub(crate) fn handle_engine_call(
call: EngineCall<PipelineData>,
context: &Option<Context>,
context: Option<&mut (dyn PluginExecutionContext + '_)>,
) -> Result<EngineCallResponse<PipelineData>, ShellError> {
let call_name = call.name();
let require_context = || {
context.as_ref().ok_or_else(|| ShellError::GenericError {
error: "A plugin execution context is required for this engine call".into(),
msg: format!(
"attempted to call {} outside of a command invocation",
call_name
),
span: None,
help: Some("this is probably a bug with the plugin".into()),
inner: vec![],
})
};
let context = context.ok_or_else(|| ShellError::GenericError {
error: "A plugin execution context is required for this engine call".into(),
msg: format!(
"attempted to call {} outside of a command invocation",
call_name
),
span: None,
help: Some("this is probably a bug with the plugin".into()),
inner: vec![],
})?;
match call {
EngineCall::GetConfig => {
let context = require_context()?;
let config = Box::new(context.get_config()?);
Ok(EngineCallResponse::Config(config))
}
EngineCall::GetPluginConfig => {
let context = require_context()?;
let plugin_config = context.get_plugin_config()?;
Ok(plugin_config.map_or_else(EngineCallResponse::empty, EngineCallResponse::value))
}
EngineCall::GetEnvVar(name) => {
let context = require_context()?;
let value = context.get_env_var(&name)?;
Ok(value.map_or_else(EngineCallResponse::empty, EngineCallResponse::value))
}
EngineCall::GetEnvVars => {
let context = require_context()?;
context.get_env_vars().map(EngineCallResponse::ValueMap)
}
EngineCall::GetEnvVars => context.get_env_vars().map(EngineCallResponse::ValueMap),
EngineCall::GetCurrentDir => {
let context = require_context()?;
let current_dir = context.get_current_dir()?;
Ok(EngineCallResponse::value(Value::string(
current_dir.item,
current_dir.span,
)))
}
EngineCall::AddEnvVar(name, value) => {
context.add_env_var(name, value)?;
Ok(EngineCallResponse::empty())
}
EngineCall::EvalClosure {
closure,
positional,
input,
redirect_stdout,
redirect_stderr,
} => require_context()?
} => context
.eval_closure(closure, positional, input, redirect_stdout, redirect_stderr)
.map(EngineCallResponse::PipelineData),
}

View File

@ -1,7 +1,4 @@
use std::{
sync::{mpsc, Arc},
time::Duration,
};
use std::{sync::mpsc, time::Duration};
use nu_protocol::{
engine::Closure, IntoInterruptiblePipelineData, PipelineData, PluginSignature, ShellError,
@ -24,8 +21,7 @@ use crate::{
};
use super::{
Context, PluginCallSubscription, PluginInterface, PluginInterfaceManager,
ReceivedPluginCallMessage,
Context, PluginCallState, PluginInterface, PluginInterfaceManager, ReceivedPluginCallMessage,
};
#[test]
@ -187,11 +183,12 @@ fn fake_plugin_call(
// Set up a fake plugin call subscription
let (tx, rx) = mpsc::channel();
manager.plugin_call_subscriptions.insert(
manager.plugin_call_states.insert(
id,
PluginCallSubscription {
PluginCallState {
sender: Some(tx),
context: None,
ctrlc: None,
context_rx: None,
remaining_streams_to_read: 0,
},
);
@ -388,7 +385,7 @@ fn manager_consume_call_response_registers_streams() -> Result<(), ShellError> {
))?;
// ListStream should have one
if let Some(sub) = manager.plugin_call_subscriptions.get(&0) {
if let Some(sub) = manager.plugin_call_states.get(&0) {
assert_eq!(
1, sub.remaining_streams_to_read,
"ListStream remaining_streams_to_read should be 1"
@ -403,7 +400,7 @@ fn manager_consume_call_response_registers_streams() -> Result<(), ShellError> {
);
// ExternalStream should have three
if let Some(sub) = manager.plugin_call_subscriptions.get(&1) {
if let Some(sub) = manager.plugin_call_states.get(&1) {
assert_eq!(
3, sub.remaining_streams_to_read,
"ExternalStream remaining_streams_to_read should be 3"
@ -483,20 +480,25 @@ fn manager_handle_engine_call_after_response_received() -> Result<(), ShellError
let mut manager = test.plugin("test");
manager.protocol_info = Some(ProtocolInfo::default());
let bogus = Context(Arc::new(PluginExecutionBogusContext));
let (context_tx, context_rx) = mpsc::channel();
// Set up a situation identical to what we would find if the response had been read, but there
// was still a stream being processed. We have nowhere to send the engine call in that case,
// so the manager has to create a place to handle it.
manager.plugin_call_subscriptions.insert(
manager.plugin_call_states.insert(
0,
PluginCallSubscription {
PluginCallState {
sender: None,
context: Some(bogus),
ctrlc: None,
context_rx: Some(context_rx),
remaining_streams_to_read: 1,
},
);
// The engine will get the context from the channel
let bogus = Context(Box::new(PluginExecutionBogusContext));
context_tx.send(bogus).expect("failed to send");
manager.send_engine_call(0, 0, EngineCall::GetConfig)?;
// Not really much choice but to wait here, as the thread will have been spawned in the
@ -528,7 +530,7 @@ fn manager_handle_engine_call_after_response_received() -> Result<(), ShellError
// Whatever was used to make this happen should have been held onto, since spawning a thread
// is expensive
let sender = &manager
.plugin_call_subscriptions
.plugin_call_states
.get(&0)
.expect("missing subscription 0")
.sender;
@ -546,11 +548,12 @@ fn manager_send_plugin_call_response_removes_context_only_if_no_streams_to_read(
let mut manager = TestCase::new().plugin("test");
for n in [0, 1] {
manager.plugin_call_subscriptions.insert(
manager.plugin_call_states.insert(
n,
PluginCallSubscription {
PluginCallState {
sender: None,
context: None,
ctrlc: None,
context_rx: None,
remaining_streams_to_read: n as i32,
},
);
@ -562,11 +565,11 @@ fn manager_send_plugin_call_response_removes_context_only_if_no_streams_to_read(
// 0 should not still be present, but 1 should be
assert!(
!manager.plugin_call_subscriptions.contains_key(&0),
!manager.plugin_call_states.contains_key(&0),
"didn't clean up when there weren't remaining streams"
);
assert!(
manager.plugin_call_subscriptions.contains_key(&1),
manager.plugin_call_states.contains_key(&1),
"clean up even though there were remaining streams"
);
Ok(())
@ -578,11 +581,12 @@ fn manager_consume_stream_end_removes_context_only_if_last_stream() -> Result<()
manager.protocol_info = Some(ProtocolInfo::default());
for n in [1, 2] {
manager.plugin_call_subscriptions.insert(
manager.plugin_call_states.insert(
n,
PluginCallSubscription {
PluginCallState {
sender: None,
context: None,
ctrlc: None,
context_rx: None,
remaining_streams_to_read: n as i32,
},
);
@ -608,21 +612,21 @@ fn manager_consume_stream_end_removes_context_only_if_last_stream() -> Result<()
// Ending 10 should cause 1 to be removed
manager.consume(StreamMessage::End(10).into())?;
assert!(
!manager.plugin_call_subscriptions.contains_key(&1),
!manager.plugin_call_states.contains_key(&1),
"contains(1) after End(10)"
);
// Ending 21 should not cause 2 to be removed
manager.consume(StreamMessage::End(21).into())?;
assert!(
manager.plugin_call_subscriptions.contains_key(&2),
manager.plugin_call_states.contains_key(&2),
"!contains(2) after End(21)"
);
// Ending 22 should cause 2 to be removed
manager.consume(StreamMessage::End(22).into())?;
assert!(
!manager.plugin_call_subscriptions.contains_key(&2),
!manager.plugin_call_states.contains_key(&2),
"contains(2) after End(22)"
);
@ -728,18 +732,15 @@ fn interface_goodbye() -> Result<(), ShellError> {
fn interface_write_plugin_call_registers_subscription() -> Result<(), ShellError> {
let mut manager = TestCase::new().plugin("test");
assert!(
manager.plugin_call_subscriptions.is_empty(),
manager.plugin_call_states.is_empty(),
"plugin call subscriptions not empty before start of test"
);
let interface = manager.get_interface();
let _ = interface.write_plugin_call(PluginCall::Signature, None)?;
let _ = interface.write_plugin_call(PluginCall::Signature, None, mpsc::channel().1)?;
manager.receive_plugin_call_subscriptions();
assert!(
!manager.plugin_call_subscriptions.is_empty(),
"not registered"
);
assert!(!manager.plugin_call_states.is_empty(), "not registered");
Ok(())
}
@ -749,7 +750,8 @@ fn interface_write_plugin_call_writes_signature() -> Result<(), ShellError> {
let manager = test.plugin("test");
let interface = manager.get_interface();
let (writer, _) = interface.write_plugin_call(PluginCall::Signature, None)?;
let (writer, _) =
interface.write_plugin_call(PluginCall::Signature, None, mpsc::channel().1)?;
writer.write()?;
let written = test.next_written().expect("nothing written");
@ -778,6 +780,7 @@ fn interface_write_plugin_call_writes_custom_value_op() -> Result<(), ShellError
CustomValueOp::ToBaseValue,
),
None,
mpsc::channel().1,
)?;
writer.write()?;
@ -812,6 +815,7 @@ fn interface_write_plugin_call_writes_run_with_value_input() -> Result<(), Shell
input: PipelineData::Value(Value::test_int(-1), None),
}),
None,
mpsc::channel().1,
)?;
writer.write()?;
@ -850,6 +854,7 @@ fn interface_write_plugin_call_writes_run_with_stream_input() -> Result<(), Shel
input: values.clone().into_pipeline_data(None),
}),
None,
mpsc::channel().1,
)?;
writer.write()?;
@ -912,7 +917,7 @@ fn interface_receive_plugin_call_receives_response() -> Result<(), ShellError> {
.expect("failed to send on new channel");
drop(tx); // so we don't deadlock on recv()
let response = interface.receive_plugin_call_response(rx, &None)?;
let response = interface.receive_plugin_call_response(rx, None, mpsc::channel().0)?;
assert!(
matches!(response, PluginCallResponse::Signature(_)),
"wrong response: {response:?}"
@ -935,7 +940,7 @@ fn interface_receive_plugin_call_receives_error() -> Result<(), ShellError> {
drop(tx); // so we don't deadlock on recv()
let error = interface
.receive_plugin_call_response(rx, &None)
.receive_plugin_call_response(rx, None, mpsc::channel().0)
.expect_err("did not receive error");
assert!(
matches!(error, ShellError::ExternalNotSupported { .. }),
@ -958,13 +963,13 @@ fn interface_receive_plugin_call_handles_engine_call() -> Result<(), ShellError>
.expect("failed to send on new channel");
// The context should be a bogus context, which will return an error for GetConfig
let context = Some(Context(Arc::new(PluginExecutionBogusContext)));
let mut context = PluginExecutionBogusContext;
// We don't actually send a response, so `receive_plugin_call_response` should actually return
// an error, but it should still do the engine call
drop(tx);
interface
.receive_plugin_call_response(rx, &context)
.receive_plugin_call_response(rx, Some(&mut context), mpsc::channel().0)
.expect_err("no error even though there was no response");
// Check for the engine call response output
@ -996,15 +1001,16 @@ fn start_fake_plugin_call_responder(
std::thread::Builder::new()
.name("fake plugin call responder".into())
.spawn(move || {
for (id, sub) in manager
for (id, state) in manager
.plugin_call_subscription_receiver
.into_iter()
.take(take)
{
for message in f(id) {
sub.sender
state
.sender
.as_ref()
.expect("sender is None")
.expect("sender was not set")
.send(message)
.expect("failed to send");
}
@ -1055,7 +1061,7 @@ fn interface_run() -> Result<(), ShellError> {
},
input: PipelineData::Empty,
},
PluginExecutionBogusContext.into(),
&mut PluginExecutionBogusContext,
)?;
assert_eq!(

View File

@ -348,6 +348,21 @@ impl PluginCallResponse<PipelineDataHeader> {
}
}
impl PluginCallResponse<PipelineData> {
/// Does this response have a stream?
pub(crate) fn has_stream(&self) -> bool {
match self {
PluginCallResponse::PipelineData(data) => match data {
PipelineData::Empty => false,
PipelineData::Value(..) => false,
PipelineData::ListStream(..) => true,
PipelineData::ExternalStream { .. } => true,
},
_ => false,
}
}
}
/// Options that can be changed to affect how the engine treats the plugin
#[derive(Serialize, Deserialize, Debug, Clone)]
pub enum PluginOption {
@ -447,6 +462,8 @@ pub enum EngineCall<D> {
GetEnvVars,
/// Get current working directory
GetCurrentDir,
/// Set an environment variable in the caller's scope
AddEnvVar(String, Value),
/// Evaluate a closure with stream input/output
EvalClosure {
/// The closure to call.
@ -473,6 +490,7 @@ impl<D> EngineCall<D> {
EngineCall::GetEnvVar(_) => "GetEnv",
EngineCall::GetEnvVars => "GetEnvs",
EngineCall::GetCurrentDir => "GetCurrentDir",
EngineCall::AddEnvVar(..) => "AddEnvVar",
EngineCall::EvalClosure { .. } => "EvalClosure",
}
}

View File

@ -0,0 +1,3 @@
mod mutable_cow;
pub(crate) use mutable_cow::*;

View File

@ -0,0 +1,35 @@
/// Like [`Cow`] but with a mutable reference instead. So not exactly clone-on-write, but can be
/// made owned.
pub enum MutableCow<'a, T> {
Borrowed(&'a mut T),
Owned(T),
}
impl<'a, T: Clone> MutableCow<'a, T> {
pub fn owned(&self) -> MutableCow<'static, T> {
match self {
MutableCow::Borrowed(r) => MutableCow::Owned((*r).clone()),
MutableCow::Owned(o) => MutableCow::Owned(o.clone()),
}
}
}
impl<'a, T> std::ops::Deref for MutableCow<'a, T> {
type Target = T;
fn deref(&self) -> &T {
match self {
MutableCow::Borrowed(r) => r,
MutableCow::Owned(o) => o,
}
}
}
impl<'a, T> std::ops::DerefMut for MutableCow<'a, T> {
fn deref_mut(&mut self) -> &mut Self::Target {
match self {
MutableCow::Borrowed(r) => r,
MutableCow::Owned(o) => o,
}
}
}

View File

@ -19,6 +19,12 @@ impl SimplePluginCommand for NuExampleEnv {
"The name of the environment variable to get",
)
.switch("cwd", "Get current working directory instead", None)
.named(
"set",
SyntaxShape::Any,
"Set an environment variable to the value",
None,
)
.search_terms(vec!["example".into(), "env".into()])
.input_output_type(Type::Nothing, Type::Any)
}
@ -31,8 +37,22 @@ impl SimplePluginCommand for NuExampleEnv {
_input: &Value,
) -> Result<Value, LabeledError> {
if call.has_flag("cwd")? {
// Get working directory
Ok(Value::string(engine.get_current_dir()?, call.head))
match call.get_flag_value("set") {
None => {
// Get working directory
Ok(Value::string(engine.get_current_dir()?, call.head))
}
Some(value) => Err(LabeledError {
label: "Invalid arguments".into(),
msg: "--cwd can't be used with --set".into(),
span: Some(value.span()),
}),
}
} else if let Some(value) = call.get_flag_value("set") {
// Set single env var
let name = call.req::<String>(0)?;
engine.add_env_var(name, value)?;
Ok(Value::nothing(call.head))
} else if let Some(name) = call.opt::<String>(0)? {
// Get single env var
Ok(engine

View File

@ -42,3 +42,14 @@ fn get_current_dir() {
assert!(result.status.success());
assert_eq!(cwd, result.out);
}
#[test]
fn set_env() {
let result = nu_with_plugins!(
cwd: ".",
plugin: ("nu_plugin_example"),
"nu-example-env NUSHELL_OPINION --set=rocks; $env.NUSHELL_OPINION"
);
assert!(result.status.success());
assert_eq!("rocks", result.out);
}