Bidirectional communication and streams for plugins (#11911)

This commit is contained in:
Devyn Cairns 2024-02-25 14:32:50 -08:00 committed by GitHub
parent 461f69ac5d
commit 88f1f386bb
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
47 changed files with 8025 additions and 1496 deletions

12
Cargo.lock generated
View File

@ -3153,11 +3153,15 @@ name = "nu-plugin"
version = "0.90.2"
dependencies = [
"bincode",
"log",
"miette",
"nu-engine",
"nu-protocol",
"rmp-serde",
"semver",
"serde",
"serde_json",
"typetag",
]
[[package]]
@ -3330,6 +3334,14 @@ dependencies = [
"sxd-xpath",
]
[[package]]
name = "nu_plugin_stream_example"
version = "0.90.2"
dependencies = [
"nu-plugin",
"nu-protocol",
]
[[package]]
name = "num"
version = "0.2.1"

View File

@ -43,6 +43,7 @@ members = [
"crates/nu_plugin_inc",
"crates/nu_plugin_gstat",
"crates/nu_plugin_example",
"crates/nu_plugin_stream_example",
"crates/nu_plugin_query",
"crates/nu_plugin_custom_values",
"crates/nu_plugin_formats",

View File

@ -1,7 +1,7 @@
use criterion::{criterion_group, criterion_main, BatchSize, Criterion};
use nu_cli::eval_source;
use nu_parser::parse;
use nu_plugin::{EncodingType, PluginResponse};
use nu_plugin::{Encoder, EncodingType, PluginCallResponse, PluginOutput};
use nu_protocol::{engine::EngineState, PipelineData, Span, Value};
use nu_utils::{get_default_config, get_default_env};
use std::path::{Path, PathBuf};
@ -148,10 +148,12 @@ fn encoding_benchmarks(c: &mut Criterion) {
for fmt in ["json", "msgpack"] {
group.bench_function(&format!("{fmt} encode {row_cnt} * {col_cnt}"), |b| {
let mut res = vec![];
let test_data =
PluginResponse::Value(Box::new(encoding_test_data(row_cnt, col_cnt)));
let test_data = PluginOutput::CallResponse(
0,
PluginCallResponse::value(encoding_test_data(row_cnt, col_cnt)),
);
let encoder = EncodingType::try_from_bytes(fmt.as_bytes()).unwrap();
b.iter(|| encoder.encode_response(&test_data, &mut res))
b.iter(|| encoder.encode(&test_data, &mut res))
});
}
}
@ -165,14 +167,16 @@ fn decoding_benchmarks(c: &mut Criterion) {
for fmt in ["json", "msgpack"] {
group.bench_function(&format!("{fmt} decode for {row_cnt} * {col_cnt}"), |b| {
let mut res = vec![];
let test_data =
PluginResponse::Value(Box::new(encoding_test_data(row_cnt, col_cnt)));
let test_data = PluginOutput::CallResponse(
0,
PluginCallResponse::value(encoding_test_data(row_cnt, col_cnt)),
);
let encoder = EncodingType::try_from_bytes(fmt.as_bytes()).unwrap();
encoder.encode_response(&test_data, &mut res).unwrap();
encoder.encode(&test_data, &mut res).unwrap();
let mut binary_data = std::io::Cursor::new(res);
b.iter(|| {
b.iter(|| -> Result<Option<PluginOutput>, _> {
binary_data.set_position(0);
encoder.decode_response(&mut binary_data)
encoder.decode(&mut binary_data)
})
});
}

View File

@ -18,3 +18,7 @@ bincode = "1.3"
rmp-serde = "1.1"
serde = { version = "1.0" }
serde_json = { version = "1.0" }
log = "0.4"
miette = "7.0"
semver = "1.0"
typetag = "0.2"

View File

@ -15,7 +15,7 @@
//! function, which will handle all of the input and output serialization when
//! invoked by Nushell.
//!
//! ```
//! ```rust,no_run
//! use nu_plugin::{EvaluatedCall, LabeledError, MsgPackSerializer, Plugin, serve_plugin};
//! use nu_protocol::{PluginSignature, Value};
//!
@ -46,8 +46,21 @@
//! that demonstrates the full range of plugin capabilities.
mod plugin;
mod protocol;
mod sequence;
mod serializers;
pub use plugin::{get_signature, serve_plugin, Plugin, PluginDeclaration};
pub use protocol::{EvaluatedCall, LabeledError, PluginResponse};
pub use serializers::{json::JsonSerializer, msgpack::MsgPackSerializer, EncodingType};
pub use plugin::{serve_plugin, Plugin, PluginEncoder, StreamingPlugin};
pub use protocol::{EvaluatedCall, LabeledError};
pub use serializers::{json::JsonSerializer, msgpack::MsgPackSerializer};
// Used by other nu crates.
#[doc(hidden)]
pub use plugin::{get_signature, PluginDeclaration};
#[doc(hidden)]
pub use serializers::EncodingType;
// Used by external benchmarks.
#[doc(hidden)]
pub use plugin::Encoder;
#[doc(hidden)]
pub use protocol::{PluginCallResponse, PluginOutput};

View File

@ -0,0 +1,46 @@
use std::sync::{atomic::AtomicBool, Arc};
use nu_protocol::{
ast::Call,
engine::{EngineState, Stack},
};
/// Object safe trait for abstracting operations required of the plugin context.
pub(crate) trait PluginExecutionContext: Send + Sync {
/// The interrupt signal, if present
fn ctrlc(&self) -> Option<&Arc<AtomicBool>>;
}
/// The execution context of a plugin command. May be extended with more fields in the future.
pub(crate) struct PluginExecutionCommandContext {
ctrlc: Option<Arc<AtomicBool>>,
}
impl PluginExecutionCommandContext {
pub fn new(
engine_state: &EngineState,
_stack: &Stack,
_call: &Call,
) -> PluginExecutionCommandContext {
PluginExecutionCommandContext {
ctrlc: engine_state.ctrlc.clone(),
}
}
}
impl PluginExecutionContext for PluginExecutionCommandContext {
fn ctrlc(&self) -> Option<&Arc<AtomicBool>> {
self.ctrlc.as_ref()
}
}
/// A bogus execution context for testing that doesn't really implement anything properly
#[cfg(test)]
pub(crate) struct PluginExecutionBogusContext;
#[cfg(test)]
impl PluginExecutionContext for PluginExecutionBogusContext {
fn ctrlc(&self) -> Option<&Arc<AtomicBool>> {
None
}
}

View File

@ -1,10 +1,7 @@
use crate::EvaluatedCall;
use super::{call_plugin, create_command, get_plugin_encoding};
use crate::protocol::{
CallInfo, CallInput, PluginCall, PluginCustomValue, PluginData, PluginResponse,
};
use super::{PluginExecutionCommandContext, PluginIdentity};
use crate::protocol::{CallInfo, EvaluatedCall};
use std::path::{Path, PathBuf};
use std::sync::Arc;
use nu_engine::eval_block;
use nu_protocol::engine::{Command, EngineState, Stack};
@ -16,8 +13,7 @@ use nu_protocol::{Example, PipelineData, ShellError, Value};
pub struct PluginDeclaration {
name: String,
signature: PluginSignature,
filename: PathBuf,
shell: Option<PathBuf>,
identity: Arc<PluginIdentity>,
}
impl PluginDeclaration {
@ -25,8 +21,7 @@ impl PluginDeclaration {
Self {
name: signature.sig.name.clone(),
signature,
filename,
shell,
identity: Arc::new(PluginIdentity::new(filename, shell)),
}
}
}
@ -76,76 +71,18 @@ impl Command for PluginDeclaration {
call: &Call,
input: PipelineData,
) -> Result<PipelineData, ShellError> {
// Call the command with self path
// Decode information from plugin
// Create PipelineData
let source_file = Path::new(&self.filename);
let mut plugin_cmd = create_command(source_file, self.shell.as_deref());
// We need the current environment variables for `python` based plugins
// Or we'll likely have a problem when a plugin is implemented in a virtual Python environment.
let current_envs = nu_engine::env::env_to_strings(engine_state, stack).unwrap_or_default();
plugin_cmd.envs(current_envs);
let mut child = plugin_cmd.spawn().map_err(|err| {
let decl = engine_state.get_decl(call.decl_id);
ShellError::GenericError {
error: format!("Unable to spawn plugin for {}", decl.name()),
msg: format!("{err}"),
span: Some(call.head),
help: None,
inner: vec![],
}
})?;
let input = input.into_value(call.head);
let span = input.span();
let input = match input {
Value::CustomValue { val, .. } => {
match val.as_any().downcast_ref::<PluginCustomValue>() {
Some(plugin_data) if plugin_data.filename == self.filename => {
CallInput::Data(PluginData {
data: plugin_data.data.clone(),
span,
})
}
_ => {
let custom_value_name = val.value_string();
return Err(ShellError::GenericError {
error: format!(
"Plugin {} can not handle the custom value {}",
self.name, custom_value_name
),
msg: format!("custom value {custom_value_name}"),
span: Some(span),
help: None,
inner: vec![],
});
}
}
}
Value::LazyRecord { val, .. } => CallInput::Value(val.collect()?),
value => CallInput::Value(value),
};
// Create the EvaluatedCall to send to the plugin first - it's best for this to fail early,
// before we actually try to run the plugin command
let evaluated_call = EvaluatedCall::try_from_call(call, engine_state, stack)?;
// Fetch the configuration for a plugin
//
// The `plugin` must match the registered name of a plugin. For
// `register nu_plugin_example` the plugin config lookup uses `"example"`
let config = self
.filename
.file_stem()
.and_then(|file| {
file.to_string_lossy()
.clone()
.strip_prefix("nu_plugin_")
.map(|name| {
nu_engine::get_config(engine_state, stack)
let config = nu_engine::get_config(engine_state, stack)
.plugins
.get(name)
.get(&self.identity.plugin_name)
.cloned()
})
})
.flatten()
.map(|value| {
let span = value.span();
match value {
@ -164,70 +101,41 @@ impl Command for PluginDeclaration {
}
});
let plugin_call = PluginCall::CallInfo(CallInfo {
name: self.name.clone(),
call: EvaluatedCall::try_from_call(call, engine_state, stack)?,
input,
config,
});
// We need the current environment variables for `python` based plugins
// Or we'll likely have a problem when a plugin is implemented in a virtual Python environment.
let current_envs = nu_engine::env::env_to_strings(engine_state, stack).unwrap_or_default();
let encoding = {
let stdout_reader = match &mut child.stdout {
Some(out) => out,
None => {
return Err(ShellError::PluginFailedToLoad {
msg: "Plugin missing stdout reader".into(),
})
}
};
get_plugin_encoding(stdout_reader)?
};
let response = call_plugin(&mut child, plugin_call, &encoding, call.head).map_err(|err| {
// Start the plugin
let plugin = self.identity.clone().spawn(current_envs).map_err(|err| {
let decl = engine_state.get_decl(call.decl_id);
ShellError::GenericError {
error: format!("Unable to decode call for {}", decl.name()),
error: format!("Unable to spawn plugin for `{}`", decl.name()),
msg: err.to_string(),
span: Some(call.head),
help: None,
inner: vec![],
}
});
})?;
let pipeline_data = match response {
Ok(PluginResponse::Value(value)) => {
Ok(PipelineData::Value(value.as_ref().clone(), None))
}
Ok(PluginResponse::PluginData(name, plugin_data)) => Ok(PipelineData::Value(
Value::custom_value(
Box::new(PluginCustomValue {
name,
data: plugin_data.data,
filename: self.filename.clone(),
shell: self.shell.clone(),
source: engine_state.get_decl(call.decl_id).name().to_owned(),
}),
plugin_data.span,
),
None,
)),
Ok(PluginResponse::Error(err)) => Err(err.into()),
Ok(PluginResponse::Signature(..)) => Err(ShellError::GenericError {
error: "Plugin missing value".into(),
msg: "Received a signature from plugin instead of value".into(),
span: Some(call.head),
help: None,
inner: vec![],
}),
Err(err) => Err(err),
};
// Create the context to execute in
let context = Arc::new(PluginExecutionCommandContext::new(
engine_state,
stack,
call,
));
// We need to call .wait() on the child, or we'll risk summoning the zombie horde
let _ = child.wait();
pipeline_data
plugin.run(
CallInfo {
name: self.name.clone(),
call: evaluated_call,
input,
config,
},
context,
)
}
fn is_plugin(&self) -> Option<(&Path, Option<&Path>)> {
Some((&self.filename, self.shell.as_deref()))
Some((&self.identity.filename, self.identity.shell.as_deref()))
}
}

View File

@ -0,0 +1,110 @@
use std::{
ffi::OsStr,
path::{Path, PathBuf},
sync::Arc,
};
use nu_protocol::ShellError;
use super::{create_command, make_plugin_interface, PluginInterface};
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct PluginIdentity {
/// The filename used to start the plugin
pub(crate) filename: PathBuf,
/// The shell used to start the plugin, if required
pub(crate) shell: Option<PathBuf>,
/// The friendly name of the plugin (e.g. `inc` for `C:\nu_plugin_inc.exe`)
pub(crate) plugin_name: String,
}
impl PluginIdentity {
pub(crate) fn new(filename: impl Into<PathBuf>, shell: Option<PathBuf>) -> PluginIdentity {
let filename = filename.into();
// `C:\nu_plugin_inc.exe` becomes `inc`
// `/home/nu/.cargo/bin/nu_plugin_inc` becomes `inc`
// any other path, including if it doesn't start with nu_plugin_, becomes
// `<invalid plugin name>`
let plugin_name = filename
.file_stem()
.map(|stem| stem.to_string_lossy().into_owned())
.and_then(|stem| stem.strip_prefix("nu_plugin_").map(|s| s.to_owned()))
.unwrap_or_else(|| {
log::warn!(
"filename `{}` is not a valid plugin name, must start with nu_plugin_",
filename.display()
);
"<invalid plugin name>".into()
});
PluginIdentity {
filename,
shell,
plugin_name,
}
}
#[cfg(all(test, windows))]
pub(crate) fn new_fake(name: &str) -> Arc<PluginIdentity> {
Arc::new(PluginIdentity::new(
format!(r"C:\fake\path\nu_plugin_{name}.exe"),
None,
))
}
#[cfg(all(test, not(windows)))]
pub(crate) fn new_fake(name: &str) -> Arc<PluginIdentity> {
Arc::new(PluginIdentity::new(
format!(r"/fake/path/nu_plugin_{name}"),
None,
))
}
/// Run the plugin command stored in this [`PluginIdentity`], then set up and return the
/// [`PluginInterface`] attached to it.
pub(crate) fn spawn(
self: Arc<Self>,
envs: impl IntoIterator<Item = (impl AsRef<OsStr>, impl AsRef<OsStr>)>,
) -> Result<PluginInterface, ShellError> {
let source_file = Path::new(&self.filename);
let mut plugin_cmd = create_command(source_file, self.shell.as_deref());
// We need the current environment variables for `python` based plugins
// Or we'll likely have a problem when a plugin is implemented in a virtual Python environment.
plugin_cmd.envs(envs);
let program_name = plugin_cmd.get_program().to_os_string().into_string();
// Run the plugin command
let child = plugin_cmd.spawn().map_err(|err| {
let error_msg = match err.kind() {
std::io::ErrorKind::NotFound => match program_name {
Ok(prog_name) => {
format!("Can't find {prog_name}, please make sure that {prog_name} is in PATH.")
}
_ => {
format!("Error spawning child process: {err}")
}
},
_ => {
format!("Error spawning child process: {err}")
}
};
ShellError::PluginFailedToLoad { msg: error_msg }
})?;
make_plugin_interface(child, self)
}
}
#[test]
fn parses_name_from_path() {
assert_eq!("test", PluginIdentity::new_fake("test").plugin_name);
assert_eq!(
"<invalid plugin name>",
PluginIdentity::new("other", None).plugin_name
);
assert_eq!(
"<invalid plugin name>",
PluginIdentity::new("", None).plugin_name
);
}

View File

@ -0,0 +1,437 @@
//! Implements the stream multiplexing interface for both the plugin side and the engine side.
use std::{
io::Write,
sync::{
atomic::{AtomicBool, Ordering::Relaxed},
Arc, Mutex,
},
thread,
};
use nu_protocol::{ListStream, PipelineData, RawStream, ShellError};
use crate::{
plugin::Encoder,
protocol::{
ExternalStreamInfo, ListStreamInfo, PipelineDataHeader, RawStreamInfo, StreamMessage,
},
sequence::Sequence,
};
mod stream;
mod engine;
pub(crate) use engine::{EngineInterfaceManager, ReceivedPluginCall};
mod plugin;
pub(crate) use plugin::{PluginInterface, PluginInterfaceManager};
use self::stream::{StreamManager, StreamManagerHandle, StreamWriter, WriteStreamMessage};
#[cfg(test)]
mod test_util;
#[cfg(test)]
mod tests;
/// The maximum number of list stream values to send without acknowledgement. This should be tuned
/// with consideration for memory usage.
const LIST_STREAM_HIGH_PRESSURE: i32 = 100;
/// The maximum number of raw stream buffers to send without acknowledgement. This should be tuned
/// with consideration for memory usage.
const RAW_STREAM_HIGH_PRESSURE: i32 = 50;
/// Read input/output from the stream.
pub(crate) trait PluginRead<T> {
/// Returns `Ok(None)` on end of stream.
fn read(&mut self) -> Result<Option<T>, ShellError>;
}
impl<R, E, T> PluginRead<T> for (R, E)
where
R: std::io::BufRead,
E: Encoder<T>,
{
fn read(&mut self) -> Result<Option<T>, ShellError> {
self.1.decode(&mut self.0)
}
}
impl<R, T> PluginRead<T> for &mut R
where
R: PluginRead<T>,
{
fn read(&mut self) -> Result<Option<T>, ShellError> {
(**self).read()
}
}
/// Write input/output to the stream.
///
/// The write should be atomic, without interference from other threads.
pub(crate) trait PluginWrite<T>: Send + Sync {
fn write(&self, data: &T) -> Result<(), ShellError>;
/// Flush any internal buffers, if applicable.
fn flush(&self) -> Result<(), ShellError>;
}
impl<E, T> PluginWrite<T> for (std::io::Stdout, E)
where
E: Encoder<T>,
{
fn write(&self, data: &T) -> Result<(), ShellError> {
let mut lock = self.0.lock();
self.1.encode(data, &mut lock)
}
fn flush(&self) -> Result<(), ShellError> {
self.0.lock().flush().map_err(|err| ShellError::IOError {
msg: err.to_string(),
})
}
}
impl<W, E, T> PluginWrite<T> for (Mutex<W>, E)
where
W: std::io::Write + Send,
E: Encoder<T>,
{
fn write(&self, data: &T) -> Result<(), ShellError> {
let mut lock = self.0.lock().map_err(|_| ShellError::NushellFailed {
msg: "writer mutex poisoned".into(),
})?;
self.1.encode(data, &mut *lock)
}
fn flush(&self) -> Result<(), ShellError> {
let mut lock = self.0.lock().map_err(|_| ShellError::NushellFailed {
msg: "writer mutex poisoned".into(),
})?;
lock.flush().map_err(|err| ShellError::IOError {
msg: err.to_string(),
})
}
}
impl<W, T> PluginWrite<T> for &W
where
W: PluginWrite<T>,
{
fn write(&self, data: &T) -> Result<(), ShellError> {
(**self).write(data)
}
fn flush(&self) -> Result<(), ShellError> {
(**self).flush()
}
}
/// An interface manager handles I/O and state management for communication between a plugin and the
/// engine. See [`PluginInterfaceManager`] for communication from the engine side to a plugin, or
/// [`EngineInterfaceManager`] for communication from the plugin side to the engine.
///
/// There is typically one [`InterfaceManager`] consuming input from a background thread, and
/// managing shared state.
pub(crate) trait InterfaceManager {
/// The corresponding interface type.
type Interface: Interface + 'static;
/// The input message type.
type Input;
/// Make a new interface that communicates with this [`InterfaceManager`].
fn get_interface(&self) -> Self::Interface;
/// Consume an input message.
///
/// When implementing, call [`.consume_stream_message()`] for any encapsulated
/// [`StreamMessage`]s received.
fn consume(&mut self, input: Self::Input) -> Result<(), ShellError>;
/// Get the [`StreamManager`] for handling operations related to stream messages.
fn stream_manager(&self) -> &StreamManager;
/// Prepare [`PipelineData`] after reading. This is called by `read_pipeline_data()` as
/// a hook so that values that need special handling can be taken care of.
fn prepare_pipeline_data(&self, data: PipelineData) -> Result<PipelineData, ShellError>;
/// Consume an input stream message.
///
/// This method is provided for implementors to use.
fn consume_stream_message(&mut self, message: StreamMessage) -> Result<(), ShellError> {
self.stream_manager().handle_message(message)
}
/// Generate `PipelineData` for reading a stream, given a [`PipelineDataHeader`] that was
/// received from the other side.
///
/// This method is provided for implementors to use.
fn read_pipeline_data(
&self,
header: PipelineDataHeader,
ctrlc: Option<&Arc<AtomicBool>>,
) -> Result<PipelineData, ShellError> {
self.prepare_pipeline_data(match header {
PipelineDataHeader::Empty => PipelineData::Empty,
PipelineDataHeader::Value(value) => PipelineData::Value(value, None),
PipelineDataHeader::ListStream(info) => {
let handle = self.stream_manager().get_handle();
let reader = handle.read_stream(info.id, self.get_interface())?;
PipelineData::ListStream(ListStream::from_stream(reader, ctrlc.cloned()), None)
}
PipelineDataHeader::ExternalStream(info) => {
let handle = self.stream_manager().get_handle();
let span = info.span;
let new_raw_stream = |raw_info: RawStreamInfo| {
let reader = handle.read_stream(raw_info.id, self.get_interface())?;
let mut stream =
RawStream::new(Box::new(reader), ctrlc.cloned(), span, raw_info.known_size);
stream.is_binary = raw_info.is_binary;
Ok::<_, ShellError>(stream)
};
PipelineData::ExternalStream {
stdout: info.stdout.map(new_raw_stream).transpose()?,
stderr: info.stderr.map(new_raw_stream).transpose()?,
exit_code: info
.exit_code
.map(|list_info| {
handle
.read_stream(list_info.id, self.get_interface())
.map(|reader| ListStream::from_stream(reader, ctrlc.cloned()))
})
.transpose()?,
span: info.span,
metadata: None,
trim_end_newline: info.trim_end_newline,
}
}
})
}
}
/// An interface provides an API for communicating with a plugin or the engine and facilitates
/// stream I/O. See [`PluginInterface`] for the API from the engine side to a plugin, or
/// [`EngineInterface`] for the API from the plugin side to the engine.
///
/// There can be multiple copies of the interface managed by a single [`InterfaceManager`].
pub(crate) trait Interface: Clone + Send {
/// The output message type, which must be capable of encapsulating a [`StreamMessage`].
type Output: From<StreamMessage>;
/// Write an output message.
fn write(&self, output: Self::Output) -> Result<(), ShellError>;
/// Flush the output buffer, so messages are visible to the other side.
fn flush(&self) -> Result<(), ShellError>;
/// Get the sequence for generating new [`StreamId`](crate::protocol::StreamId)s.
fn stream_id_sequence(&self) -> &Sequence;
/// Get the [`StreamManagerHandle`] for doing stream operations.
fn stream_manager_handle(&self) -> &StreamManagerHandle;
/// Prepare [`PipelineData`] to be written. This is called by `init_write_pipeline_data()` as
/// a hook so that values that need special handling can be taken care of.
fn prepare_pipeline_data(&self, data: PipelineData) -> Result<PipelineData, ShellError>;
/// Initialize a write for [`PipelineData`]. This returns two parts: the header, which can be
/// embedded in the particular message that references the stream, and a writer, which will
/// write out all of the data in the pipeline when `.write()` is called.
///
/// Note that not all [`PipelineData`] starts a stream. You should call `write()` anyway, as
/// it will automatically handle this case.
///
/// This method is provided for implementors to use.
fn init_write_pipeline_data(
&self,
data: PipelineData,
) -> Result<(PipelineDataHeader, PipelineDataWriter<Self>), ShellError> {
// Allocate a stream id and a writer
let new_stream = |high_pressure_mark: i32| {
// Get a free stream id
let id = self.stream_id_sequence().next()?;
// Create the writer
let writer =
self.stream_manager_handle()
.write_stream(id, self.clone(), high_pressure_mark)?;
Ok::<_, ShellError>((id, writer))
};
match self.prepare_pipeline_data(data)? {
PipelineData::Value(value, _) => {
Ok((PipelineDataHeader::Value(value), PipelineDataWriter::None))
}
PipelineData::Empty => Ok((PipelineDataHeader::Empty, PipelineDataWriter::None)),
PipelineData::ListStream(stream, _) => {
let (id, writer) = new_stream(LIST_STREAM_HIGH_PRESSURE)?;
Ok((
PipelineDataHeader::ListStream(ListStreamInfo { id }),
PipelineDataWriter::ListStream(writer, stream),
))
}
PipelineData::ExternalStream {
stdout,
stderr,
exit_code,
span,
metadata: _,
trim_end_newline,
} => {
// Create the writers and stream ids
let stdout_stream = stdout
.is_some()
.then(|| new_stream(RAW_STREAM_HIGH_PRESSURE))
.transpose()?;
let stderr_stream = stderr
.is_some()
.then(|| new_stream(RAW_STREAM_HIGH_PRESSURE))
.transpose()?;
let exit_code_stream = exit_code
.is_some()
.then(|| new_stream(LIST_STREAM_HIGH_PRESSURE))
.transpose()?;
// Generate the header, with the stream ids
let header = PipelineDataHeader::ExternalStream(ExternalStreamInfo {
span,
stdout: stdout
.as_ref()
.zip(stdout_stream.as_ref())
.map(|(stream, (id, _))| RawStreamInfo::new(*id, stream)),
stderr: stderr
.as_ref()
.zip(stderr_stream.as_ref())
.map(|(stream, (id, _))| RawStreamInfo::new(*id, stream)),
exit_code: exit_code_stream
.as_ref()
.map(|&(id, _)| ListStreamInfo { id }),
trim_end_newline,
});
// Collect the writers
let writer = PipelineDataWriter::ExternalStream {
stdout: stdout_stream.map(|(_, writer)| writer).zip(stdout),
stderr: stderr_stream.map(|(_, writer)| writer).zip(stderr),
exit_code: exit_code_stream.map(|(_, writer)| writer).zip(exit_code),
};
Ok((header, writer))
}
}
}
}
impl<T> WriteStreamMessage for T
where
T: Interface,
{
fn write_stream_message(&mut self, msg: StreamMessage) -> Result<(), ShellError> {
self.write(msg.into())
}
fn flush(&mut self) -> Result<(), ShellError> {
<Self as Interface>::flush(self)
}
}
/// Completes the write operation for a [`PipelineData`]. You must call
/// [`PipelineDataWriter::write()`] to write all of the data contained within the streams.
#[derive(Default)]
#[must_use]
pub(crate) enum PipelineDataWriter<W: WriteStreamMessage> {
#[default]
None,
ListStream(StreamWriter<W>, ListStream),
ExternalStream {
stdout: Option<(StreamWriter<W>, RawStream)>,
stderr: Option<(StreamWriter<W>, RawStream)>,
exit_code: Option<(StreamWriter<W>, ListStream)>,
},
}
impl<W> PipelineDataWriter<W>
where
W: WriteStreamMessage + Send + 'static,
{
/// Write all of the data in each of the streams. This method waits for completion.
pub(crate) fn write(self) -> Result<(), ShellError> {
match self {
// If no stream was contained in the PipelineData, do nothing.
PipelineDataWriter::None => Ok(()),
// Write a list stream.
PipelineDataWriter::ListStream(mut writer, stream) => {
writer.write_all(stream)?;
Ok(())
}
// Write all three possible streams of an ExternalStream on separate threads.
PipelineDataWriter::ExternalStream {
stdout,
stderr,
exit_code,
} => {
thread::scope(|scope| {
let stderr_thread = stderr.map(|(mut writer, stream)| {
thread::Builder::new()
.name("plugin stderr writer".into())
.spawn_scoped(scope, move || writer.write_all(raw_stream_iter(stream)))
.expect("failed to spawn thread")
});
let exit_code_thread = exit_code.map(|(mut writer, stream)| {
thread::Builder::new()
.name("plugin exit_code writer".into())
.spawn_scoped(scope, move || writer.write_all(stream))
.expect("failed to spawn thread")
});
// Optimize for stdout: if only stdout is present, don't spawn any other
// threads.
if let Some((mut writer, stream)) = stdout {
writer.write_all(raw_stream_iter(stream))?;
}
let panicked = |thread_name: &str| {
Err(ShellError::NushellFailed {
msg: format!(
"{thread_name} thread panicked in PipelineDataWriter::write"
),
})
};
stderr_thread
.map(|t| t.join().unwrap_or_else(|_| panicked("stderr")))
.transpose()?;
exit_code_thread
.map(|t| t.join().unwrap_or_else(|_| panicked("exit_code")))
.transpose()?;
Ok(())
})
}
}
}
/// Write all of the data in each of the streams. This method returns immediately; any necessary
/// write will happen in the background. If a thread was spawned, its handle is returned.
pub(crate) fn write_background(self) -> Option<thread::JoinHandle<Result<(), ShellError>>> {
match self {
PipelineDataWriter::None => None,
_ => Some(
thread::Builder::new()
.name("plugin stream background writer".into())
.spawn(move || {
let result = self.write();
if let Err(ref err) = result {
// Assume that the background thread error probably won't be handled and log it
// here just in case.
log::warn!("Error while writing pipeline in background: {err}");
}
result
})
.expect("failed to spawn thread"),
),
}
}
}
/// Custom iterator for [`RawStream`] that respects ctrlc, but still has binary chunks
fn raw_stream_iter(stream: RawStream) -> impl Iterator<Item = Result<Vec<u8>, ShellError>> {
let ctrlc = stream.ctrlc;
stream
.stream
.take_while(move |_| ctrlc.as_ref().map(|b| !b.load(Relaxed)).unwrap_or(true))
}

View File

@ -0,0 +1,375 @@
//! Interface used by the plugin to communicate with the engine.
use std::sync::{mpsc, Arc};
use nu_protocol::{
IntoInterruptiblePipelineData, ListStream, PipelineData, PluginSignature, ShellError, Spanned,
Value,
};
use crate::{
protocol::{
CallInfo, CustomValueOp, PluginCall, PluginCallId, PluginCallResponse, PluginCustomValue,
PluginInput, ProtocolInfo,
},
LabeledError, PluginOutput,
};
use super::{
stream::{StreamManager, StreamManagerHandle},
Interface, InterfaceManager, PipelineDataWriter, PluginRead, PluginWrite,
};
use crate::sequence::Sequence;
/// Plugin calls that are received by the [`EngineInterfaceManager`] for handling.
///
/// With each call, an [`EngineInterface`] is included that can be provided to the plugin code
/// and should be used to send the response. The interface sent includes the [`PluginCallId`] for
/// sending associated messages with the correct context.
#[derive(Debug)]
pub(crate) enum ReceivedPluginCall {
Signature {
engine: EngineInterface,
},
Run {
engine: EngineInterface,
call: CallInfo<PipelineData>,
},
CustomValueOp {
engine: EngineInterface,
custom_value: Spanned<PluginCustomValue>,
op: CustomValueOp,
},
}
#[cfg(test)]
mod tests;
/// Internal shared state between the manager and each interface.
struct EngineInterfaceState {
/// Sequence for generating stream ids
stream_id_sequence: Sequence,
/// The synchronized output writer
writer: Box<dyn PluginWrite<PluginOutput>>,
}
impl std::fmt::Debug for EngineInterfaceState {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("EngineInterfaceState")
.field("stream_id_sequence", &self.stream_id_sequence)
.finish_non_exhaustive()
}
}
/// Manages reading and dispatching messages for [`EngineInterface`]s.
#[derive(Debug)]
pub(crate) struct EngineInterfaceManager {
/// Shared state
state: Arc<EngineInterfaceState>,
/// Channel to send received PluginCalls to
plugin_call_sender: mpsc::Sender<ReceivedPluginCall>,
/// Receiver for PluginCalls. This is usually taken after initialization
plugin_call_receiver: Option<mpsc::Receiver<ReceivedPluginCall>>,
/// Manages stream messages and state
stream_manager: StreamManager,
/// Protocol version info, set after `Hello` received
protocol_info: Option<ProtocolInfo>,
}
impl EngineInterfaceManager {
pub(crate) fn new(writer: impl PluginWrite<PluginOutput> + 'static) -> EngineInterfaceManager {
let (plug_tx, plug_rx) = mpsc::channel();
EngineInterfaceManager {
state: Arc::new(EngineInterfaceState {
stream_id_sequence: Sequence::default(),
writer: Box::new(writer),
}),
plugin_call_sender: plug_tx,
plugin_call_receiver: Some(plug_rx),
stream_manager: StreamManager::new(),
protocol_info: None,
}
}
/// Get the receiving end of the plugin call channel. Plugin calls that need to be handled
/// will be sent here.
pub(crate) fn take_plugin_call_receiver(
&mut self,
) -> Option<mpsc::Receiver<ReceivedPluginCall>> {
self.plugin_call_receiver.take()
}
/// Create an [`EngineInterface`] associated with the given call id.
fn interface_for_context(&self, context: PluginCallId) -> EngineInterface {
EngineInterface {
state: self.state.clone(),
stream_manager_handle: self.stream_manager.get_handle(),
context: Some(context),
}
}
/// Send a [`ReceivedPluginCall`] to the channel
fn send_plugin_call(&self, plugin_call: ReceivedPluginCall) -> Result<(), ShellError> {
self.plugin_call_sender
.send(plugin_call)
.map_err(|_| ShellError::NushellFailed {
msg: "Received a plugin call, but there's nowhere to send it".into(),
})
}
/// True if there are no other copies of the state (which would mean there are no interfaces
/// and no stream readers/writers)
pub(crate) fn is_finished(&self) -> bool {
Arc::strong_count(&self.state) < 2
}
/// Loop on input from the given reader as long as `is_finished()` is false
///
/// Any errors will be propagated to all read streams automatically.
pub(crate) fn consume_all(
&mut self,
mut reader: impl PluginRead<PluginInput>,
) -> Result<(), ShellError> {
while let Some(msg) = reader.read().transpose() {
if self.is_finished() {
break;
}
if let Err(err) = msg.and_then(|msg| self.consume(msg)) {
let _ = self.stream_manager.broadcast_read_error(err.clone());
return Err(err);
}
}
Ok(())
}
}
impl InterfaceManager for EngineInterfaceManager {
type Interface = EngineInterface;
type Input = PluginInput;
fn get_interface(&self) -> Self::Interface {
EngineInterface {
state: self.state.clone(),
stream_manager_handle: self.stream_manager.get_handle(),
context: None,
}
}
fn consume(&mut self, input: Self::Input) -> Result<(), ShellError> {
log::trace!("from engine: {:?}", input);
match input {
PluginInput::Hello(info) => {
let local_info = ProtocolInfo::default();
if local_info.is_compatible_with(&info)? {
self.protocol_info = Some(info);
Ok(())
} else {
self.protocol_info = None;
Err(ShellError::PluginFailedToLoad {
msg: format!(
"Plugin is compiled for nushell version {}, \
which is not compatible with version {}",
local_info.version, info.version
),
})
}
}
_ if self.protocol_info.is_none() => {
// Must send protocol info first
Err(ShellError::PluginFailedToLoad {
msg: "Failed to receive initial Hello message. This engine might be too old"
.into(),
})
}
PluginInput::Stream(message) => self.consume_stream_message(message),
PluginInput::Call(id, call) => match call {
// We just let the receiver handle it rather than trying to store signature here
// or something
PluginCall::Signature => self.send_plugin_call(ReceivedPluginCall::Signature {
engine: self.interface_for_context(id),
}),
// Set up the streams from the input and reformat to a ReceivedPluginCall
PluginCall::Run(CallInfo {
name,
mut call,
input,
config,
}) => {
let interface = self.interface_for_context(id);
// If there's an error with initialization of the input stream, just send
// the error response rather than failing here
match self.read_pipeline_data(input, None) {
Ok(input) => {
// Deserialize custom values in the arguments
if let Err(err) = deserialize_call_args(&mut call) {
return interface.write_response(Err(err))?.write();
}
// Send the plugin call to the receiver
self.send_plugin_call(ReceivedPluginCall::Run {
engine: interface,
call: CallInfo {
name,
call,
input,
config,
},
})
}
err @ Err(_) => interface.write_response(err)?.write(),
}
}
// Send request with the custom value
PluginCall::CustomValueOp(custom_value, op) => {
self.send_plugin_call(ReceivedPluginCall::CustomValueOp {
engine: self.interface_for_context(id),
custom_value,
op,
})
}
},
}
}
fn stream_manager(&self) -> &StreamManager {
&self.stream_manager
}
fn prepare_pipeline_data(&self, mut data: PipelineData) -> Result<PipelineData, ShellError> {
// Deserialize custom values in the pipeline data
match data {
PipelineData::Value(ref mut value, _) => {
PluginCustomValue::deserialize_custom_values_in(value)?;
Ok(data)
}
PipelineData::ListStream(ListStream { stream, ctrlc, .. }, meta) => Ok(stream
.map(|mut value| {
let span = value.span();
PluginCustomValue::deserialize_custom_values_in(&mut value)
.map(|()| value)
.unwrap_or_else(|err| Value::error(err, span))
})
.into_pipeline_data_with_metadata(meta, ctrlc)),
PipelineData::Empty | PipelineData::ExternalStream { .. } => Ok(data),
}
}
}
/// Deserialize custom values in call arguments
fn deserialize_call_args(call: &mut crate::EvaluatedCall) -> Result<(), ShellError> {
call.positional
.iter_mut()
.try_for_each(PluginCustomValue::deserialize_custom_values_in)?;
call.named
.iter_mut()
.flat_map(|(_, value)| value.as_mut())
.try_for_each(PluginCustomValue::deserialize_custom_values_in)
}
/// A reference through which the nushell engine can be interacted with during execution.
#[derive(Debug, Clone)]
pub struct EngineInterface {
/// Shared state with the manager
state: Arc<EngineInterfaceState>,
/// Handle to stream manager
stream_manager_handle: StreamManagerHandle,
/// The plugin call this interface belongs to.
context: Option<PluginCallId>,
}
impl EngineInterface {
/// Write the protocol info. This should be done after initialization
pub(crate) fn hello(&self) -> Result<(), ShellError> {
self.write(PluginOutput::Hello(ProtocolInfo::default()))?;
self.flush()
}
fn context(&self) -> Result<PluginCallId, ShellError> {
self.context.ok_or_else(|| ShellError::NushellFailed {
msg: "Tried to call an EngineInterface method that requires a call context \
outside of one"
.into(),
})
}
/// Write a call response of either [`PipelineData`] or an error. Returns the stream writer
/// to finish writing the stream
pub(crate) fn write_response(
&self,
result: Result<PipelineData, impl Into<LabeledError>>,
) -> Result<PipelineDataWriter<Self>, ShellError> {
match result {
Ok(data) => {
let (header, writer) = match self.init_write_pipeline_data(data) {
Ok(tup) => tup,
// If we get an error while trying to construct the pipeline data, send that
// instead
Err(err) => return self.write_response(Err(err)),
};
// Write pipeline data header response, and the full stream
let response = PluginCallResponse::PipelineData(header);
self.write(PluginOutput::CallResponse(self.context()?, response))?;
self.flush()?;
Ok(writer)
}
Err(err) => {
let response = PluginCallResponse::Error(err.into());
self.write(PluginOutput::CallResponse(self.context()?, response))?;
self.flush()?;
Ok(Default::default())
}
}
}
/// Write a call response of plugin signatures.
pub(crate) fn write_signature(
&self,
signature: Vec<PluginSignature>,
) -> Result<(), ShellError> {
let response = PluginCallResponse::Signature(signature);
self.write(PluginOutput::CallResponse(self.context()?, response))?;
self.flush()
}
}
impl Interface for EngineInterface {
type Output = PluginOutput;
fn write(&self, output: PluginOutput) -> Result<(), ShellError> {
log::trace!("to engine: {:?}", output);
self.state.writer.write(&output)
}
fn flush(&self) -> Result<(), ShellError> {
self.state.writer.flush()
}
fn stream_id_sequence(&self) -> &Sequence {
&self.state.stream_id_sequence
}
fn stream_manager_handle(&self) -> &StreamManagerHandle {
&self.stream_manager_handle
}
fn prepare_pipeline_data(&self, mut data: PipelineData) -> Result<PipelineData, ShellError> {
// Serialize custom values in the pipeline data
match data {
PipelineData::Value(ref mut value, _) => {
PluginCustomValue::serialize_custom_values_in(value)?;
Ok(data)
}
PipelineData::ListStream(ListStream { stream, ctrlc, .. }, meta) => Ok(stream
.map(|mut value| {
let span = value.span();
PluginCustomValue::serialize_custom_values_in(&mut value)
.map(|_| value)
.unwrap_or_else(|err| Value::error(err, span))
})
.into_pipeline_data_with_metadata(meta, ctrlc)),
PipelineData::Empty | PipelineData::ExternalStream { .. } => Ok(data),
}
}
}

View File

@ -0,0 +1,779 @@
use nu_protocol::{
CustomValue, IntoInterruptiblePipelineData, PipelineData, PluginSignature, ShellError, Span,
Spanned, Value,
};
use crate::{
plugin::interface::{test_util::TestCase, Interface, InterfaceManager},
protocol::{
test_util::{expected_test_custom_value, test_plugin_custom_value, TestCustomValue},
CallInfo, CustomValueOp, ExternalStreamInfo, ListStreamInfo, PipelineDataHeader,
PluginCall, PluginCustomValue, PluginInput, Protocol, ProtocolInfo, RawStreamInfo,
StreamData, StreamMessage,
},
EvaluatedCall, LabeledError, PluginCallResponse, PluginOutput,
};
use super::ReceivedPluginCall;
#[test]
fn manager_consume_all_consumes_messages() -> Result<(), ShellError> {
let mut test = TestCase::new();
let mut manager = test.engine();
// This message should be non-problematic
test.add(PluginInput::Hello(ProtocolInfo::default()));
manager.consume_all(&mut test)?;
assert!(!test.has_unconsumed_read());
Ok(())
}
#[test]
fn manager_consume_all_exits_after_streams_and_interfaces_are_dropped() -> Result<(), ShellError> {
let mut test = TestCase::new();
let mut manager = test.engine();
// Add messages that won't cause errors
for _ in 0..5 {
test.add(PluginInput::Hello(ProtocolInfo::default()));
}
// Create a stream...
let stream = manager.read_pipeline_data(
PipelineDataHeader::ListStream(ListStreamInfo { id: 0 }),
None,
)?;
// and an interface...
let interface = manager.get_interface();
// Expect that is_finished is false
assert!(
!manager.is_finished(),
"is_finished is true even though active stream/interface exists"
);
// After dropping, it should be true
drop(stream);
drop(interface);
assert!(
manager.is_finished(),
"is_finished is false even though manager has no stream or interface"
);
// When it's true, consume_all shouldn't consume everything
manager.consume_all(&mut test)?;
assert!(
test.has_unconsumed_read(),
"consume_all consumed the messages"
);
Ok(())
}
fn test_io_error() -> ShellError {
ShellError::IOError {
msg: "test io error".into(),
}
}
fn check_test_io_error(error: &ShellError) {
assert!(
format!("{error:?}").contains("test io error"),
"error: {error}"
);
}
#[test]
fn manager_consume_all_propagates_error_to_readers() -> Result<(), ShellError> {
let mut test = TestCase::new();
let mut manager = test.engine();
test.set_read_error(test_io_error());
let stream = manager.read_pipeline_data(
PipelineDataHeader::ListStream(ListStreamInfo { id: 0 }),
None,
)?;
manager
.consume_all(&mut test)
.expect_err("consume_all did not error");
// Ensure end of stream
drop(manager);
let value = stream.into_iter().next().expect("stream is empty");
if let Value::Error { error, .. } = value {
check_test_io_error(&error);
Ok(())
} else {
panic!("did not get an error");
}
}
fn invalid_input() -> PluginInput {
// This should definitely cause an error, as 0.0.0 is not compatible with any version other than
// itself
PluginInput::Hello(ProtocolInfo {
protocol: Protocol::NuPlugin,
version: "0.0.0".into(),
features: vec![],
})
}
fn check_invalid_input_error(error: &ShellError) {
// the error message should include something about the version...
assert!(format!("{error:?}").contains("0.0.0"), "error: {error}");
}
#[test]
fn manager_consume_all_propagates_message_error_to_readers() -> Result<(), ShellError> {
let mut test = TestCase::new();
let mut manager = test.engine();
test.add(invalid_input());
let stream = manager.read_pipeline_data(
PipelineDataHeader::ExternalStream(ExternalStreamInfo {
span: Span::test_data(),
stdout: Some(RawStreamInfo {
id: 0,
is_binary: false,
known_size: None,
}),
stderr: None,
exit_code: None,
trim_end_newline: false,
}),
None,
)?;
manager
.consume_all(&mut test)
.expect_err("consume_all did not error");
// Ensure end of stream
drop(manager);
let value = stream.into_iter().next().expect("stream is empty");
if let Value::Error { error, .. } = value {
check_invalid_input_error(&error);
Ok(())
} else {
panic!("did not get an error");
}
}
#[test]
fn manager_consume_sets_protocol_info_on_hello() -> Result<(), ShellError> {
let mut manager = TestCase::new().engine();
let info = ProtocolInfo::default();
manager.consume(PluginInput::Hello(info.clone()))?;
let set_info = manager
.protocol_info
.as_ref()
.expect("protocol info not set");
assert_eq!(info.version, set_info.version);
Ok(())
}
#[test]
fn manager_consume_errors_on_wrong_nushell_version() -> Result<(), ShellError> {
let mut manager = TestCase::new().engine();
let info = ProtocolInfo {
protocol: Protocol::NuPlugin,
version: "0.0.0".into(),
features: vec![],
};
manager
.consume(PluginInput::Hello(info))
.expect_err("version 0.0.0 should cause an error");
Ok(())
}
#[test]
fn manager_consume_errors_on_sending_other_messages_before_hello() -> Result<(), ShellError> {
let mut manager = TestCase::new().engine();
// hello not set
assert!(manager.protocol_info.is_none());
let error = manager
.consume(PluginInput::Stream(StreamMessage::Drop(0)))
.expect_err("consume before Hello should cause an error");
assert!(format!("{error:?}").contains("Hello"));
Ok(())
}
#[test]
fn manager_consume_call_signature_forwards_to_receiver_with_context() -> Result<(), ShellError> {
let mut manager = TestCase::new().engine();
manager.protocol_info = Some(ProtocolInfo::default());
let rx = manager
.take_plugin_call_receiver()
.expect("couldn't take receiver");
manager.consume(PluginInput::Call(0, PluginCall::Signature))?;
match rx.try_recv().expect("call was not forwarded to receiver") {
ReceivedPluginCall::Signature { engine } => {
assert_eq!(Some(0), engine.context);
Ok(())
}
call => panic!("wrong call type: {call:?}"),
}
}
#[test]
fn manager_consume_call_run_forwards_to_receiver_with_context() -> Result<(), ShellError> {
let mut manager = TestCase::new().engine();
manager.protocol_info = Some(ProtocolInfo::default());
let rx = manager
.take_plugin_call_receiver()
.expect("couldn't take receiver");
manager.consume(PluginInput::Call(
17,
PluginCall::Run(CallInfo {
name: "bar".into(),
call: EvaluatedCall {
head: Span::test_data(),
positional: vec![],
named: vec![],
},
input: PipelineDataHeader::Empty,
config: None,
}),
))?;
// Make sure the streams end and we don't deadlock
drop(manager);
match rx.try_recv().expect("call was not forwarded to receiver") {
ReceivedPluginCall::Run { engine, call: _ } => {
assert_eq!(Some(17), engine.context, "context");
Ok(())
}
call => panic!("wrong call type: {call:?}"),
}
}
#[test]
fn manager_consume_call_run_forwards_to_receiver_with_pipeline_data() -> Result<(), ShellError> {
let mut manager = TestCase::new().engine();
manager.protocol_info = Some(ProtocolInfo::default());
let rx = manager
.take_plugin_call_receiver()
.expect("couldn't take receiver");
manager.consume(PluginInput::Call(
0,
PluginCall::Run(CallInfo {
name: "bar".into(),
call: EvaluatedCall {
head: Span::test_data(),
positional: vec![],
named: vec![],
},
input: PipelineDataHeader::ListStream(ListStreamInfo { id: 6 }),
config: None,
}),
))?;
for i in 0..10 {
manager.consume(PluginInput::Stream(StreamMessage::Data(
6,
Value::test_int(i).into(),
)))?;
}
manager.consume(PluginInput::Stream(StreamMessage::End(6)))?;
// Make sure the streams end and we don't deadlock
drop(manager);
match rx.try_recv().expect("call was not forwarded to receiver") {
ReceivedPluginCall::Run { engine: _, call } => {
assert_eq!("bar", call.name);
// Ensure we manage to receive the stream messages
assert_eq!(10, call.input.into_iter().count());
Ok(())
}
call => panic!("wrong call type: {call:?}"),
}
}
#[test]
fn manager_consume_call_run_deserializes_custom_values_in_args() -> Result<(), ShellError> {
let mut manager = TestCase::new().engine();
manager.protocol_info = Some(ProtocolInfo::default());
let rx = manager
.take_plugin_call_receiver()
.expect("couldn't take receiver");
let value = Value::test_custom_value(Box::new(test_plugin_custom_value()));
manager.consume(PluginInput::Call(
0,
PluginCall::Run(CallInfo {
name: "bar".into(),
call: EvaluatedCall {
head: Span::test_data(),
positional: vec![value.clone()],
named: vec![(
Spanned {
item: "flag".into(),
span: Span::test_data(),
},
Some(value),
)],
},
input: PipelineDataHeader::Empty,
config: None,
}),
))?;
// Make sure the streams end and we don't deadlock
drop(manager);
match rx.try_recv().expect("call was not forwarded to receiver") {
ReceivedPluginCall::Run { engine: _, call } => {
assert_eq!(1, call.call.positional.len());
assert_eq!(1, call.call.named.len());
for arg in call.call.positional {
let custom_value: &TestCustomValue = arg
.as_custom_value()?
.as_any()
.downcast_ref()
.expect("positional arg is not TestCustomValue");
assert_eq!(expected_test_custom_value(), *custom_value, "positional");
}
for (key, val) in call.call.named {
let key = &key.item;
let custom_value: &TestCustomValue = val
.as_ref()
.unwrap_or_else(|| panic!("found empty named argument: {key}"))
.as_custom_value()?
.as_any()
.downcast_ref()
.unwrap_or_else(|| panic!("named arg {key} is not TestCustomValue"));
assert_eq!(expected_test_custom_value(), *custom_value, "named: {key}");
}
Ok(())
}
call => panic!("wrong call type: {call:?}"),
}
}
#[test]
fn manager_consume_call_custom_value_op_forwards_to_receiver_with_context() -> Result<(), ShellError>
{
let mut manager = TestCase::new().engine();
manager.protocol_info = Some(ProtocolInfo::default());
let rx = manager
.take_plugin_call_receiver()
.expect("couldn't take receiver");
manager.consume(PluginInput::Call(
32,
PluginCall::CustomValueOp(
Spanned {
item: test_plugin_custom_value(),
span: Span::test_data(),
},
CustomValueOp::ToBaseValue,
),
))?;
match rx.try_recv().expect("call was not forwarded to receiver") {
ReceivedPluginCall::CustomValueOp {
engine,
custom_value,
op,
} => {
assert_eq!(Some(32), engine.context);
assert_eq!("TestCustomValue", custom_value.item.name);
assert!(
matches!(op, CustomValueOp::ToBaseValue),
"incorrect op: {op:?}"
);
}
call => panic!("wrong call type: {call:?}"),
}
Ok(())
}
#[test]
fn manager_prepare_pipeline_data_deserializes_custom_values() -> Result<(), ShellError> {
let manager = TestCase::new().engine();
let data = manager.prepare_pipeline_data(PipelineData::Value(
Value::test_custom_value(Box::new(test_plugin_custom_value())),
None,
))?;
let value = data
.into_iter()
.next()
.expect("prepared pipeline data is empty");
let custom_value: &TestCustomValue = value
.as_custom_value()?
.as_any()
.downcast_ref()
.expect("custom value is not a TestCustomValue, probably not deserialized");
assert_eq!(expected_test_custom_value(), *custom_value);
Ok(())
}
#[test]
fn manager_prepare_pipeline_data_deserializes_custom_values_in_streams() -> Result<(), ShellError> {
let manager = TestCase::new().engine();
let data = manager.prepare_pipeline_data(
[Value::test_custom_value(Box::new(
test_plugin_custom_value(),
))]
.into_pipeline_data(None),
)?;
let value = data
.into_iter()
.next()
.expect("prepared pipeline data is empty");
let custom_value: &TestCustomValue = value
.as_custom_value()?
.as_any()
.downcast_ref()
.expect("custom value is not a TestCustomValue, probably not deserialized");
assert_eq!(expected_test_custom_value(), *custom_value);
Ok(())
}
#[test]
fn manager_prepare_pipeline_data_embeds_deserialization_errors_in_streams() -> Result<(), ShellError>
{
let manager = TestCase::new().engine();
let invalid_custom_value = PluginCustomValue {
name: "Invalid".into(),
data: vec![0; 8], // should fail to decode to anything
source: None,
};
let span = Span::new(20, 30);
let data = manager.prepare_pipeline_data(
[Value::custom_value(Box::new(invalid_custom_value), span)].into_pipeline_data(None),
)?;
let value = data
.into_iter()
.next()
.expect("prepared pipeline data is empty");
match value {
Value::Error { error, .. } => match *error {
ShellError::CustomValueFailedToDecode {
span: error_span, ..
} => {
assert_eq!(span, error_span, "error span not the same as the value's");
}
_ => panic!("expected ShellError::CustomValueFailedToDecode, but got {error:?}"),
},
_ => panic!("unexpected value, not error: {value:?}"),
}
Ok(())
}
#[test]
fn interface_hello_sends_protocol_info() -> Result<(), ShellError> {
let test = TestCase::new();
let interface = test.engine().get_interface();
interface.hello()?;
let written = test.next_written().expect("nothing written");
match written {
PluginOutput::Hello(info) => {
assert_eq!(ProtocolInfo::default().version, info.version);
}
_ => panic!("unexpected message written: {written:?}"),
}
assert!(!test.has_unconsumed_write());
Ok(())
}
#[test]
fn interface_write_response_with_value() -> Result<(), ShellError> {
let test = TestCase::new();
let interface = test.engine().interface_for_context(33);
interface
.write_response(Ok::<_, ShellError>(PipelineData::Value(
Value::test_int(6),
None,
)))?
.write()?;
let written = test.next_written().expect("nothing written");
match written {
PluginOutput::CallResponse(id, response) => {
assert_eq!(33, id, "id");
match response {
PluginCallResponse::PipelineData(header) => match header {
PipelineDataHeader::Value(value) => assert_eq!(6, value.as_int()?),
_ => panic!("unexpected pipeline data header: {header:?}"),
},
_ => panic!("unexpected response: {response:?}"),
}
}
_ => panic!("unexpected message written: {written:?}"),
}
assert!(!test.has_unconsumed_write());
Ok(())
}
#[test]
fn interface_write_response_with_stream() -> Result<(), ShellError> {
let test = TestCase::new();
let manager = test.engine();
let interface = manager.interface_for_context(34);
interface
.write_response(Ok::<_, ShellError>(
[Value::test_int(3), Value::test_int(4), Value::test_int(5)].into_pipeline_data(None),
))?
.write()?;
let written = test.next_written().expect("nothing written");
let info = match written {
PluginOutput::CallResponse(_, response) => match response {
PluginCallResponse::PipelineData(header) => match header {
PipelineDataHeader::ListStream(info) => info,
_ => panic!("expected ListStream header: {header:?}"),
},
_ => panic!("wrong response: {response:?}"),
},
_ => panic!("wrong output written: {written:?}"),
};
for number in [3, 4, 5] {
match test.next_written().expect("missing stream Data message") {
PluginOutput::Stream(StreamMessage::Data(id, data)) => {
assert_eq!(info.id, id, "Data id");
match data {
StreamData::List(val) => assert_eq!(number, val.as_int()?),
_ => panic!("expected List data: {data:?}"),
}
}
message => panic!("expected Stream(Data(..)): {message:?}"),
}
}
match test.next_written().expect("missing stream End message") {
PluginOutput::Stream(StreamMessage::End(id)) => assert_eq!(info.id, id, "End id"),
message => panic!("expected Stream(Data(..)): {message:?}"),
}
assert!(!test.has_unconsumed_write());
Ok(())
}
#[test]
fn interface_write_response_with_error() -> Result<(), ShellError> {
let test = TestCase::new();
let interface = test.engine().interface_for_context(35);
let labeled_error = LabeledError {
label: "this is an error".into(),
msg: "a test error".into(),
span: None,
};
interface
.write_response(Err(labeled_error.clone()))?
.write()?;
let written = test.next_written().expect("nothing written");
match written {
PluginOutput::CallResponse(id, response) => {
assert_eq!(35, id, "id");
match response {
PluginCallResponse::Error(err) => assert_eq!(labeled_error, err),
_ => panic!("unexpected response: {response:?}"),
}
}
_ => panic!("unexpected message written: {written:?}"),
}
assert!(!test.has_unconsumed_write());
Ok(())
}
#[test]
fn interface_write_signature() -> Result<(), ShellError> {
let test = TestCase::new();
let interface = test.engine().interface_for_context(36);
let signatures = vec![PluginSignature::build("test command")];
interface.write_signature(signatures.clone())?;
let written = test.next_written().expect("nothing written");
match written {
PluginOutput::CallResponse(id, response) => {
assert_eq!(36, id, "id");
match response {
PluginCallResponse::Signature(sigs) => assert_eq!(1, sigs.len(), "sigs.len"),
_ => panic!("unexpected response: {response:?}"),
}
}
_ => panic!("unexpected message written: {written:?}"),
}
assert!(!test.has_unconsumed_write());
Ok(())
}
#[test]
fn interface_prepare_pipeline_data_serializes_custom_values() -> Result<(), ShellError> {
let interface = TestCase::new().engine().get_interface();
let data = interface.prepare_pipeline_data(PipelineData::Value(
Value::test_custom_value(Box::new(expected_test_custom_value())),
None,
))?;
let value = data
.into_iter()
.next()
.expect("prepared pipeline data is empty");
let custom_value: &PluginCustomValue = value
.as_custom_value()?
.as_any()
.downcast_ref()
.expect("custom value is not a PluginCustomValue, probably not serialized");
let expected = test_plugin_custom_value();
assert_eq!(expected.name, custom_value.name);
assert_eq!(expected.data, custom_value.data);
assert!(custom_value.source.is_none());
Ok(())
}
#[test]
fn interface_prepare_pipeline_data_serializes_custom_values_in_streams() -> Result<(), ShellError> {
let interface = TestCase::new().engine().get_interface();
let data = interface.prepare_pipeline_data(
[Value::test_custom_value(Box::new(
expected_test_custom_value(),
))]
.into_pipeline_data(None),
)?;
let value = data
.into_iter()
.next()
.expect("prepared pipeline data is empty");
let custom_value: &PluginCustomValue = value
.as_custom_value()?
.as_any()
.downcast_ref()
.expect("custom value is not a PluginCustomValue, probably not serialized");
let expected = test_plugin_custom_value();
assert_eq!(expected.name, custom_value.name);
assert_eq!(expected.data, custom_value.data);
assert!(custom_value.source.is_none());
Ok(())
}
/// A non-serializable custom value. Should cause a serialization error
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
enum CantSerialize {
#[serde(skip_serializing)]
BadVariant,
}
#[typetag::serde]
impl CustomValue for CantSerialize {
fn clone_value(&self, span: Span) -> Value {
Value::custom_value(Box::new(self.clone()), span)
}
fn value_string(&self) -> String {
"CantSerialize".into()
}
fn to_base_value(&self, _span: Span) -> Result<Value, ShellError> {
unimplemented!()
}
fn as_any(&self) -> &dyn std::any::Any {
self
}
}
#[test]
fn interface_prepare_pipeline_data_embeds_serialization_errors_in_streams() -> Result<(), ShellError>
{
let interface = TestCase::new().engine().get_interface();
let span = Span::new(40, 60);
let data = interface.prepare_pipeline_data(
[Value::custom_value(
Box::new(CantSerialize::BadVariant),
span,
)]
.into_pipeline_data(None),
)?;
let value = data
.into_iter()
.next()
.expect("prepared pipeline data is empty");
match value {
Value::Error { error, .. } => match *error {
ShellError::CustomValueFailedToEncode {
span: error_span, ..
} => {
assert_eq!(span, error_span, "error span not the same as the value's");
}
_ => panic!("expected ShellError::CustomValueFailedToEncode, but got {error:?}"),
},
_ => panic!("unexpected value, not error: {value:?}"),
}
Ok(())
}

View File

@ -0,0 +1,504 @@
//! Interface used by the engine to communicate with the plugin.
use std::{
collections::{btree_map, BTreeMap},
sync::{mpsc, Arc},
};
use nu_protocol::{
IntoInterruptiblePipelineData, ListStream, PipelineData, PluginSignature, ShellError, Spanned,
Value,
};
use crate::{
plugin::{context::PluginExecutionContext, PluginIdentity},
protocol::{
CallInfo, CustomValueOp, PluginCall, PluginCallId, PluginCallResponse, PluginCustomValue,
PluginInput, PluginOutput, ProtocolInfo,
},
sequence::Sequence,
};
use super::{
stream::{StreamManager, StreamManagerHandle},
Interface, InterfaceManager, PipelineDataWriter, PluginRead, PluginWrite,
};
#[cfg(test)]
mod tests;
#[derive(Debug)]
enum ReceivedPluginCallMessage {
/// The final response to send
Response(PluginCallResponse<PipelineData>),
/// An critical error with the interface
Error(ShellError),
}
/// Context for plugin call execution
#[derive(Clone)]
pub(crate) struct Context(Arc<dyn PluginExecutionContext>);
impl std::fmt::Debug for Context {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str("Context")
}
}
impl std::ops::Deref for Context {
type Target = dyn PluginExecutionContext;
fn deref(&self) -> &Self::Target {
&*self.0
}
}
/// Internal shared state between the manager and each interface.
struct PluginInterfaceState {
/// The identity of the plugin being interfaced with
identity: Arc<PluginIdentity>,
/// Sequence for generating plugin call ids
plugin_call_id_sequence: Sequence,
/// Sequence for generating stream ids
stream_id_sequence: Sequence,
/// Sender to subscribe to a plugin call response
plugin_call_subscription_sender: mpsc::Sender<(PluginCallId, PluginCallSubscription)>,
/// The synchronized output writer
writer: Box<dyn PluginWrite<PluginInput>>,
}
impl std::fmt::Debug for PluginInterfaceState {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("PluginInterfaceState")
.field("identity", &self.identity)
.field("plugin_call_id_sequence", &self.plugin_call_id_sequence)
.field("stream_id_sequence", &self.stream_id_sequence)
.field(
"plugin_call_subscription_sender",
&self.plugin_call_subscription_sender,
)
.finish_non_exhaustive()
}
}
/// Sent to the [`PluginInterfaceManager`] before making a plugin call to indicate interest in its
/// response.
#[derive(Debug)]
struct PluginCallSubscription {
/// The sender back to the thread that is waiting for the plugin call response
sender: mpsc::Sender<ReceivedPluginCallMessage>,
/// Optional context for the environment of a plugin call
context: Option<Context>,
}
/// Manages reading and dispatching messages for [`PluginInterface`]s.
#[derive(Debug)]
pub(crate) struct PluginInterfaceManager {
/// Shared state
state: Arc<PluginInterfaceState>,
/// Manages stream messages and state
stream_manager: StreamManager,
/// Protocol version info, set after `Hello` received
protocol_info: Option<ProtocolInfo>,
/// Subscriptions for messages related to plugin calls
plugin_call_subscriptions: BTreeMap<PluginCallId, PluginCallSubscription>,
/// Receiver for plugin call subscriptions
plugin_call_subscription_receiver: mpsc::Receiver<(PluginCallId, PluginCallSubscription)>,
}
impl PluginInterfaceManager {
pub(crate) fn new(
identity: Arc<PluginIdentity>,
writer: impl PluginWrite<PluginInput> + 'static,
) -> PluginInterfaceManager {
let (subscription_tx, subscription_rx) = mpsc::channel();
PluginInterfaceManager {
state: Arc::new(PluginInterfaceState {
identity,
plugin_call_id_sequence: Sequence::default(),
stream_id_sequence: Sequence::default(),
plugin_call_subscription_sender: subscription_tx,
writer: Box::new(writer),
}),
stream_manager: StreamManager::new(),
protocol_info: None,
plugin_call_subscriptions: BTreeMap::new(),
plugin_call_subscription_receiver: subscription_rx,
}
}
/// Consume pending messages in the `plugin_call_subscription_receiver`
fn receive_plugin_call_subscriptions(&mut self) {
while let Ok((id, subscription)) = self.plugin_call_subscription_receiver.try_recv() {
if let btree_map::Entry::Vacant(e) = self.plugin_call_subscriptions.entry(id) {
e.insert(subscription);
} else {
log::warn!("Duplicate plugin call ID ignored: {id}");
}
}
}
/// Find the context corresponding to the given plugin call id
fn get_context(&mut self, id: PluginCallId) -> Result<Option<Context>, ShellError> {
// Make sure we're up to date
self.receive_plugin_call_subscriptions();
// Find the subscription and return the context
self.plugin_call_subscriptions
.get(&id)
.map(|sub| sub.context.clone())
.ok_or_else(|| ShellError::PluginFailedToDecode {
msg: format!("Unknown plugin call ID: {id}"),
})
}
/// Send a [`PluginCallResponse`] to the appropriate sender
fn send_plugin_call_response(
&mut self,
id: PluginCallId,
response: PluginCallResponse<PipelineData>,
) -> Result<(), ShellError> {
// Ensure we're caught up on the subscriptions made
self.receive_plugin_call_subscriptions();
// Remove the subscription, since this would be the last message
if let Some(subscription) = self.plugin_call_subscriptions.remove(&id) {
if subscription
.sender
.send(ReceivedPluginCallMessage::Response(response))
.is_err()
{
log::warn!("Received a plugin call response for id={id}, but the caller hung up");
}
Ok(())
} else {
Err(ShellError::PluginFailedToDecode {
msg: format!("Unknown plugin call ID: {id}"),
})
}
}
/// True if there are no other copies of the state (which would mean there are no interfaces
/// and no stream readers/writers)
pub(crate) fn is_finished(&self) -> bool {
Arc::strong_count(&self.state) < 2
}
/// Loop on input from the given reader as long as `is_finished()` is false
///
/// Any errors will be propagated to all read streams automatically.
pub(crate) fn consume_all(
&mut self,
mut reader: impl PluginRead<PluginOutput>,
) -> Result<(), ShellError> {
while let Some(msg) = reader.read().transpose() {
if self.is_finished() {
break;
}
if let Err(err) = msg.and_then(|msg| self.consume(msg)) {
// Error to streams
let _ = self.stream_manager.broadcast_read_error(err.clone());
// Error to call waiters
self.receive_plugin_call_subscriptions();
for subscription in
std::mem::take(&mut self.plugin_call_subscriptions).into_values()
{
let _ = subscription
.sender
.send(ReceivedPluginCallMessage::Error(err.clone()));
}
return Err(err);
}
}
Ok(())
}
}
impl InterfaceManager for PluginInterfaceManager {
type Interface = PluginInterface;
type Input = PluginOutput;
fn get_interface(&self) -> Self::Interface {
PluginInterface {
state: self.state.clone(),
stream_manager_handle: self.stream_manager.get_handle(),
}
}
fn consume(&mut self, input: Self::Input) -> Result<(), ShellError> {
log::trace!("from plugin: {:?}", input);
match input {
PluginOutput::Hello(info) => {
let local_info = ProtocolInfo::default();
if local_info.is_compatible_with(&info)? {
self.protocol_info = Some(info);
Ok(())
} else {
self.protocol_info = None;
Err(ShellError::PluginFailedToLoad {
msg: format!(
"Plugin is compiled for nushell version {}, \
which is not compatible with version {}",
info.version, local_info.version
),
})
}
}
_ if self.protocol_info.is_none() => {
// Must send protocol info first
Err(ShellError::PluginFailedToLoad {
msg: "Failed to receive initial Hello message. \
This plugin might be too old"
.into(),
})
}
PluginOutput::Stream(message) => self.consume_stream_message(message),
PluginOutput::CallResponse(id, response) => {
// Handle reading the pipeline data, if any
let response = match response {
PluginCallResponse::Error(err) => PluginCallResponse::Error(err),
PluginCallResponse::Signature(sigs) => PluginCallResponse::Signature(sigs),
PluginCallResponse::PipelineData(data) => {
// If there's an error with initializing this stream, change it to a plugin
// error response, but send it anyway
let exec_context = self.get_context(id)?;
let ctrlc = exec_context.as_ref().and_then(|c| c.0.ctrlc());
match self.read_pipeline_data(data, ctrlc) {
Ok(data) => PluginCallResponse::PipelineData(data),
Err(err) => PluginCallResponse::Error(err.into()),
}
}
};
self.send_plugin_call_response(id, response)
}
}
}
fn stream_manager(&self) -> &StreamManager {
&self.stream_manager
}
fn prepare_pipeline_data(&self, mut data: PipelineData) -> Result<PipelineData, ShellError> {
// Add source to any values
match data {
PipelineData::Value(ref mut value, _) => {
PluginCustomValue::add_source(value, &self.state.identity);
Ok(data)
}
PipelineData::ListStream(ListStream { stream, ctrlc, .. }, meta) => {
let identity = self.state.identity.clone();
Ok(stream
.map(move |mut value| {
PluginCustomValue::add_source(&mut value, &identity);
value
})
.into_pipeline_data_with_metadata(meta, ctrlc))
}
PipelineData::Empty | PipelineData::ExternalStream { .. } => Ok(data),
}
}
}
/// A reference through which a plugin can be interacted with during execution.
#[derive(Debug, Clone)]
pub(crate) struct PluginInterface {
/// Shared state
state: Arc<PluginInterfaceState>,
/// Handle to stream manager
stream_manager_handle: StreamManagerHandle,
}
impl PluginInterface {
/// Write the protocol info. This should be done after initialization
pub(crate) fn hello(&self) -> Result<(), ShellError> {
self.write(PluginInput::Hello(ProtocolInfo::default()))?;
self.flush()
}
/// Write a plugin call message. Returns the writer for the stream, and the receiver for
/// messages (e.g. response) related to the plugin call
fn write_plugin_call(
&self,
call: PluginCall<PipelineData>,
context: Option<Context>,
) -> Result<
(
PipelineDataWriter<Self>,
mpsc::Receiver<ReceivedPluginCallMessage>,
),
ShellError,
> {
let id = self.state.plugin_call_id_sequence.next()?;
let (tx, rx) = mpsc::channel();
// Convert the call into one with a header and handle the stream, if necessary
let (call, writer) = match call {
PluginCall::Signature => (PluginCall::Signature, Default::default()),
PluginCall::CustomValueOp(value, op) => {
(PluginCall::CustomValueOp(value, op), Default::default())
}
PluginCall::Run(CallInfo {
name,
call,
input,
config,
}) => {
let (header, writer) = self.init_write_pipeline_data(input)?;
(
PluginCall::Run(CallInfo {
name,
call,
input: header,
config,
}),
writer,
)
}
};
// Register the subscription to the response, and the context
self.state
.plugin_call_subscription_sender
.send((
id,
PluginCallSubscription {
sender: tx,
context,
},
))
.map_err(|_| ShellError::NushellFailed {
msg: "PluginInterfaceManager hung up and is no longer accepting plugin calls"
.into(),
})?;
// Write request
self.write(PluginInput::Call(id, call))?;
self.flush()?;
Ok((writer, rx))
}
/// Read the channel for plugin call messages and handle them until the response is received.
fn receive_plugin_call_response(
&self,
rx: mpsc::Receiver<ReceivedPluginCallMessage>,
) -> Result<PluginCallResponse<PipelineData>, ShellError> {
if let Ok(msg) = rx.recv() {
// Handle message from receiver
match msg {
ReceivedPluginCallMessage::Response(resp) => Ok(resp),
ReceivedPluginCallMessage::Error(err) => Err(err),
}
} else {
// If we fail to get a response
Err(ShellError::PluginFailedToDecode {
msg: "Failed to receive response to plugin call".into(),
})
}
}
/// Perform a plugin call. Input and output streams are handled automatically.
fn plugin_call(
&self,
call: PluginCall<PipelineData>,
context: &Option<Context>,
) -> Result<PluginCallResponse<PipelineData>, ShellError> {
let (writer, rx) = self.write_plugin_call(call, context.clone())?;
// Finish writing stream in the background
writer.write_background();
self.receive_plugin_call_response(rx)
}
/// Get the command signatures from the plugin.
pub(crate) fn get_signature(&self) -> Result<Vec<PluginSignature>, ShellError> {
match self.plugin_call(PluginCall::Signature, &None)? {
PluginCallResponse::Signature(sigs) => Ok(sigs),
PluginCallResponse::Error(err) => Err(err.into()),
_ => Err(ShellError::PluginFailedToDecode {
msg: "Received unexpected response to plugin Signature call".into(),
}),
}
}
/// Run the plugin with the given call and execution context.
pub(crate) fn run(
&self,
call: CallInfo<PipelineData>,
context: Arc<impl PluginExecutionContext + 'static>,
) -> Result<PipelineData, ShellError> {
let context = Some(Context(context));
match self.plugin_call(PluginCall::Run(call), &context)? {
PluginCallResponse::PipelineData(data) => Ok(data),
PluginCallResponse::Error(err) => Err(err.into()),
_ => Err(ShellError::PluginFailedToDecode {
msg: "Received unexpected response to plugin Run call".into(),
}),
}
}
/// Collapse a custom value to its base value.
pub(crate) fn custom_value_to_base_value(
&self,
value: Spanned<PluginCustomValue>,
) -> Result<Value, ShellError> {
let span = value.span;
let call = PluginCall::CustomValueOp(value, CustomValueOp::ToBaseValue);
match self.plugin_call(call, &None)? {
PluginCallResponse::PipelineData(out_data) => Ok(out_data.into_value(span)),
PluginCallResponse::Error(err) => Err(err.into()),
_ => Err(ShellError::PluginFailedToDecode {
msg: "Received unexpected response to plugin CustomValueOp::ToBaseValue call"
.into(),
}),
}
}
}
impl Interface for PluginInterface {
type Output = PluginInput;
fn write(&self, input: PluginInput) -> Result<(), ShellError> {
log::trace!("to plugin: {:?}", input);
self.state.writer.write(&input)
}
fn flush(&self) -> Result<(), ShellError> {
self.state.writer.flush()
}
fn stream_id_sequence(&self) -> &Sequence {
&self.state.stream_id_sequence
}
fn stream_manager_handle(&self) -> &StreamManagerHandle {
&self.stream_manager_handle
}
fn prepare_pipeline_data(&self, data: PipelineData) -> Result<PipelineData, ShellError> {
// Validate the destination of values in the pipeline data
match data {
PipelineData::Value(mut value, meta) => {
PluginCustomValue::verify_source(&mut value, &self.state.identity)?;
Ok(PipelineData::Value(value, meta))
}
PipelineData::ListStream(ListStream { stream, ctrlc, .. }, meta) => {
let identity = self.state.identity.clone();
Ok(stream
.map(move |mut value| {
match PluginCustomValue::verify_source(&mut value, &identity) {
Ok(()) => value,
// Put the error in the stream instead
Err(err) => Value::error(err, value.span()),
}
})
.into_pipeline_data_with_metadata(meta, ctrlc))
}
PipelineData::Empty | PipelineData::ExternalStream { .. } => Ok(data),
}
}
}

View File

@ -0,0 +1,842 @@
use std::sync::mpsc;
use nu_protocol::{
IntoInterruptiblePipelineData, PipelineData, PluginSignature, ShellError, Span, Spanned, Value,
};
use crate::{
plugin::{
context::PluginExecutionBogusContext,
interface::{test_util::TestCase, Interface, InterfaceManager},
PluginIdentity,
},
protocol::{
test_util::{expected_test_custom_value, test_plugin_custom_value},
CallInfo, CustomValueOp, ExternalStreamInfo, ListStreamInfo, PipelineDataHeader,
PluginCall, PluginCallId, PluginCustomValue, PluginInput, Protocol, ProtocolInfo,
RawStreamInfo, StreamData, StreamMessage,
},
EvaluatedCall, PluginCallResponse, PluginOutput,
};
use super::{
PluginCallSubscription, PluginInterface, PluginInterfaceManager, ReceivedPluginCallMessage,
};
#[test]
fn manager_consume_all_consumes_messages() -> Result<(), ShellError> {
let mut test = TestCase::new();
let mut manager = test.plugin("test");
// This message should be non-problematic
test.add(PluginOutput::Hello(ProtocolInfo::default()));
manager.consume_all(&mut test)?;
assert!(!test.has_unconsumed_read());
Ok(())
}
#[test]
fn manager_consume_all_exits_after_streams_and_interfaces_are_dropped() -> Result<(), ShellError> {
let mut test = TestCase::new();
let mut manager = test.plugin("test");
// Add messages that won't cause errors
for _ in 0..5 {
test.add(PluginOutput::Hello(ProtocolInfo::default()));
}
// Create a stream...
let stream = manager.read_pipeline_data(
PipelineDataHeader::ListStream(ListStreamInfo { id: 0 }),
None,
)?;
// and an interface...
let interface = manager.get_interface();
// Expect that is_finished is false
assert!(
!manager.is_finished(),
"is_finished is true even though active stream/interface exists"
);
// After dropping, it should be true
drop(stream);
drop(interface);
assert!(
manager.is_finished(),
"is_finished is false even though manager has no stream or interface"
);
// When it's true, consume_all shouldn't consume everything
manager.consume_all(&mut test)?;
assert!(
test.has_unconsumed_read(),
"consume_all consumed the messages"
);
Ok(())
}
fn test_io_error() -> ShellError {
ShellError::IOError {
msg: "test io error".into(),
}
}
fn check_test_io_error(error: &ShellError) {
assert!(
format!("{error:?}").contains("test io error"),
"error: {error}"
);
}
#[test]
fn manager_consume_all_propagates_io_error_to_readers() -> Result<(), ShellError> {
let mut test = TestCase::new();
let mut manager = test.plugin("test");
test.set_read_error(test_io_error());
let stream = manager.read_pipeline_data(
PipelineDataHeader::ListStream(ListStreamInfo { id: 0 }),
None,
)?;
manager
.consume_all(&mut test)
.expect_err("consume_all did not error");
// Ensure end of stream
drop(manager);
let value = stream.into_iter().next().expect("stream is empty");
if let Value::Error { error, .. } = value {
check_test_io_error(&error);
Ok(())
} else {
panic!("did not get an error");
}
}
fn invalid_output() -> PluginOutput {
// This should definitely cause an error, as 0.0.0 is not compatible with any version other than
// itself
PluginOutput::Hello(ProtocolInfo {
protocol: Protocol::NuPlugin,
version: "0.0.0".into(),
features: vec![],
})
}
fn check_invalid_output_error(error: &ShellError) {
// the error message should include something about the version...
assert!(format!("{error:?}").contains("0.0.0"), "error: {error}");
}
#[test]
fn manager_consume_all_propagates_message_error_to_readers() -> Result<(), ShellError> {
let mut test = TestCase::new();
let mut manager = test.plugin("test");
test.add(invalid_output());
let stream = manager.read_pipeline_data(
PipelineDataHeader::ExternalStream(ExternalStreamInfo {
span: Span::test_data(),
stdout: Some(RawStreamInfo {
id: 0,
is_binary: false,
known_size: None,
}),
stderr: None,
exit_code: None,
trim_end_newline: false,
}),
None,
)?;
manager
.consume_all(&mut test)
.expect_err("consume_all did not error");
// Ensure end of stream
drop(manager);
let value = stream.into_iter().next().expect("stream is empty");
if let Value::Error { error, .. } = value {
check_invalid_output_error(&error);
Ok(())
} else {
panic!("did not get an error");
}
}
fn fake_plugin_call(
manager: &mut PluginInterfaceManager,
id: PluginCallId,
) -> mpsc::Receiver<ReceivedPluginCallMessage> {
// Set up a fake plugin call subscription
let (tx, rx) = mpsc::channel();
manager.plugin_call_subscriptions.insert(
id,
PluginCallSubscription {
sender: tx,
context: None,
},
);
rx
}
#[test]
fn manager_consume_all_propagates_io_error_to_plugin_calls() -> Result<(), ShellError> {
let mut test = TestCase::new();
let mut manager = test.plugin("test");
let interface = manager.get_interface();
test.set_read_error(test_io_error());
// Set up a fake plugin call subscription
let rx = fake_plugin_call(&mut manager, 0);
manager
.consume_all(&mut test)
.expect_err("consume_all did not error");
// We have to hold interface until now otherwise consume_all won't try to process the message
drop(interface);
let message = rx.try_recv().expect("failed to get plugin call message");
match message {
ReceivedPluginCallMessage::Error(error) => {
check_test_io_error(&error);
Ok(())
}
_ => panic!("received something other than an error: {message:?}"),
}
}
#[test]
fn manager_consume_all_propagates_message_error_to_plugin_calls() -> Result<(), ShellError> {
let mut test = TestCase::new();
let mut manager = test.plugin("test");
let interface = manager.get_interface();
test.add(invalid_output());
// Set up a fake plugin call subscription
let rx = fake_plugin_call(&mut manager, 0);
manager
.consume_all(&mut test)
.expect_err("consume_all did not error");
// We have to hold interface until now otherwise consume_all won't try to process the message
drop(interface);
let message = rx.try_recv().expect("failed to get plugin call message");
match message {
ReceivedPluginCallMessage::Error(error) => {
check_invalid_output_error(&error);
Ok(())
}
_ => panic!("received something other than an error: {message:?}"),
}
}
#[test]
fn manager_consume_sets_protocol_info_on_hello() -> Result<(), ShellError> {
let mut manager = TestCase::new().plugin("test");
let info = ProtocolInfo::default();
manager.consume(PluginOutput::Hello(info.clone()))?;
let set_info = manager
.protocol_info
.as_ref()
.expect("protocol info not set");
assert_eq!(info.version, set_info.version);
Ok(())
}
#[test]
fn manager_consume_errors_on_wrong_nushell_version() -> Result<(), ShellError> {
let mut manager = TestCase::new().plugin("test");
let info = ProtocolInfo {
protocol: Protocol::NuPlugin,
version: "0.0.0".into(),
features: vec![],
};
manager
.consume(PluginOutput::Hello(info))
.expect_err("version 0.0.0 should cause an error");
Ok(())
}
#[test]
fn manager_consume_errors_on_sending_other_messages_before_hello() -> Result<(), ShellError> {
let mut manager = TestCase::new().plugin("test");
// hello not set
assert!(manager.protocol_info.is_none());
let error = manager
.consume(PluginOutput::Stream(StreamMessage::Drop(0)))
.expect_err("consume before Hello should cause an error");
assert!(format!("{error:?}").contains("Hello"));
Ok(())
}
#[test]
fn manager_consume_call_response_forwards_to_subscriber_with_pipeline_data(
) -> Result<(), ShellError> {
let mut manager = TestCase::new().plugin("test");
manager.protocol_info = Some(ProtocolInfo::default());
let rx = fake_plugin_call(&mut manager, 0);
manager.consume(PluginOutput::CallResponse(
0,
PluginCallResponse::PipelineData(PipelineDataHeader::ListStream(ListStreamInfo { id: 0 })),
))?;
for i in 0..2 {
manager.consume(PluginOutput::Stream(StreamMessage::Data(
0,
Value::test_int(i).into(),
)))?;
}
manager.consume(PluginOutput::Stream(StreamMessage::End(0)))?;
// Make sure the streams end and we don't deadlock
drop(manager);
let message = rx
.try_recv()
.expect("failed to get plugin call response message");
match message {
ReceivedPluginCallMessage::Response(response) => match response {
PluginCallResponse::PipelineData(data) => {
// Ensure we manage to receive the stream messages
assert_eq!(2, data.into_iter().count());
Ok(())
}
_ => panic!("unexpected response: {response:?}"),
},
_ => panic!("unexpected response message: {message:?}"),
}
}
#[test]
fn manager_prepare_pipeline_data_adds_source_to_values() -> Result<(), ShellError> {
let manager = TestCase::new().plugin("test");
let data = manager.prepare_pipeline_data(PipelineData::Value(
Value::test_custom_value(Box::new(test_plugin_custom_value())),
None,
))?;
let value = data
.into_iter()
.next()
.expect("prepared pipeline data is empty");
let custom_value: &PluginCustomValue = value
.as_custom_value()?
.as_any()
.downcast_ref()
.expect("custom value is not a PluginCustomValue");
if let Some(source) = &custom_value.source {
assert_eq!("test", source.plugin_name);
} else {
panic!("source was not set");
}
Ok(())
}
#[test]
fn manager_prepare_pipeline_data_adds_source_to_list_streams() -> Result<(), ShellError> {
let manager = TestCase::new().plugin("test");
let data = manager.prepare_pipeline_data(
[Value::test_custom_value(Box::new(
test_plugin_custom_value(),
))]
.into_pipeline_data(None),
)?;
let value = data
.into_iter()
.next()
.expect("prepared pipeline data is empty");
let custom_value: &PluginCustomValue = value
.as_custom_value()?
.as_any()
.downcast_ref()
.expect("custom value is not a PluginCustomValue");
if let Some(source) = &custom_value.source {
assert_eq!("test", source.plugin_name);
} else {
panic!("source was not set");
}
Ok(())
}
#[test]
fn interface_hello_sends_protocol_info() -> Result<(), ShellError> {
let test = TestCase::new();
let interface = test.plugin("test").get_interface();
interface.hello()?;
let written = test.next_written().expect("nothing written");
match written {
PluginInput::Hello(info) => {
assert_eq!(ProtocolInfo::default().version, info.version);
}
_ => panic!("unexpected message written: {written:?}"),
}
assert!(!test.has_unconsumed_write());
Ok(())
}
#[test]
fn interface_write_plugin_call_registers_subscription() -> Result<(), ShellError> {
let mut manager = TestCase::new().plugin("test");
assert!(
manager.plugin_call_subscriptions.is_empty(),
"plugin call subscriptions not empty before start of test"
);
let interface = manager.get_interface();
let _ = interface.write_plugin_call(PluginCall::Signature, None)?;
manager.receive_plugin_call_subscriptions();
assert!(
!manager.plugin_call_subscriptions.is_empty(),
"not registered"
);
Ok(())
}
#[test]
fn interface_write_plugin_call_writes_signature() -> Result<(), ShellError> {
let test = TestCase::new();
let manager = test.plugin("test");
let interface = manager.get_interface();
let (writer, _) = interface.write_plugin_call(PluginCall::Signature, None)?;
writer.write()?;
let written = test.next_written().expect("nothing written");
match written {
PluginInput::Call(_, call) => assert!(
matches!(call, PluginCall::Signature),
"not Signature: {call:?}"
),
_ => panic!("unexpected message written: {written:?}"),
}
Ok(())
}
#[test]
fn interface_write_plugin_call_writes_custom_value_op() -> Result<(), ShellError> {
let test = TestCase::new();
let manager = test.plugin("test");
let interface = manager.get_interface();
let (writer, _) = interface.write_plugin_call(
PluginCall::CustomValueOp(
Spanned {
item: test_plugin_custom_value(),
span: Span::test_data(),
},
CustomValueOp::ToBaseValue,
),
None,
)?;
writer.write()?;
let written = test.next_written().expect("nothing written");
match written {
PluginInput::Call(_, call) => assert!(
matches!(
call,
PluginCall::CustomValueOp(_, CustomValueOp::ToBaseValue)
),
"expected CustomValueOp(_, ToBaseValue), got {call:?}"
),
_ => panic!("unexpected message written: {written:?}"),
}
Ok(())
}
#[test]
fn interface_write_plugin_call_writes_run_with_value_input() -> Result<(), ShellError> {
let test = TestCase::new();
let manager = test.plugin("test");
let interface = manager.get_interface();
let (writer, _) = interface.write_plugin_call(
PluginCall::Run(CallInfo {
name: "foo".into(),
call: EvaluatedCall {
head: Span::test_data(),
positional: vec![],
named: vec![],
},
input: PipelineData::Value(Value::test_int(-1), None),
config: None,
}),
None,
)?;
writer.write()?;
let written = test.next_written().expect("nothing written");
match written {
PluginInput::Call(_, call) => match call {
PluginCall::Run(CallInfo { name, input, .. }) => {
assert_eq!("foo", name);
match input {
PipelineDataHeader::Value(value) => assert_eq!(-1, value.as_int()?),
_ => panic!("unexpected input header: {input:?}"),
}
}
_ => panic!("unexpected Call: {call:?}"),
},
_ => panic!("unexpected message written: {written:?}"),
}
Ok(())
}
#[test]
fn interface_write_plugin_call_writes_run_with_stream_input() -> Result<(), ShellError> {
let test = TestCase::new();
let manager = test.plugin("test");
let interface = manager.get_interface();
let values = vec![Value::test_int(1), Value::test_int(2)];
let (writer, _) = interface.write_plugin_call(
PluginCall::Run(CallInfo {
name: "foo".into(),
call: EvaluatedCall {
head: Span::test_data(),
positional: vec![],
named: vec![],
},
input: values.clone().into_pipeline_data(None),
config: None,
}),
None,
)?;
writer.write()?;
let written = test.next_written().expect("nothing written");
let info = match written {
PluginInput::Call(_, call) => match call {
PluginCall::Run(CallInfo { name, input, .. }) => {
assert_eq!("foo", name);
match input {
PipelineDataHeader::ListStream(info) => info,
_ => panic!("unexpected input header: {input:?}"),
}
}
_ => panic!("unexpected Call: {call:?}"),
},
_ => panic!("unexpected message written: {written:?}"),
};
// Expect stream messages
for value in values {
match test
.next_written()
.expect("failed to get Data stream message")
{
PluginInput::Stream(StreamMessage::Data(id, data)) => {
assert_eq!(info.id, id, "id");
match data {
StreamData::List(data_value) => {
assert_eq!(value, data_value, "wrong value in Data message");
}
_ => panic!("not List stream data: {data:?}"),
}
}
message => panic!("expected Stream(Data(..)) message: {message:?}"),
}
}
match test
.next_written()
.expect("failed to get End stream message")
{
PluginInput::Stream(StreamMessage::End(id)) => {
assert_eq!(info.id, id, "id");
}
message => panic!("expected Stream(End(_)) message: {message:?}"),
}
Ok(())
}
#[test]
fn interface_receive_plugin_call_receives_response() -> Result<(), ShellError> {
let interface = TestCase::new().plugin("test").get_interface();
// Set up a fake channel that has the response in it
let (tx, rx) = mpsc::channel();
tx.send(ReceivedPluginCallMessage::Response(
PluginCallResponse::Signature(vec![]),
))
.expect("failed to send on new channel");
drop(tx); // so we don't deadlock on recv()
let response = interface.receive_plugin_call_response(rx)?;
assert!(
matches!(response, PluginCallResponse::Signature(_)),
"wrong response: {response:?}"
);
Ok(())
}
#[test]
fn interface_receive_plugin_call_receives_error() -> Result<(), ShellError> {
let interface = TestCase::new().plugin("test").get_interface();
// Set up a fake channel that has the error in it
let (tx, rx) = mpsc::channel();
tx.send(ReceivedPluginCallMessage::Error(
ShellError::ExternalNotSupported {
span: Span::test_data(),
},
))
.expect("failed to send on new channel");
drop(tx); // so we don't deadlock on recv()
let error = interface
.receive_plugin_call_response(rx)
.expect_err("did not receive error");
assert!(
matches!(error, ShellError::ExternalNotSupported { .. }),
"wrong error: {error:?}"
);
Ok(())
}
/// Fake responses to requests for plugin call messages
fn start_fake_plugin_call_responder(
manager: PluginInterfaceManager,
take: usize,
mut f: impl FnMut(PluginCallId) -> Vec<ReceivedPluginCallMessage> + Send + 'static,
) {
std::thread::Builder::new()
.name("fake plugin call responder".into())
.spawn(move || {
for (id, sub) in manager
.plugin_call_subscription_receiver
.into_iter()
.take(take)
{
for message in f(id) {
sub.sender.send(message).expect("failed to send");
}
}
})
.expect("failed to spawn thread");
}
#[test]
fn interface_get_signature() -> Result<(), ShellError> {
let test = TestCase::new();
let manager = test.plugin("test");
let interface = manager.get_interface();
start_fake_plugin_call_responder(manager, 1, |_| {
vec![ReceivedPluginCallMessage::Response(
PluginCallResponse::Signature(vec![PluginSignature::build("test")]),
)]
});
let signatures = interface.get_signature()?;
assert_eq!(1, signatures.len());
assert!(test.has_unconsumed_write());
Ok(())
}
#[test]
fn interface_run() -> Result<(), ShellError> {
let test = TestCase::new();
let manager = test.plugin("test");
let interface = manager.get_interface();
let number = 64;
start_fake_plugin_call_responder(manager, 1, move |_| {
vec![ReceivedPluginCallMessage::Response(
PluginCallResponse::PipelineData(PipelineData::Value(Value::test_int(number), None)),
)]
});
let result = interface.run(
CallInfo {
name: "bogus".into(),
call: EvaluatedCall {
head: Span::test_data(),
positional: vec![],
named: vec![],
},
input: PipelineData::Empty,
config: None,
},
PluginExecutionBogusContext.into(),
)?;
assert_eq!(
Value::test_int(number),
result.into_value(Span::test_data())
);
assert!(test.has_unconsumed_write());
Ok(())
}
#[test]
fn interface_custom_value_to_base_value() -> Result<(), ShellError> {
let test = TestCase::new();
let manager = test.plugin("test");
let interface = manager.get_interface();
let string = "this is a test";
start_fake_plugin_call_responder(manager, 1, move |_| {
vec![ReceivedPluginCallMessage::Response(
PluginCallResponse::PipelineData(PipelineData::Value(Value::test_string(string), None)),
)]
});
let result = interface.custom_value_to_base_value(Spanned {
item: test_plugin_custom_value(),
span: Span::test_data(),
})?;
assert_eq!(Value::test_string(string), result);
assert!(test.has_unconsumed_write());
Ok(())
}
fn normal_values(interface: &PluginInterface) -> Vec<Value> {
vec![
Value::test_int(5),
Value::test_custom_value(Box::new(PluginCustomValue {
name: "SomeTest".into(),
data: vec![1, 2, 3],
// Has the same source, so it should be accepted
source: Some(interface.state.identity.clone()),
})),
]
}
#[test]
fn interface_prepare_pipeline_data_accepts_normal_values() -> Result<(), ShellError> {
let interface = TestCase::new().plugin("test").get_interface();
for value in normal_values(&interface) {
match interface.prepare_pipeline_data(PipelineData::Value(value.clone(), None)) {
Ok(data) => assert_eq!(
value.get_type(),
data.into_value(Span::test_data()).get_type()
),
Err(err) => panic!("failed to accept {value:?}: {err}"),
}
}
Ok(())
}
#[test]
fn interface_prepare_pipeline_data_accepts_normal_streams() -> Result<(), ShellError> {
let interface = TestCase::new().plugin("test").get_interface();
let values = normal_values(&interface);
let data = interface.prepare_pipeline_data(values.clone().into_pipeline_data(None))?;
let mut count = 0;
for (expected_value, actual_value) in values.iter().zip(data) {
assert!(
!actual_value.is_error(),
"error value instead of {expected_value:?} in stream: {actual_value:?}"
);
assert_eq!(expected_value.get_type(), actual_value.get_type());
count += 1;
}
assert_eq!(
values.len(),
count,
"didn't receive as many values as expected"
);
Ok(())
}
fn bad_custom_values() -> Vec<Value> {
// These shouldn't be accepted
vec![
// Native custom value (not PluginCustomValue) should be rejected
Value::test_custom_value(Box::new(expected_test_custom_value())),
// Has no source, so it should be rejected
Value::test_custom_value(Box::new(PluginCustomValue {
name: "SomeTest".into(),
data: vec![1, 2, 3],
source: None,
})),
// Has a different source, so it should be rejected
Value::test_custom_value(Box::new(PluginCustomValue {
name: "SomeTest".into(),
data: vec![1, 2, 3],
source: Some(PluginIdentity::new_fake("pluto")),
})),
]
}
#[test]
fn interface_prepare_pipeline_data_rejects_bad_custom_value() -> Result<(), ShellError> {
let interface = TestCase::new().plugin("test").get_interface();
for value in bad_custom_values() {
match interface.prepare_pipeline_data(PipelineData::Value(value.clone(), None)) {
Err(err) => match err {
ShellError::CustomValueIncorrectForPlugin { .. } => (),
_ => panic!("expected error type CustomValueIncorrectForPlugin, but got {err:?}"),
},
Ok(_) => panic!("mistakenly accepted {value:?}"),
}
}
Ok(())
}
#[test]
fn interface_prepare_pipeline_data_rejects_bad_custom_value_in_a_stream() -> Result<(), ShellError>
{
let interface = TestCase::new().plugin("test").get_interface();
let values = bad_custom_values();
let data = interface.prepare_pipeline_data(values.clone().into_pipeline_data(None))?;
let mut count = 0;
for value in data {
assert!(value.is_error(), "expected error value for {value:?}");
count += 1;
}
assert_eq!(
values.len(),
count,
"didn't receive as many values as expected"
);
Ok(())
}

View File

@ -0,0 +1,621 @@
use std::{
collections::{btree_map, BTreeMap},
iter::FusedIterator,
marker::PhantomData,
sync::{mpsc, Arc, Condvar, Mutex, MutexGuard, Weak},
};
use nu_protocol::{ShellError, Span, Value};
use crate::protocol::{StreamData, StreamId, StreamMessage};
#[cfg(test)]
mod tests;
/// Receives messages from a stream read from input by a [`StreamManager`].
///
/// The receiver reads for messages of type `Result<Option<StreamData>, ShellError>` from the
/// channel, which is managed by a [`StreamManager`]. Signalling for end-of-stream is explicit
/// through `Ok(Some)`.
///
/// Failing to receive is an error. When end-of-stream is received, the `receiver` is set to `None`
/// and all further calls to `next()` return `None`.
///
/// The type `T` must implement [`FromShellError`], so that errors in the stream can be represented,
/// and `TryFrom<StreamData>` to convert it to the correct type.
///
/// For each message read, it sends [`StreamMessage::Ack`] to the writer. When dropped,
/// it sends [`StreamMessage::Drop`].
#[derive(Debug)]
pub(crate) struct StreamReader<T, W>
where
W: WriteStreamMessage,
{
id: StreamId,
receiver: Option<mpsc::Receiver<Result<Option<StreamData>, ShellError>>>,
writer: W,
/// Iterator requires the item type to be fixed, so we have to keep it as part of the type,
/// even though we're actually receiving dynamic data.
marker: PhantomData<fn() -> T>,
}
impl<T, W> StreamReader<T, W>
where
T: TryFrom<StreamData, Error = ShellError>,
W: WriteStreamMessage,
{
/// Create a new StreamReader from parts
pub(crate) fn new(
id: StreamId,
receiver: mpsc::Receiver<Result<Option<StreamData>, ShellError>>,
writer: W,
) -> StreamReader<T, W> {
StreamReader {
id,
receiver: Some(receiver),
writer,
marker: PhantomData,
}
}
/// Receive a message from the channel, or return an error if:
///
/// * the channel couldn't be received from
/// * an error was sent on the channel
/// * the message received couldn't be converted to `T`
pub(crate) fn recv(&mut self) -> Result<Option<T>, ShellError> {
let connection_lost = || ShellError::GenericError {
error: "Stream ended unexpectedly".into(),
msg: "connection lost before explicit end of stream".into(),
span: None,
help: None,
inner: vec![],
};
if let Some(ref rx) = self.receiver {
// Try to receive a message first
let msg = match rx.try_recv() {
Ok(msg) => msg?,
Err(mpsc::TryRecvError::Empty) => {
// The receiver doesn't have any messages waiting for us. It's possible that the
// other side hasn't seen our acknowledgements. Let's flush the writer and then
// wait
self.writer.flush()?;
rx.recv().map_err(|_| connection_lost())??
}
Err(mpsc::TryRecvError::Disconnected) => return Err(connection_lost()),
};
if let Some(data) = msg {
// Acknowledge the message
self.writer
.write_stream_message(StreamMessage::Ack(self.id))?;
// Try to convert it into the correct type
Ok(Some(data.try_into()?))
} else {
// Remove the receiver, so that future recv() calls always return Ok(None)
self.receiver = None;
Ok(None)
}
} else {
// Closed already
Ok(None)
}
}
}
impl<T, W> Iterator for StreamReader<T, W>
where
T: FromShellError + TryFrom<StreamData, Error = ShellError>,
W: WriteStreamMessage,
{
type Item = T;
fn next(&mut self) -> Option<T> {
// Converting the error to the value here makes the implementation a lot easier
self.recv()
.unwrap_or_else(|err| Some(T::from_shell_error(err)))
}
}
// Guaranteed not to return anything after the end
impl<T, W> FusedIterator for StreamReader<T, W>
where
T: FromShellError + TryFrom<StreamData, Error = ShellError>,
W: WriteStreamMessage,
{
}
impl<T, W> Drop for StreamReader<T, W>
where
W: WriteStreamMessage,
{
fn drop(&mut self) {
if let Err(err) = self
.writer
.write_stream_message(StreamMessage::Drop(self.id))
.and_then(|_| self.writer.flush())
{
log::warn!("Failed to send message to drop stream: {err}");
}
}
}
/// Values that can contain a `ShellError` to signal an error has occurred.
pub(crate) trait FromShellError {
fn from_shell_error(err: ShellError) -> Self;
}
// For List streams.
impl FromShellError for Value {
fn from_shell_error(err: ShellError) -> Self {
Value::error(err, Span::unknown())
}
}
// For Raw streams, mostly.
impl<T> FromShellError for Result<T, ShellError> {
fn from_shell_error(err: ShellError) -> Self {
Err(err)
}
}
/// Writes messages to a stream, with flow control.
///
/// The `signal` contained
#[derive(Debug)]
pub(crate) struct StreamWriter<W: WriteStreamMessage> {
id: StreamId,
signal: Arc<StreamWriterSignal>,
writer: W,
ended: bool,
}
impl<W> StreamWriter<W>
where
W: WriteStreamMessage,
{
pub(crate) fn new(id: StreamId, signal: Arc<StreamWriterSignal>, writer: W) -> StreamWriter<W> {
StreamWriter {
id,
signal,
writer,
ended: false,
}
}
/// Check if the stream was dropped from the other end. Recommended to do this before calling
/// [`.write()`], especially in a loop.
pub(crate) fn is_dropped(&self) -> Result<bool, ShellError> {
self.signal.is_dropped()
}
/// Write a single piece of data to the stream.
///
/// Error if something failed with the write, or if [`.end()`] was already called
/// previously.
pub(crate) fn write(&mut self, data: impl Into<StreamData>) -> Result<(), ShellError> {
if !self.ended {
self.writer
.write_stream_message(StreamMessage::Data(self.id, data.into()))?;
// This implements flow control, so we don't write too many messages:
if !self.signal.notify_sent()? {
// Flush the output, and then wait for acknowledgements
self.writer.flush()?;
self.signal.wait_for_drain()
} else {
Ok(())
}
} else {
Err(ShellError::GenericError {
error: "Wrote to a stream after it ended".into(),
msg: format!(
"tried to write to stream {} after it was already ended",
self.id
),
span: None,
help: Some("this may be a bug in the nu-plugin crate".into()),
inner: vec![],
})
}
}
/// Write a full iterator to the stream. Note that this doesn't end the stream, so you should
/// still call [`.end()`].
///
/// If the stream is dropped from the other end, the iterator will not be fully consumed, and
/// writing will terminate.
///
/// Returns `Ok(true)` if the iterator was fully consumed, or `Ok(false)` if a drop interrupted
/// the stream from the other side.
pub(crate) fn write_all<T>(
&mut self,
data: impl IntoIterator<Item = T>,
) -> Result<bool, ShellError>
where
T: Into<StreamData>,
{
// Check before starting
if self.is_dropped()? {
return Ok(false);
}
for item in data {
// Check again after each item is consumed from the iterator, just in case the iterator
// takes a while to produce a value
if self.is_dropped()? {
return Ok(false);
}
self.write(item)?;
}
Ok(true)
}
/// End the stream. Recommend doing this instead of relying on `Drop` so that you can catch the
/// error.
pub(crate) fn end(&mut self) -> Result<(), ShellError> {
if !self.ended {
// Set the flag first so we don't double-report in the Drop
self.ended = true;
self.writer
.write_stream_message(StreamMessage::End(self.id))?;
self.writer.flush()
} else {
Ok(())
}
}
}
impl<W> Drop for StreamWriter<W>
where
W: WriteStreamMessage,
{
fn drop(&mut self) {
// Make sure we ended the stream
if let Err(err) = self.end() {
log::warn!("Error while ending stream in Drop for StreamWriter: {err}");
}
}
}
/// Stores stream state for a writer, and can be blocked on to wait for messages to be acknowledged.
/// A key part of managing stream lifecycle and flow control.
#[derive(Debug)]
pub(crate) struct StreamWriterSignal {
mutex: Mutex<StreamWriterSignalState>,
change_cond: Condvar,
}
#[derive(Debug)]
pub(crate) struct StreamWriterSignalState {
/// Stream has been dropped and consumer is no longer interested in any messages.
dropped: bool,
/// Number of messages that have been sent without acknowledgement.
unacknowledged: i32,
/// Max number of messages to send before waiting for acknowledgement.
high_pressure_mark: i32,
}
impl StreamWriterSignal {
/// Create a new signal.
///
/// If `notify_sent()` is called more than `high_pressure_mark` times, it will wait until
/// `notify_acknowledge()` is called by another thread enough times to bring the number of
/// unacknowledged sent messages below that threshold.
pub fn new(high_pressure_mark: i32) -> StreamWriterSignal {
assert!(high_pressure_mark > 0);
StreamWriterSignal {
mutex: Mutex::new(StreamWriterSignalState {
dropped: false,
unacknowledged: 0,
high_pressure_mark,
}),
change_cond: Condvar::new(),
}
}
fn lock(&self) -> Result<MutexGuard<StreamWriterSignalState>, ShellError> {
self.mutex.lock().map_err(|_| ShellError::NushellFailed {
msg: "StreamWriterSignal mutex poisoned due to panic".into(),
})
}
/// True if the stream was dropped and the consumer is no longer interested in it. Indicates
/// that no more messages should be sent, other than `End`.
pub fn is_dropped(&self) -> Result<bool, ShellError> {
Ok(self.lock()?.dropped)
}
/// Notify the writers that the stream has been dropped, so they can stop writing.
pub fn set_dropped(&self) -> Result<(), ShellError> {
let mut state = self.lock()?;
state.dropped = true;
// Unblock the writers so they can terminate
self.change_cond.notify_all();
Ok(())
}
/// Track that a message has been sent. Returns `Ok(true)` if more messages can be sent,
/// or `Ok(false)` if the high pressure mark has been reached and [`.wait_for_drain()`] should
/// be called to block.
pub fn notify_sent(&self) -> Result<bool, ShellError> {
let mut state = self.lock()?;
state.unacknowledged =
state
.unacknowledged
.checked_add(1)
.ok_or_else(|| ShellError::NushellFailed {
msg: "Overflow in counter: too many unacknowledged messages".into(),
})?;
Ok(state.unacknowledged < state.high_pressure_mark)
}
/// Wait for acknowledgements before sending more data. Also returns if the stream is dropped.
pub fn wait_for_drain(&self) -> Result<(), ShellError> {
let mut state = self.lock()?;
while !state.dropped && state.unacknowledged >= state.high_pressure_mark {
state = self
.change_cond
.wait(state)
.map_err(|_| ShellError::NushellFailed {
msg: "StreamWriterSignal mutex poisoned due to panic".into(),
})?;
}
Ok(())
}
/// Notify the writers that a message has been acknowledged, so they can continue to write
/// if they were waiting.
pub fn notify_acknowledged(&self) -> Result<(), ShellError> {
let mut state = self.lock()?;
state.unacknowledged =
state
.unacknowledged
.checked_sub(1)
.ok_or_else(|| ShellError::NushellFailed {
msg: "Underflow in counter: too many message acknowledgements".into(),
})?;
// Unblock the writer
self.change_cond.notify_one();
Ok(())
}
}
/// A sink for a [`StreamMessage`]
pub(crate) trait WriteStreamMessage {
fn write_stream_message(&mut self, msg: StreamMessage) -> Result<(), ShellError>;
fn flush(&mut self) -> Result<(), ShellError>;
}
#[derive(Debug, Default)]
struct StreamManagerState {
reading_streams: BTreeMap<StreamId, mpsc::Sender<Result<Option<StreamData>, ShellError>>>,
writing_streams: BTreeMap<StreamId, Weak<StreamWriterSignal>>,
}
impl StreamManagerState {
/// Lock the state, or return a [`ShellError`] if the mutex is poisoned.
fn lock(
state: &Mutex<StreamManagerState>,
) -> Result<MutexGuard<StreamManagerState>, ShellError> {
state.lock().map_err(|_| ShellError::NushellFailed {
msg: "StreamManagerState mutex poisoned due to a panic".into(),
})
}
}
#[derive(Debug)]
pub(crate) struct StreamManager {
state: Arc<Mutex<StreamManagerState>>,
}
impl StreamManager {
/// Create a new StreamManager.
pub(crate) fn new() -> StreamManager {
StreamManager {
state: Default::default(),
}
}
fn lock(&self) -> Result<MutexGuard<StreamManagerState>, ShellError> {
StreamManagerState::lock(&self.state)
}
/// Create a new handle to the StreamManager for registering streams.
pub(crate) fn get_handle(&self) -> StreamManagerHandle {
StreamManagerHandle {
state: Arc::downgrade(&self.state),
}
}
/// Process a stream message, and update internal state accordingly.
pub(crate) fn handle_message(&self, message: StreamMessage) -> Result<(), ShellError> {
let mut state = self.lock()?;
match message {
StreamMessage::Data(id, data) => {
if let Some(sender) = state.reading_streams.get(&id) {
// We should ignore the error on send. This just means the reader has dropped,
// but it will have sent a Drop message to the other side, and we will receive
// an End message at which point we can remove the channel.
let _ = sender.send(Ok(Some(data)));
Ok(())
} else {
Err(ShellError::PluginFailedToDecode {
msg: format!("received Data for unknown stream {id}"),
})
}
}
StreamMessage::End(id) => {
if let Some(sender) = state.reading_streams.remove(&id) {
// We should ignore the error on the send, because the reader might have dropped
// already
let _ = sender.send(Ok(None));
Ok(())
} else {
Err(ShellError::PluginFailedToDecode {
msg: format!("received End for unknown stream {id}"),
})
}
}
StreamMessage::Drop(id) => {
if let Some(signal) = state.writing_streams.remove(&id) {
if let Some(signal) = signal.upgrade() {
// This will wake blocked writers so they can stop writing, so it's ok
signal.set_dropped()?;
}
}
// It's possible that the stream has already finished writing and we don't have it
// anymore, so we fall through to Ok
Ok(())
}
StreamMessage::Ack(id) => {
if let Some(signal) = state.writing_streams.get(&id) {
if let Some(signal) = signal.upgrade() {
// This will wake up a blocked writer
signal.notify_acknowledged()?;
} else {
// We know it doesn't exist, so might as well remove it
state.writing_streams.remove(&id);
}
}
// It's possible that the stream has already finished writing and we don't have it
// anymore, so we fall through to Ok
Ok(())
}
}
}
/// Broadcast an error to all stream readers. This is useful for error propagation.
pub(crate) fn broadcast_read_error(&self, error: ShellError) -> Result<(), ShellError> {
let state = self.lock()?;
for channel in state.reading_streams.values() {
// Ignore send errors.
let _ = channel.send(Err(error.clone()));
}
Ok(())
}
// If the `StreamManager` is dropped, we should let all of the stream writers know that they
// won't be able to write anymore. We don't need to do anything about the readers though
// because they'll know when the `Sender` is dropped automatically
fn drop_all_writers(&self) -> Result<(), ShellError> {
let mut state = self.lock()?;
let writers = std::mem::take(&mut state.writing_streams);
for (_, signal) in writers {
if let Some(signal) = signal.upgrade() {
// more important that we send to all than handling an error
let _ = signal.set_dropped();
}
}
Ok(())
}
}
impl Drop for StreamManager {
fn drop(&mut self) {
if let Err(err) = self.drop_all_writers() {
log::warn!("error during Drop for StreamManager: {}", err)
}
}
}
/// A [`StreamManagerHandle`] supports operations for interacting with the [`StreamManager`].
///
/// Streams can be registered for reading, returning a [`StreamReader`], or for writing, returning
/// a [`StreamWriter`].
#[derive(Debug, Clone)]
pub(crate) struct StreamManagerHandle {
state: Weak<Mutex<StreamManagerState>>,
}
impl StreamManagerHandle {
/// Because the handle only has a weak reference to the [`StreamManager`] state, we have to
/// first try to upgrade to a strong reference and then lock. This function wraps those two
/// operations together, handling errors appropriately.
fn with_lock<T, F>(&self, f: F) -> Result<T, ShellError>
where
F: FnOnce(MutexGuard<StreamManagerState>) -> Result<T, ShellError>,
{
let upgraded = self
.state
.upgrade()
.ok_or_else(|| ShellError::NushellFailed {
msg: "StreamManager is no longer alive".into(),
})?;
let guard = upgraded.lock().map_err(|_| ShellError::NushellFailed {
msg: "StreamManagerState mutex poisoned due to a panic".into(),
})?;
f(guard)
}
/// Register a new stream for reading, and return a [`StreamReader`] that can be used to iterate
/// on the values received. A [`StreamMessage`] writer is required for writing control messages
/// back to the producer.
pub(crate) fn read_stream<T, W>(
&self,
id: StreamId,
writer: W,
) -> Result<StreamReader<T, W>, ShellError>
where
T: TryFrom<StreamData, Error = ShellError>,
W: WriteStreamMessage,
{
let (tx, rx) = mpsc::channel();
self.with_lock(|mut state| {
// Must be exclusive
if let btree_map::Entry::Vacant(e) = state.reading_streams.entry(id) {
e.insert(tx);
Ok(())
} else {
Err(ShellError::GenericError {
error: format!("Failed to acquire reader for stream {id}"),
msg: "tried to get a reader for a stream that's already being read".into(),
span: None,
help: Some("this may be a bug in the nu-plugin crate".into()),
inner: vec![],
})
}
})?;
Ok(StreamReader::new(id, rx, writer))
}
/// Register a new stream for writing, and return a [`StreamWriter`] that can be used to send
/// data to the stream.
///
/// The `high_pressure_mark` value controls how many messages can be written without receiving
/// an acknowledgement before any further attempts to write will wait for the consumer to
/// acknowledge them. This prevents overwhelming the reader.
pub(crate) fn write_stream<W>(
&self,
id: StreamId,
writer: W,
high_pressure_mark: i32,
) -> Result<StreamWriter<W>, ShellError>
where
W: WriteStreamMessage,
{
let signal = Arc::new(StreamWriterSignal::new(high_pressure_mark));
self.with_lock(|mut state| {
// Remove dead writing streams
state
.writing_streams
.retain(|_, signal| signal.strong_count() > 0);
// Must be exclusive
if let btree_map::Entry::Vacant(e) = state.writing_streams.entry(id) {
e.insert(Arc::downgrade(&signal));
Ok(())
} else {
Err(ShellError::GenericError {
error: format!("Failed to acquire writer for stream {id}"),
msg: "tried to get a writer for a stream that's already being written".into(),
span: None,
help: Some("this may be a bug in the nu-plugin crate".into()),
inner: vec![],
})
}
})?;
Ok(StreamWriter::new(id, signal, writer))
}
}

View File

@ -0,0 +1,508 @@
use std::{
sync::{
atomic::{AtomicBool, Ordering::Relaxed},
mpsc, Arc,
},
time::Duration,
};
use nu_protocol::{ShellError, Value};
use crate::protocol::{StreamData, StreamMessage};
use super::{StreamManager, StreamReader, StreamWriter, StreamWriterSignal, WriteStreamMessage};
// Should be long enough to definitely complete any quick operation, but not so long that tests are
// slow to complete. 10 ms is a pretty long time
const WAIT_DURATION: Duration = Duration::from_millis(10);
#[derive(Debug, Clone, Default)]
struct TestSink(Vec<StreamMessage>);
impl WriteStreamMessage for TestSink {
fn write_stream_message(&mut self, msg: StreamMessage) -> Result<(), ShellError> {
self.0.push(msg);
Ok(())
}
fn flush(&mut self) -> Result<(), ShellError> {
Ok(())
}
}
impl WriteStreamMessage for mpsc::Sender<StreamMessage> {
fn write_stream_message(&mut self, msg: StreamMessage) -> Result<(), ShellError> {
self.send(msg).map_err(|err| ShellError::NushellFailed {
msg: err.to_string(),
})
}
fn flush(&mut self) -> Result<(), ShellError> {
Ok(())
}
}
#[test]
fn reader_recv_list_messages() -> Result<(), ShellError> {
let (tx, rx) = mpsc::channel();
let mut reader = StreamReader::new(0, rx, TestSink::default());
tx.send(Ok(Some(StreamData::List(Value::test_int(5)))))
.unwrap();
drop(tx);
assert_eq!(Some(Value::test_int(5)), reader.recv()?);
Ok(())
}
#[test]
fn list_reader_recv_wrong_type() -> Result<(), ShellError> {
let (tx, rx) = mpsc::channel();
let mut reader = StreamReader::<Value, _>::new(0, rx, TestSink::default());
tx.send(Ok(Some(StreamData::Raw(Ok(vec![10, 20])))))
.unwrap();
tx.send(Ok(Some(StreamData::List(Value::test_nothing()))))
.unwrap();
drop(tx);
reader.recv().expect_err("should be an error");
reader.recv().expect("should be able to recover");
Ok(())
}
#[test]
fn reader_recv_raw_messages() -> Result<(), ShellError> {
let (tx, rx) = mpsc::channel();
let mut reader = StreamReader::new(0, rx, TestSink::default());
tx.send(Ok(Some(StreamData::Raw(Ok(vec![10, 20])))))
.unwrap();
drop(tx);
assert_eq!(Some(vec![10, 20]), reader.recv()?.transpose()?);
Ok(())
}
#[test]
fn raw_reader_recv_wrong_type() -> Result<(), ShellError> {
let (tx, rx) = mpsc::channel();
let mut reader =
StreamReader::<Result<Vec<u8>, ShellError>, _>::new(0, rx, TestSink::default());
tx.send(Ok(Some(StreamData::List(Value::test_nothing()))))
.unwrap();
tx.send(Ok(Some(StreamData::Raw(Ok(vec![10, 20])))))
.unwrap();
drop(tx);
reader.recv().expect_err("should be an error");
reader.recv().expect("should be able to recover");
Ok(())
}
#[test]
fn reader_recv_acknowledge() -> Result<(), ShellError> {
let (tx, rx) = mpsc::channel();
let mut reader = StreamReader::<Value, _>::new(0, rx, TestSink::default());
tx.send(Ok(Some(StreamData::List(Value::test_int(5)))))
.unwrap();
tx.send(Ok(Some(StreamData::List(Value::test_int(6)))))
.unwrap();
drop(tx);
reader.recv()?;
reader.recv()?;
let wrote = &reader.writer.0;
assert!(wrote.len() >= 2);
assert!(
matches!(wrote[0], StreamMessage::Ack(0)),
"0 = {:?}",
wrote[0]
);
assert!(
matches!(wrote[1], StreamMessage::Ack(0)),
"1 = {:?}",
wrote[1]
);
Ok(())
}
#[test]
fn reader_recv_end_of_stream() -> Result<(), ShellError> {
let (tx, rx) = mpsc::channel();
let mut reader = StreamReader::<Value, _>::new(0, rx, TestSink::default());
tx.send(Ok(Some(StreamData::List(Value::test_int(5)))))
.unwrap();
tx.send(Ok(None)).unwrap();
drop(tx);
assert!(reader.recv()?.is_some(), "actual message");
assert!(reader.recv()?.is_none(), "on close");
assert!(reader.recv()?.is_none(), "after close");
Ok(())
}
#[test]
fn reader_drop() {
let (_tx, rx) = mpsc::channel();
// Flag set if drop message is received.
struct Check(Arc<AtomicBool>);
impl WriteStreamMessage for Check {
fn write_stream_message(&mut self, msg: StreamMessage) -> Result<(), ShellError> {
assert!(matches!(msg, StreamMessage::Drop(1)), "got {:?}", msg);
self.0.store(true, Relaxed);
Ok(())
}
fn flush(&mut self) -> Result<(), ShellError> {
Ok(())
}
}
let flag = Arc::new(AtomicBool::new(false));
let reader = StreamReader::<Value, _>::new(1, rx, Check(flag.clone()));
drop(reader);
assert!(flag.load(Relaxed));
}
#[test]
fn writer_write_all_stops_if_dropped() -> Result<(), ShellError> {
let signal = Arc::new(StreamWriterSignal::new(20));
let id = 1337;
let mut writer = StreamWriter::new(id, signal.clone(), TestSink::default());
// Simulate this by having it consume a stream that will actually do the drop halfway through
let iter = (0..5).map(Value::test_int).chain({
let mut n = 5;
std::iter::from_fn(move || {
// produces numbers 5..10, but drops for the first one
if n == 5 {
signal.set_dropped().unwrap();
}
if n < 10 {
let value = Value::test_int(n);
n += 1;
Some(value)
} else {
None
}
})
});
writer.write_all(iter)?;
assert!(writer.is_dropped()?);
let wrote = &writer.writer.0;
assert_eq!(5, wrote.len(), "length wrong: {wrote:?}");
for (n, message) in (0..5).zip(wrote) {
match message {
StreamMessage::Data(msg_id, StreamData::List(value)) => {
assert_eq!(id, *msg_id, "id");
assert_eq!(Value::test_int(n), *value, "value");
}
other => panic!("unexpected message: {other:?}"),
}
}
Ok(())
}
#[test]
fn writer_end() -> Result<(), ShellError> {
let signal = Arc::new(StreamWriterSignal::new(20));
let mut writer = StreamWriter::new(9001, signal.clone(), TestSink::default());
writer.end()?;
writer
.write(Value::test_int(2))
.expect_err("shouldn't be able to write after end");
writer.end().expect("end twice should be ok");
let wrote = &writer.writer.0;
assert!(
matches!(wrote.last(), Some(StreamMessage::End(9001))),
"didn't write end message: {wrote:?}"
);
Ok(())
}
#[test]
fn signal_set_dropped() -> Result<(), ShellError> {
let signal = StreamWriterSignal::new(4);
assert!(!signal.is_dropped()?);
signal.set_dropped()?;
assert!(signal.is_dropped()?);
Ok(())
}
#[test]
fn signal_notify_sent_false_if_unacknowledged() -> Result<(), ShellError> {
let signal = StreamWriterSignal::new(2);
assert!(signal.notify_sent()?);
for _ in 0..100 {
assert!(!signal.notify_sent()?);
}
Ok(())
}
#[test]
fn signal_notify_sent_never_false_if_flowing() -> Result<(), ShellError> {
let signal = StreamWriterSignal::new(1);
for _ in 0..100 {
signal.notify_acknowledged()?;
}
for _ in 0..100 {
assert!(signal.notify_sent()?);
}
Ok(())
}
#[test]
fn signal_wait_for_drain_blocks_on_unacknowledged() -> Result<(), ShellError> {
let signal = StreamWriterSignal::new(50);
std::thread::scope(|scope| {
let spawned = scope.spawn(|| {
for _ in 0..100 {
if !signal.notify_sent()? {
signal.wait_for_drain()?;
}
}
Ok(())
});
std::thread::sleep(WAIT_DURATION);
assert!(!spawned.is_finished(), "didn't block");
for _ in 0..100 {
signal.notify_acknowledged()?;
}
std::thread::sleep(WAIT_DURATION);
assert!(spawned.is_finished(), "blocked at end");
spawned.join().unwrap()
})
}
#[test]
fn signal_wait_for_drain_unblocks_on_dropped() -> Result<(), ShellError> {
let signal = StreamWriterSignal::new(1);
std::thread::scope(|scope| {
let spawned = scope.spawn(|| {
while !signal.is_dropped()? {
if !signal.notify_sent()? {
signal.wait_for_drain()?;
}
}
Ok(())
});
std::thread::sleep(WAIT_DURATION);
assert!(!spawned.is_finished(), "didn't block");
signal.set_dropped()?;
std::thread::sleep(WAIT_DURATION);
assert!(spawned.is_finished(), "still blocked at end");
spawned.join().unwrap()
})
}
#[test]
fn stream_manager_single_stream_read_scenario() -> Result<(), ShellError> {
let manager = StreamManager::new();
let handle = manager.get_handle();
let (tx, rx) = mpsc::channel();
let readable = handle.read_stream::<Value, _>(2, tx)?;
let expected_values = vec![Value::test_int(40), Value::test_string("hello")];
for value in &expected_values {
manager.handle_message(StreamMessage::Data(2, value.clone().into()))?;
}
manager.handle_message(StreamMessage::End(2))?;
let values = readable.collect::<Vec<Value>>();
assert_eq!(expected_values, values);
// Now check the sent messages on consumption
// Should be Ack for each message, then Drop
for _ in &expected_values {
match rx.try_recv().expect("failed to receive Ack") {
StreamMessage::Ack(2) => (),
other => panic!("should have been an Ack: {other:?}"),
}
}
match rx.try_recv().expect("failed to receive Drop") {
StreamMessage::Drop(2) => (),
other => panic!("should have been a Drop: {other:?}"),
}
Ok(())
}
#[test]
fn stream_manager_multi_stream_read_scenario() -> Result<(), ShellError> {
let manager = StreamManager::new();
let handle = manager.get_handle();
let (tx, rx) = mpsc::channel();
let readable_list = handle.read_stream::<Value, _>(2, tx.clone())?;
let readable_raw = handle.read_stream::<Result<Vec<u8>, _>, _>(3, tx)?;
let expected_values = (1..100).map(Value::test_int).collect::<Vec<_>>();
let expected_raw_buffers = (1..100).map(|n| vec![n]).collect::<Vec<Vec<u8>>>();
for (value, buf) in expected_values.iter().zip(&expected_raw_buffers) {
manager.handle_message(StreamMessage::Data(2, value.clone().into()))?;
manager.handle_message(StreamMessage::Data(3, StreamData::Raw(Ok(buf.clone()))))?;
}
manager.handle_message(StreamMessage::End(2))?;
manager.handle_message(StreamMessage::End(3))?;
let values = readable_list.collect::<Vec<Value>>();
let bufs = readable_raw.collect::<Result<Vec<Vec<u8>>, _>>()?;
for (expected_value, value) in expected_values.iter().zip(&values) {
assert_eq!(expected_value, value, "in List stream");
}
for (expected_buf, buf) in expected_raw_buffers.iter().zip(&bufs) {
assert_eq!(expected_buf, buf, "in Raw stream");
}
// Now check the sent messages on consumption
// Should be Ack for each message, then Drop
for _ in &expected_values {
match rx.try_recv().expect("failed to receive Ack") {
StreamMessage::Ack(2) => (),
other => panic!("should have been an Ack(2): {other:?}"),
}
}
match rx.try_recv().expect("failed to receive Drop") {
StreamMessage::Drop(2) => (),
other => panic!("should have been a Drop(2): {other:?}"),
}
for _ in &expected_values {
match rx.try_recv().expect("failed to receive Ack") {
StreamMessage::Ack(3) => (),
other => panic!("should have been an Ack(3): {other:?}"),
}
}
match rx.try_recv().expect("failed to receive Drop") {
StreamMessage::Drop(3) => (),
other => panic!("should have been a Drop(3): {other:?}"),
}
// Should be end of stream
assert!(
rx.try_recv().is_err(),
"more messages written to stream than expected"
);
Ok(())
}
#[test]
fn stream_manager_write_scenario() -> Result<(), ShellError> {
let manager = StreamManager::new();
let handle = manager.get_handle();
let (tx, rx) = mpsc::channel();
let mut writable = handle.write_stream(4, tx, 100)?;
let expected_values = vec![b"hello".to_vec(), b"world".to_vec(), b"test".to_vec()];
for value in &expected_values {
writable.write(Ok(value.clone()))?;
}
// Now try signalling ack
assert_eq!(
expected_values.len() as i32,
writable.signal.lock()?.unacknowledged,
"unacknowledged initial count",
);
manager.handle_message(StreamMessage::Ack(4))?;
assert_eq!(
expected_values.len() as i32 - 1,
writable.signal.lock()?.unacknowledged,
"unacknowledged post-Ack count",
);
// ...and Drop
manager.handle_message(StreamMessage::Drop(4))?;
assert!(writable.is_dropped()?);
// Drop the StreamWriter...
drop(writable);
// now check what was actually written
for value in &expected_values {
match rx.try_recv().expect("failed to receive Data") {
StreamMessage::Data(4, StreamData::Raw(Ok(received))) => {
assert_eq!(*value, received);
}
other @ StreamMessage::Data(..) => panic!("wrong Data for {value:?}: {other:?}"),
other => panic!("should have been Data: {other:?}"),
}
}
match rx.try_recv().expect("failed to receive End") {
StreamMessage::End(4) => (),
other => panic!("should have been End: {other:?}"),
}
Ok(())
}
#[test]
fn stream_manager_broadcast_read_error() -> Result<(), ShellError> {
let manager = StreamManager::new();
let handle = manager.get_handle();
let mut readable0 = handle.read_stream::<Value, _>(0, TestSink::default())?;
let mut readable1 = handle.read_stream::<Result<Vec<u8>, _>, _>(1, TestSink::default())?;
let error = ShellError::PluginFailedToDecode {
msg: "test decode error".into(),
};
manager.broadcast_read_error(error.clone())?;
drop(manager);
assert_eq!(
error.to_string(),
readable0
.recv()
.transpose()
.expect("nothing received from readable0")
.expect_err("not an error received from readable0")
.to_string()
);
assert_eq!(
error.to_string(),
readable1
.next()
.expect("nothing received from readable1")
.expect_err("not an error received from readable1")
.to_string()
);
Ok(())
}
#[test]
fn stream_manager_drop_writers_on_drop() -> Result<(), ShellError> {
let manager = StreamManager::new();
let handle = manager.get_handle();
let writable = handle.write_stream(4, TestSink::default(), 100)?;
assert!(!writable.is_dropped()?);
drop(manager);
assert!(writable.is_dropped()?);
Ok(())
}

View File

@ -0,0 +1,143 @@
use std::{
collections::VecDeque,
sync::{Arc, Mutex},
};
use nu_protocol::ShellError;
use crate::{plugin::PluginIdentity, protocol::PluginInput, PluginOutput};
use super::{EngineInterfaceManager, PluginInterfaceManager, PluginRead, PluginWrite};
/// Mock read/write helper for the engine and plugin interfaces.
#[derive(Debug, Clone)]
pub(crate) struct TestCase<I, O> {
r#in: Arc<Mutex<TestData<I>>>,
out: Arc<Mutex<TestData<O>>>,
}
#[derive(Debug)]
pub(crate) struct TestData<T> {
data: VecDeque<T>,
error: Option<ShellError>,
flushed: bool,
}
impl<T> Default for TestData<T> {
fn default() -> Self {
TestData {
data: VecDeque::new(),
error: None,
flushed: false,
}
}
}
impl<I, O> PluginRead<I> for TestCase<I, O> {
fn read(&mut self) -> Result<Option<I>, ShellError> {
let mut lock = self.r#in.lock().unwrap();
if let Some(err) = lock.error.take() {
Err(err)
} else {
Ok(lock.data.pop_front())
}
}
}
impl<I, O> PluginWrite<O> for TestCase<I, O>
where
I: Send + Clone,
O: Send + Clone,
{
fn write(&self, data: &O) -> Result<(), ShellError> {
let mut lock = self.out.lock().unwrap();
lock.flushed = false;
if let Some(err) = lock.error.take() {
Err(err)
} else {
lock.data.push_back(data.clone());
Ok(())
}
}
fn flush(&self) -> Result<(), ShellError> {
let mut lock = self.out.lock().unwrap();
lock.flushed = true;
Ok(())
}
}
#[allow(dead_code)]
impl<I, O> TestCase<I, O> {
pub(crate) fn new() -> TestCase<I, O> {
TestCase {
r#in: Default::default(),
out: Default::default(),
}
}
/// Clear the read buffer.
pub(crate) fn clear(&self) {
self.r#in.lock().unwrap().data.truncate(0);
}
/// Add input that will be read by the interface.
pub(crate) fn add(&self, input: impl Into<I>) {
self.r#in.lock().unwrap().data.push_back(input.into());
}
/// Add multiple inputs that will be read by the interface.
pub(crate) fn extend(&self, inputs: impl IntoIterator<Item = I>) {
self.r#in.lock().unwrap().data.extend(inputs);
}
/// Return an error from the next read operation.
pub(crate) fn set_read_error(&self, err: ShellError) {
self.r#in.lock().unwrap().error = Some(err);
}
/// Return an error from the next write operation.
pub(crate) fn set_write_error(&self, err: ShellError) {
self.out.lock().unwrap().error = Some(err);
}
/// Get the next output that was written.
pub(crate) fn next_written(&self) -> Option<O> {
self.out.lock().unwrap().data.pop_front()
}
/// Iterator over written data.
pub(crate) fn written(&self) -> impl Iterator<Item = O> + '_ {
std::iter::from_fn(|| self.next_written())
}
/// Returns true if the writer was flushed after the last write operation.
pub(crate) fn was_flushed(&self) -> bool {
self.out.lock().unwrap().flushed
}
/// Returns true if the reader has unconsumed reads.
pub(crate) fn has_unconsumed_read(&self) -> bool {
!self.r#in.lock().unwrap().data.is_empty()
}
/// Returns true if the writer has unconsumed writes.
pub(crate) fn has_unconsumed_write(&self) -> bool {
!self.out.lock().unwrap().data.is_empty()
}
}
impl TestCase<PluginOutput, PluginInput> {
/// Create a new [`PluginInterfaceManager`] that writes to this test case.
pub(crate) fn plugin(&self, name: &str) -> PluginInterfaceManager {
PluginInterfaceManager::new(PluginIdentity::new_fake(name), self.clone())
}
}
impl TestCase<PluginInput, PluginOutput> {
/// Create a new [`EngineInterfaceManager`] that writes to this test case.
pub(crate) fn engine(&self) -> EngineInterfaceManager {
EngineInterfaceManager::new(self.clone())
}
}

View File

@ -0,0 +1,559 @@
use std::{path::Path, sync::Arc};
use nu_protocol::{
DataSource, ListStream, PipelineData, PipelineMetadata, RawStream, ShellError, Span, Value,
};
use crate::{
protocol::{
ExternalStreamInfo, ListStreamInfo, PipelineDataHeader, PluginInput, PluginOutput,
RawStreamInfo, StreamData, StreamMessage,
},
sequence::Sequence,
};
use super::{
stream::{StreamManager, StreamManagerHandle},
test_util::TestCase,
Interface, InterfaceManager, PluginRead, PluginWrite,
};
fn test_metadata() -> PipelineMetadata {
PipelineMetadata {
data_source: DataSource::FilePath("/test/path".into()),
}
}
#[derive(Debug)]
struct TestInterfaceManager {
stream_manager: StreamManager,
test: TestCase<PluginInput, PluginOutput>,
seq: Arc<Sequence>,
}
#[derive(Debug, Clone)]
struct TestInterface {
stream_manager_handle: StreamManagerHandle,
test: TestCase<PluginInput, PluginOutput>,
seq: Arc<Sequence>,
}
impl TestInterfaceManager {
fn new(test: &TestCase<PluginInput, PluginOutput>) -> TestInterfaceManager {
TestInterfaceManager {
stream_manager: StreamManager::new(),
test: test.clone(),
seq: Arc::new(Sequence::default()),
}
}
fn consume_all(&mut self) -> Result<(), ShellError> {
while let Some(msg) = self.test.read()? {
self.consume(msg)?;
}
Ok(())
}
}
impl InterfaceManager for TestInterfaceManager {
type Interface = TestInterface;
type Input = PluginInput;
fn get_interface(&self) -> Self::Interface {
TestInterface {
stream_manager_handle: self.stream_manager.get_handle(),
test: self.test.clone(),
seq: self.seq.clone(),
}
}
fn consume(&mut self, input: Self::Input) -> Result<(), ShellError> {
match input {
PluginInput::Stream(msg) => self.consume_stream_message(msg),
_ => unimplemented!(),
}
}
fn stream_manager(&self) -> &StreamManager {
&self.stream_manager
}
fn prepare_pipeline_data(&self, data: PipelineData) -> Result<PipelineData, ShellError> {
Ok(data.set_metadata(Some(test_metadata())))
}
}
impl Interface for TestInterface {
type Output = PluginOutput;
fn write(&self, output: Self::Output) -> Result<(), ShellError> {
self.test.write(&output)
}
fn flush(&self) -> Result<(), ShellError> {
Ok(())
}
fn stream_id_sequence(&self) -> &Sequence {
&self.seq
}
fn stream_manager_handle(&self) -> &StreamManagerHandle {
&self.stream_manager_handle
}
fn prepare_pipeline_data(&self, data: PipelineData) -> Result<PipelineData, ShellError> {
// Add an arbitrary check to the data to verify this is being called
match data {
PipelineData::Value(Value::Binary { .. }, None) => Err(ShellError::NushellFailed {
msg: "TEST can't send binary".into(),
}),
_ => Ok(data),
}
}
}
#[test]
fn read_pipeline_data_empty() -> Result<(), ShellError> {
let manager = TestInterfaceManager::new(&TestCase::new());
let header = PipelineDataHeader::Empty;
assert!(matches!(
manager.read_pipeline_data(header, None)?,
PipelineData::Empty
));
Ok(())
}
#[test]
fn read_pipeline_data_value() -> Result<(), ShellError> {
let manager = TestInterfaceManager::new(&TestCase::new());
let value = Value::test_int(4);
let header = PipelineDataHeader::Value(value.clone());
match manager.read_pipeline_data(header, None)? {
PipelineData::Value(read_value, _) => assert_eq!(value, read_value),
PipelineData::ListStream(_, _) => panic!("unexpected ListStream"),
PipelineData::ExternalStream { .. } => panic!("unexpected ExternalStream"),
PipelineData::Empty => panic!("unexpected Empty"),
}
Ok(())
}
#[test]
fn read_pipeline_data_list_stream() -> Result<(), ShellError> {
let test = TestCase::new();
let mut manager = TestInterfaceManager::new(&test);
let data = (0..100).map(Value::test_int).collect::<Vec<_>>();
for value in &data {
test.add(StreamMessage::Data(7, value.clone().into()));
}
test.add(StreamMessage::End(7));
let header = PipelineDataHeader::ListStream(ListStreamInfo { id: 7 });
let pipe = manager.read_pipeline_data(header, None)?;
assert!(
matches!(pipe, PipelineData::ListStream(..)),
"unexpected PipelineData: {pipe:?}"
);
// need to consume input
manager.consume_all()?;
let mut count = 0;
for (expected, read) in data.into_iter().zip(pipe) {
assert_eq!(expected, read);
count += 1;
}
assert_eq!(100, count);
assert!(test.has_unconsumed_write());
Ok(())
}
#[test]
fn read_pipeline_data_external_stream() -> Result<(), ShellError> {
let test = TestCase::new();
let mut manager = TestInterfaceManager::new(&test);
let iterations = 100;
let out_pattern = b"hello".to_vec();
let err_pattern = vec![5, 4, 3, 2];
test.add(StreamMessage::Data(14, Value::test_int(1).into()));
for _ in 0..iterations {
test.add(StreamMessage::Data(12, Ok(out_pattern.clone()).into()));
test.add(StreamMessage::Data(13, Ok(err_pattern.clone()).into()));
}
test.add(StreamMessage::End(12));
test.add(StreamMessage::End(13));
test.add(StreamMessage::End(14));
let test_span = Span::new(10, 13);
let header = PipelineDataHeader::ExternalStream(ExternalStreamInfo {
span: test_span,
stdout: Some(RawStreamInfo {
id: 12,
is_binary: false,
known_size: Some((out_pattern.len() * iterations) as u64),
}),
stderr: Some(RawStreamInfo {
id: 13,
is_binary: true,
known_size: None,
}),
exit_code: Some(ListStreamInfo { id: 14 }),
trim_end_newline: true,
});
let pipe = manager.read_pipeline_data(header, None)?;
// need to consume input
manager.consume_all()?;
match pipe {
PipelineData::ExternalStream {
stdout,
stderr,
exit_code,
span,
metadata,
trim_end_newline,
} => {
let stdout = stdout.expect("stdout is None");
let stderr = stderr.expect("stderr is None");
let exit_code = exit_code.expect("exit_code is None");
assert_eq!(test_span, span);
assert!(
metadata.is_some(),
"expected metadata to be Some due to prepare_pipeline_data()"
);
assert!(trim_end_newline);
assert!(!stdout.is_binary);
assert!(stderr.is_binary);
assert_eq!(
Some((out_pattern.len() * iterations) as u64),
stdout.known_size
);
assert_eq!(None, stderr.known_size);
// check the streams
let mut count = 0;
for chunk in stdout.stream {
assert_eq!(out_pattern, chunk?);
count += 1;
}
assert_eq!(iterations, count, "stdout length");
let mut count = 0;
for chunk in stderr.stream {
assert_eq!(err_pattern, chunk?);
count += 1;
}
assert_eq!(iterations, count, "stderr length");
assert_eq!(vec![Value::test_int(1)], exit_code.collect::<Vec<_>>());
}
_ => panic!("unexpected PipelineData: {pipe:?}"),
}
// Don't need to check exactly what was written, just be sure that there is some output
assert!(test.has_unconsumed_write());
Ok(())
}
#[test]
fn read_pipeline_data_ctrlc() -> Result<(), ShellError> {
let manager = TestInterfaceManager::new(&TestCase::new());
let header = PipelineDataHeader::ListStream(ListStreamInfo { id: 0 });
let ctrlc = Default::default();
match manager.read_pipeline_data(header, Some(&ctrlc))? {
PipelineData::ListStream(
ListStream {
ctrlc: stream_ctrlc,
..
},
_,
) => {
assert!(Arc::ptr_eq(&ctrlc, &stream_ctrlc.expect("ctrlc not set")));
Ok(())
}
_ => panic!("Unexpected PipelineData, should have been ListStream"),
}
}
#[test]
fn read_pipeline_data_prepared_properly() -> Result<(), ShellError> {
let manager = TestInterfaceManager::new(&TestCase::new());
let header = PipelineDataHeader::ListStream(ListStreamInfo { id: 0 });
match manager.read_pipeline_data(header, None)? {
PipelineData::ListStream(_, meta) => match meta {
Some(PipelineMetadata { data_source }) => match data_source {
DataSource::FilePath(path) => {
assert_eq!(Path::new("/test/path"), path);
Ok(())
}
_ => panic!("wrong metadata: {data_source:?}"),
},
None => panic!("metadata not set"),
},
_ => panic!("Unexpected PipelineData, should have been ListStream"),
}
}
#[test]
fn write_pipeline_data_empty() -> Result<(), ShellError> {
let test = TestCase::new();
let manager = TestInterfaceManager::new(&test);
let interface = manager.get_interface();
let (header, writer) = interface.init_write_pipeline_data(PipelineData::Empty)?;
assert!(matches!(header, PipelineDataHeader::Empty));
writer.write()?;
assert!(
!test.has_unconsumed_write(),
"Empty shouldn't write any stream messages, test: {test:#?}"
);
Ok(())
}
#[test]
fn write_pipeline_data_value() -> Result<(), ShellError> {
let test = TestCase::new();
let manager = TestInterfaceManager::new(&test);
let interface = manager.get_interface();
let value = Value::test_int(7);
let (header, writer) =
interface.init_write_pipeline_data(PipelineData::Value(value.clone(), None))?;
match header {
PipelineDataHeader::Value(read_value) => assert_eq!(value, read_value),
_ => panic!("unexpected header: {header:?}"),
}
writer.write()?;
assert!(
!test.has_unconsumed_write(),
"Value shouldn't write any stream messages, test: {test:#?}"
);
Ok(())
}
#[test]
fn write_pipeline_data_prepared_properly() {
let manager = TestInterfaceManager::new(&TestCase::new());
let interface = manager.get_interface();
// Sending a binary should be an error in our test scenario
let value = Value::test_binary(vec![7, 8]);
match interface.init_write_pipeline_data(PipelineData::Value(value, None)) {
Ok(_) => panic!("prepare_pipeline_data was not called"),
Err(err) => {
assert_eq!(
ShellError::NushellFailed {
msg: "TEST can't send binary".into()
}
.to_string(),
err.to_string()
);
}
}
}
#[test]
fn write_pipeline_data_list_stream() -> Result<(), ShellError> {
let test = TestCase::new();
let manager = TestInterfaceManager::new(&test);
let interface = manager.get_interface();
let values = vec![
Value::test_int(40),
Value::test_bool(false),
Value::test_string("this is a test"),
];
// Set up pipeline data for a list stream
let pipe = PipelineData::ListStream(
ListStream::from_stream(values.clone().into_iter(), None),
None,
);
let (header, writer) = interface.init_write_pipeline_data(pipe)?;
let info = match header {
PipelineDataHeader::ListStream(info) => info,
_ => panic!("unexpected header: {header:?}"),
};
writer.write()?;
// Now make sure the stream messages have been written
for value in values {
match test.next_written().expect("unexpected end of stream") {
PluginOutput::Stream(StreamMessage::Data(id, data)) => {
assert_eq!(info.id, id, "Data id");
match data {
StreamData::List(read_value) => assert_eq!(value, read_value, "Data value"),
_ => panic!("unexpected Data: {data:?}"),
}
}
other => panic!("unexpected output: {other:?}"),
}
}
match test.next_written().expect("unexpected end of stream") {
PluginOutput::Stream(StreamMessage::End(id)) => {
assert_eq!(info.id, id, "End id");
}
other => panic!("unexpected output: {other:?}"),
}
assert!(!test.has_unconsumed_write());
Ok(())
}
#[test]
fn write_pipeline_data_external_stream() -> Result<(), ShellError> {
let test = TestCase::new();
let manager = TestInterfaceManager::new(&test);
let interface = manager.get_interface();
let stdout_bufs = vec![
b"hello".to_vec(),
b"world".to_vec(),
b"these are tests".to_vec(),
];
let stdout_len = stdout_bufs.iter().map(|b| b.len() as u64).sum::<u64>();
let stderr_bufs = vec![b"error messages".to_vec(), b"go here".to_vec()];
let exit_code = Value::test_int(7);
let span = Span::new(400, 500);
// Set up pipeline data for an external stream
let pipe = PipelineData::ExternalStream {
stdout: Some(RawStream::new(
Box::new(stdout_bufs.clone().into_iter().map(Ok)),
None,
span,
Some(stdout_len),
)),
stderr: Some(RawStream::new(
Box::new(stderr_bufs.clone().into_iter().map(Ok)),
None,
span,
None,
)),
exit_code: Some(ListStream::from_stream(
std::iter::once(exit_code.clone()),
None,
)),
span,
metadata: None,
trim_end_newline: true,
};
let (header, writer) = interface.init_write_pipeline_data(pipe)?;
let info = match header {
PipelineDataHeader::ExternalStream(info) => info,
_ => panic!("unexpected header: {header:?}"),
};
writer.write()?;
let stdout_info = info.stdout.as_ref().expect("stdout info is None");
let stderr_info = info.stderr.as_ref().expect("stderr info is None");
let exit_code_info = info.exit_code.as_ref().expect("exit code info is None");
assert_eq!(span, info.span);
assert!(info.trim_end_newline);
assert_eq!(Some(stdout_len), stdout_info.known_size);
assert_eq!(None, stderr_info.known_size);
// Now make sure the stream messages have been written
let mut stdout_iter = stdout_bufs.into_iter();
let mut stderr_iter = stderr_bufs.into_iter();
let mut exit_code_iter = std::iter::once(exit_code);
let mut stdout_ended = false;
let mut stderr_ended = false;
let mut exit_code_ended = false;
// There's no specific order these messages must come in with respect to how the streams are
// interleaved, but all of the data for each stream must be in its original order, and the
// End must come after all Data
for msg in test.written() {
match msg {
PluginOutput::Stream(StreamMessage::Data(id, data)) => {
if id == stdout_info.id {
let result: Result<Vec<u8>, ShellError> =
data.try_into().expect("wrong data in stdout stream");
assert_eq!(
stdout_iter.next().expect("too much data in stdout"),
result.expect("unexpected error in stdout stream")
);
} else if id == stderr_info.id {
let result: Result<Vec<u8>, ShellError> =
data.try_into().expect("wrong data in stderr stream");
assert_eq!(
stderr_iter.next().expect("too much data in stderr"),
result.expect("unexpected error in stderr stream")
);
} else if id == exit_code_info.id {
let code: Value = data.try_into().expect("wrong data in stderr stream");
assert_eq!(
exit_code_iter.next().expect("too much data in stderr"),
code
);
} else {
panic!("unrecognized stream id: {id}");
}
}
PluginOutput::Stream(StreamMessage::End(id)) => {
if id == stdout_info.id {
assert!(!stdout_ended, "double End of stdout");
assert!(stdout_iter.next().is_none(), "unexpected end of stdout");
stdout_ended = true;
} else if id == stderr_info.id {
assert!(!stderr_ended, "double End of stderr");
assert!(stderr_iter.next().is_none(), "unexpected end of stderr");
stderr_ended = true;
} else if id == exit_code_info.id {
assert!(!exit_code_ended, "double End of exit_code");
assert!(
exit_code_iter.next().is_none(),
"unexpected end of exit_code"
);
exit_code_ended = true;
} else {
panic!("unrecognized stream id: {id}");
}
}
other => panic!("unexpected output: {other:?}"),
}
}
assert!(stdout_ended, "stdout did not End");
assert!(stderr_ended, "stderr did not End");
assert!(exit_code_ended, "exit_code did not End");
Ok(())
}

View File

@ -2,53 +2,69 @@ mod declaration;
pub use declaration::PluginDeclaration;
use nu_engine::documentation::get_flags_section;
use std::collections::HashMap;
use std::ffi::OsStr;
use std::sync::{Arc, Mutex};
use crate::protocol::{CallInput, LabeledError, PluginCall, PluginData, PluginResponse};
use crate::plugin::interface::{EngineInterfaceManager, ReceivedPluginCall};
use crate::protocol::{CallInfo, CustomValueOp, LabeledError, PluginInput, PluginOutput};
use crate::EncodingType;
use std::env;
use std::fmt::Write;
use std::io::{BufReader, ErrorKind, Read, Write as WriteTrait};
use std::io::{BufReader, Read, Write as WriteTrait};
use std::path::Path;
use std::process::{Child, ChildStdout, Command as CommandSys, Stdio};
use nu_protocol::{CustomValue, PluginSignature, ShellError, Span, Value};
use nu_protocol::{PipelineData, PluginSignature, ShellError, Value};
mod interface;
pub(crate) use interface::PluginInterface;
mod context;
pub(crate) use context::PluginExecutionCommandContext;
mod identity;
pub(crate) use identity::PluginIdentity;
use self::interface::{InterfaceManager, PluginInterfaceManager};
use super::EvaluatedCall;
pub(crate) const OUTPUT_BUFFER_SIZE: usize = 8192;
/// Encoding scheme that defines a plugin's communication protocol with Nu
pub trait PluginEncoder: Clone {
/// The name of the encoder (e.g., `json`)
fn name(&self) -> &str;
/// Encoder for a specific message type. Usually implemented on [`PluginInput`]
/// and [`PluginOutput`].
#[doc(hidden)]
pub trait Encoder<T>: Clone + Send + Sync {
/// Serialize a value in the [`PluginEncoder`]s format
///
/// Returns [ShellError::IOError] if there was a problem writing, or
/// [ShellError::PluginFailedToEncode] for a serialization error.
#[doc(hidden)]
fn encode(&self, data: &T, writer: &mut impl std::io::Write) -> Result<(), ShellError>;
/// Serialize a `PluginCall` in the `PluginEncoder`s format
fn encode_call(
&self,
plugin_call: &PluginCall,
writer: &mut impl std::io::Write,
) -> Result<(), ShellError>;
/// Deserialize a `PluginCall` from the `PluginEncoder`s format
fn decode_call(&self, reader: &mut impl std::io::BufRead) -> Result<PluginCall, ShellError>;
/// Serialize a `PluginResponse` from the plugin in this `PluginEncoder`'s preferred
/// format
fn encode_response(
&self,
plugin_response: &PluginResponse,
writer: &mut impl std::io::Write,
) -> Result<(), ShellError>;
/// Deserialize a `PluginResponse` from the plugin from this `PluginEncoder`'s
/// preferred format
fn decode_response(
&self,
reader: &mut impl std::io::BufRead,
) -> Result<PluginResponse, ShellError>;
/// Deserialize a value from the [`PluginEncoder`]'s format
///
/// Returns `None` if there is no more output to receive.
///
/// Returns [ShellError::IOError] if there was a problem reading, or
/// [ShellError::PluginFailedToDecode] for a deserialization error.
#[doc(hidden)]
fn decode(&self, reader: &mut impl std::io::BufRead) -> Result<Option<T>, ShellError>;
}
pub(crate) fn create_command(path: &Path, shell: Option<&Path>) -> CommandSys {
/// Encoding scheme that defines a plugin's communication protocol with Nu
pub trait PluginEncoder: Encoder<PluginInput> + Encoder<PluginOutput> {
/// The name of the encoder (e.g., `json`)
fn name(&self) -> &str;
}
fn create_command(path: &Path, shell: Option<&Path>) -> CommandSys {
log::trace!("Starting plugin: {path:?}, shell = {shell:?}");
// There is only one mode supported at the moment, but the idea is that future
// communication methods could be supported if desirable
let mut input_arg = Some("--stdio");
let mut process = match (path.extension(), shell) {
(_, Some(shell)) => {
let mut process = std::process::Command::new(shell);
@ -57,18 +73,25 @@ pub(crate) fn create_command(path: &Path, shell: Option<&Path>) -> CommandSys {
process
}
(Some(extension), None) => {
let (shell, separator) = match extension.to_str() {
let (shell, command_switch) = match extension.to_str() {
Some("cmd") | Some("bat") => (Some("cmd"), Some("/c")),
Some("sh") => (Some("sh"), Some("-c")),
Some("py") => (Some("python"), None),
_ => (None, None),
};
match (shell, separator) {
(Some(shell), Some(separator)) => {
match (shell, command_switch) {
(Some(shell), Some(command_switch)) => {
let mut process = std::process::Command::new(shell);
process.arg(separator);
process.arg(path);
process.arg(command_switch);
// If `command_switch` is set, we need to pass the path + arg as one argument
// e.g. sh -c "nu_plugin_inc --stdio"
let mut combined = path.as_os_str().to_owned();
if let Some(arg) = input_arg.take() {
combined.push(OsStr::new(" "));
combined.push(OsStr::new(arg));
}
process.arg(combined);
process
}
@ -84,41 +107,60 @@ pub(crate) fn create_command(path: &Path, shell: Option<&Path>) -> CommandSys {
(None, None) => std::process::Command::new(path),
};
// Pass input_arg, unless we consumed it already
if let Some(input_arg) = input_arg {
process.arg(input_arg);
}
// Both stdout and stdin are piped so we can receive information from the plugin
process.stdout(Stdio::piped()).stdin(Stdio::piped());
process
}
pub(crate) fn call_plugin(
child: &mut Child,
plugin_call: PluginCall,
encoding: &EncodingType,
span: Span,
) -> Result<PluginResponse, ShellError> {
if let Some(mut stdin_writer) = child.stdin.take() {
let encoding_clone = encoding.clone();
// If the child process fills its stdout buffer, it may end up waiting until the parent
// reads the stdout, and not be able to read stdin in the meantime, causing a deadlock.
// Writing from another thread ensures that stdout is being read at the same time, avoiding the problem.
std::thread::spawn(move || encoding_clone.encode_call(&plugin_call, &mut stdin_writer));
fn make_plugin_interface(
mut child: Child,
identity: Arc<PluginIdentity>,
) -> Result<PluginInterface, ShellError> {
let stdin = child
.stdin
.take()
.ok_or_else(|| ShellError::PluginFailedToLoad {
msg: "plugin missing stdin writer".into(),
})?;
let mut stdout = child
.stdout
.take()
.ok_or_else(|| ShellError::PluginFailedToLoad {
msg: "Plugin missing stdout writer".into(),
})?;
let encoder = get_plugin_encoding(&mut stdout)?;
let reader = BufReader::with_capacity(OUTPUT_BUFFER_SIZE, stdout);
let mut manager = PluginInterfaceManager::new(identity, (Mutex::new(stdin), encoder));
let interface = manager.get_interface();
interface.hello()?;
// Spawn the reader on a new thread. We need to be able to read messages at the same time that
// we write, because we are expected to be able to handle multiple messages coming in from the
// plugin at any time, including stream messages like `Drop`.
std::thread::Builder::new()
.name("plugin interface reader".into())
.spawn(move || {
if let Err(err) = manager.consume_all((reader, encoder)) {
log::warn!("Error in PluginInterfaceManager: {err}");
}
// Deserialize response from plugin to extract the resulting value
if let Some(stdout_reader) = &mut child.stdout {
let reader = stdout_reader;
let mut buf_read = BufReader::with_capacity(OUTPUT_BUFFER_SIZE, reader);
encoding.decode_response(&mut buf_read)
} else {
Err(ShellError::GenericError {
error: "Error with stdout reader".into(),
msg: "no stdout reader".into(),
span: Some(span),
help: None,
inner: vec![],
// If the loop has ended, drop the manager so everyone disconnects and then wait for the
// child to exit
drop(manager);
let _ = child.wait();
})
}
.expect("failed to spawn thread");
Ok(interface)
}
#[doc(hidden)] // Note: not for plugin authors / only used in nu-parser
@ -127,71 +169,9 @@ pub fn get_signature(
shell: Option<&Path>,
current_envs: &HashMap<String, String>,
) -> Result<Vec<PluginSignature>, ShellError> {
let mut plugin_cmd = create_command(path, shell);
let program_name = plugin_cmd.get_program().to_os_string().into_string();
plugin_cmd.envs(current_envs);
let mut child = plugin_cmd.spawn().map_err(|err| {
let error_msg = match err.kind() {
ErrorKind::NotFound => match program_name {
Ok(prog_name) => {
format!("Can't find {prog_name}, please make sure that {prog_name} is in PATH.")
}
_ => {
format!("Error spawning child process: {err}")
}
},
_ => {
format!("Error spawning child process: {err}")
}
};
ShellError::PluginFailedToLoad { msg: error_msg }
})?;
let mut stdin_writer = child
.stdin
.take()
.ok_or_else(|| ShellError::PluginFailedToLoad {
msg: "plugin missing stdin writer".into(),
})?;
let mut stdout_reader = child
.stdout
.take()
.ok_or_else(|| ShellError::PluginFailedToLoad {
msg: "Plugin missing stdout reader".into(),
})?;
let encoding = get_plugin_encoding(&mut stdout_reader)?;
// Create message to plugin to indicate that signature is required and
// send call to plugin asking for signature
let encoding_clone = encoding.clone();
// If the child process fills its stdout buffer, it may end up waiting until the parent
// reads the stdout, and not be able to read stdin in the meantime, causing a deadlock.
// Writing from another thread ensures that stdout is being read at the same time, avoiding the problem.
std::thread::spawn(move || {
encoding_clone.encode_call(&PluginCall::Signature, &mut stdin_writer)
});
// deserialize response from plugin to extract the signature
let reader = stdout_reader;
let mut buf_read = BufReader::with_capacity(OUTPUT_BUFFER_SIZE, reader);
let response = encoding.decode_response(&mut buf_read)?;
let signatures = match response {
PluginResponse::Signature(sign) => Ok(sign),
PluginResponse::Error(err) => Err(err.into()),
_ => Err(ShellError::PluginFailedToLoad {
msg: "Plugin missing signature".into(),
}),
}?;
match child.wait() {
Ok(_) => Ok(signatures),
Err(err) => Err(ShellError::PluginFailedToLoad {
msg: format!("{err}"),
}),
}
Arc::new(PluginIdentity::new(path, shell.map(|s| s.to_owned())))
.spawn(current_envs)?
.get_signature()
}
/// The basic API for a Nushell plugin
@ -199,6 +179,9 @@ pub fn get_signature(
/// This is the trait that Nushell plugins must implement. The methods defined on
/// `Plugin` are invoked by [serve_plugin] during plugin registration and execution.
///
/// If large amounts of data are expected to need to be received or produced, it may be more
/// appropriate to implement [StreamingPlugin] instead.
///
/// # Examples
/// Basic usage:
/// ```
@ -224,6 +207,10 @@ pub fn get_signature(
/// Ok(Value::string("Hello, World!".to_owned(), call.head))
/// }
/// }
///
/// # fn main() {
/// # serve_plugin(&mut HelloPlugin{}, MsgPackSerializer)
/// # }
/// ```
pub trait Plugin {
/// The signature of the plugin
@ -244,6 +231,9 @@ pub trait Plugin {
/// invoked command will be passed in via this argument. The `call` contains
/// metadata describing how the plugin was invoked and `input` contains the structured
/// data passed to the command implemented by this [Plugin].
///
/// This variant does not support streaming. Consider implementing [StreamingPlugin] instead
/// if streaming is desired.
fn run(
&mut self,
name: &str,
@ -253,13 +243,115 @@ pub trait Plugin {
) -> Result<Value, LabeledError>;
}
/// The streaming API for a Nushell plugin
///
/// This is a more low-level version of the [Plugin] trait that supports operating on streams of
/// data. If you don't need to operate on streams, consider using that trait instead.
///
/// The methods defined on `StreamingPlugin` are invoked by [serve_plugin] during plugin
/// registration and execution.
///
/// # Examples
/// Basic usage:
/// ```
/// # use nu_plugin::*;
/// # use nu_protocol::{PluginSignature, PipelineData, Type, Value};
/// struct LowercasePlugin;
///
/// impl StreamingPlugin for LowercasePlugin {
/// fn signature(&self) -> Vec<PluginSignature> {
/// let sig = PluginSignature::build("lowercase")
/// .usage("Convert each string in a stream to lowercase")
/// .input_output_type(Type::List(Type::String.into()), Type::List(Type::String.into()));
///
/// vec![sig]
/// }
///
/// fn run(
/// &mut self,
/// name: &str,
/// config: &Option<Value>,
/// call: &EvaluatedCall,
/// input: PipelineData,
/// ) -> Result<PipelineData, LabeledError> {
/// let span = call.head;
/// Ok(input.map(move |value| {
/// value.as_str()
/// .map(|string| Value::string(string.to_lowercase(), span))
/// // Errors in a stream should be returned as values.
/// .unwrap_or_else(|err| Value::error(err, span))
/// }, None)?)
/// }
/// }
///
/// # fn main() {
/// # serve_plugin(&mut LowercasePlugin{}, MsgPackSerializer)
/// # }
/// ```
pub trait StreamingPlugin {
/// The signature of the plugin
///
/// This method returns the [PluginSignature]s that describe the capabilities
/// of this plugin. Since a single plugin executable can support multiple invocation
/// patterns we return a `Vec` of signatures.
fn signature(&self) -> Vec<PluginSignature>;
/// Perform the actual behavior of the plugin
///
/// The behavior of the plugin is defined by the implementation of this method.
/// When Nushell invoked the plugin [serve_plugin] will call this method and
/// print the serialized returned value or error to stdout, which Nushell will
/// interpret.
///
/// The `name` is only relevant for plugins that implement multiple commands as the
/// invoked command will be passed in via this argument. The `call` contains
/// metadata describing how the plugin was invoked and `input` contains the structured
/// data passed to the command implemented by this [Plugin].
///
/// This variant expects to receive and produce [PipelineData], which allows for stream-based
/// handling of I/O. This is recommended if the plugin is expected to transform large lists or
/// potentially large quantities of bytes. The API is more complex however, and [Plugin] is
/// recommended instead if this is not a concern.
fn run(
&mut self,
name: &str,
config: &Option<Value>,
call: &EvaluatedCall,
input: PipelineData,
) -> Result<PipelineData, LabeledError>;
}
/// All [Plugin]s can be used as [StreamingPlugin]s, but input streams will be fully consumed
/// before the plugin runs.
impl<T: Plugin> StreamingPlugin for T {
fn signature(&self) -> Vec<PluginSignature> {
<Self as Plugin>::signature(self)
}
fn run(
&mut self,
name: &str,
config: &Option<Value>,
call: &EvaluatedCall,
input: PipelineData,
) -> Result<PipelineData, LabeledError> {
// Unwrap the PipelineData from input, consuming the potential stream, and pass it to the
// simpler signature in Plugin
let span = input.span().unwrap_or(call.head);
let input_value = input.into_value(span);
// Wrap the output in PipelineData::Value
<Self as Plugin>::run(self, name, config, call, &input_value)
.map(|value| PipelineData::Value(value, None))
}
}
/// Function used to implement the communication protocol between
/// nushell and an external plugin.
/// nushell and an external plugin. Both [Plugin] and [StreamingPlugin] are supported.
///
/// When creating a new plugin this function is typically used as the main entry
/// point for the plugin, e.g.
///
/// ```
/// ```rust,no_run
/// # use nu_plugin::*;
/// # use nu_protocol::{PluginSignature, Value};
/// # struct MyPlugin;
@ -273,22 +365,42 @@ pub trait Plugin {
/// serve_plugin(&mut MyPlugin::new(), MsgPackSerializer)
/// }
/// ```
///
/// The object that is expected to be received by nushell is the `PluginResponse` struct.
/// The `serve_plugin` function should ensure that it is encoded correctly and sent
/// to StdOut for nushell to decode and and present its result.
pub fn serve_plugin(plugin: &mut impl Plugin, encoder: impl PluginEncoder) {
if env::args().any(|arg| (arg == "-h") || (arg == "--help")) {
pub fn serve_plugin(plugin: &mut impl StreamingPlugin, encoder: impl PluginEncoder + 'static) {
let mut args = env::args().skip(1);
let number_of_args = args.len();
let first_arg = args.next();
if number_of_args == 0
|| first_arg
.as_ref()
.is_some_and(|arg| arg == "-h" || arg == "--help")
{
print_help(plugin, encoder);
std::process::exit(0)
}
// Must pass --stdio for plugin execution. Any other arg is an error to give us options in the
// future.
if number_of_args > 1 || !first_arg.is_some_and(|arg| arg == "--stdio") {
eprintln!(
"{}: This plugin must be run from within Nushell.",
env::current_exe()
.map(|path| path.display().to_string())
.unwrap_or_else(|_| "plugin".into())
);
eprintln!(
"If you are running from Nushell, this plugin may be incompatible with the \
version of nushell you are using."
);
std::process::exit(1)
}
// tell nushell encoding.
//
// 1 byte
// encoding format: | content-length | content |
{
let mut stdout = std::io::stdout();
{
let encoding = encoder.name();
let length = encoding.len() as u8;
let mut encoding_content: Vec<u8> = encoding.as_bytes().to_vec();
@ -301,91 +413,120 @@ pub fn serve_plugin(plugin: &mut impl Plugin, encoder: impl PluginEncoder) {
.expect("Failed to tell nushell my encoding when flushing stdout");
}
let mut stdin_buf = BufReader::with_capacity(OUTPUT_BUFFER_SIZE, std::io::stdin());
let plugin_call = encoder.decode_call(&mut stdin_buf);
let mut manager = EngineInterfaceManager::new((stdout, encoder.clone()));
let call_receiver = manager
.take_plugin_call_receiver()
// This expect should be totally safe, as we just created the manager
.expect("take_plugin_call_receiver returned None");
match plugin_call {
// We need to hold on to the interface to keep the manager alive. We can drop it at the end
let interface = manager.get_interface();
// Try an operation that could result in ShellError. Exit if an I/O error is encountered.
// Try to report the error to nushell otherwise, and failing that, panic.
macro_rules! try_or_report {
($interface:expr, $expr:expr) => (match $expr {
Ok(val) => val,
// Just exit if there is an I/O error. Most likely this just means that nushell
// interrupted us. If not, the error probably happened on the other side too, so we
// don't need to also report it.
Err(ShellError::IOError { .. }) => std::process::exit(1),
// If there is another error, try to send it to nushell and then exit.
Err(err) => {
let response = PluginResponse::Error(err.into());
encoder
.encode_response(&response, &mut std::io::stdout())
.expect("Error encoding response");
let _ = $interface.write_response(Err(err.clone())).unwrap_or_else(|_| {
// If we can't send it to nushell, panic with it so at least we get the output
panic!("{}", err)
});
std::process::exit(1)
}
Ok(plugin_call) => {
})
}
// Send Hello message
try_or_report!(interface, interface.hello());
// Spawn the reader thread
std::thread::Builder::new()
.name("engine interface reader".into())
.spawn(move || {
if let Err(err) = manager.consume_all((std::io::stdin().lock(), encoder)) {
// Do our best to report the read error. Most likely there is some kind of
// incompatibility between the plugin and nushell, so it makes more sense to try to
// report it on stderr than to send something.
let exe = std::env::current_exe().ok();
let plugin_name: String = exe
.as_ref()
.and_then(|path| path.file_stem())
.map(|stem| stem.to_string_lossy().into_owned())
.map(|stem| {
stem.strip_prefix("nu_plugin_")
.map(|s| s.to_owned())
.unwrap_or(stem)
})
.unwrap_or_else(|| "(unknown)".into());
eprintln!("Plugin `{plugin_name}` read error: {err}");
std::process::exit(1);
}
})
.expect("failed to spawn thread");
for plugin_call in call_receiver {
match plugin_call {
// Sending the signature back to nushell to create the declaration definition
PluginCall::Signature => {
let response = PluginResponse::Signature(plugin.signature());
encoder
.encode_response(&response, &mut std::io::stdout())
.expect("Error encoding response");
ReceivedPluginCall::Signature { engine } => {
try_or_report!(engine, engine.write_signature(plugin.signature()));
}
PluginCall::CallInfo(call_info) => {
let input = match call_info.input {
CallInput::Value(value) => Ok(value),
CallInput::Data(plugin_data) => {
bincode::deserialize::<Box<dyn CustomValue>>(&plugin_data.data)
.map(|custom_value| {
Value::custom_value(custom_value, plugin_data.span)
})
.map_err(|err| ShellError::PluginFailedToDecode {
msg: err.to_string(),
})
}
};
let value = match input {
Ok(input) => {
plugin.run(&call_info.name, &call_info.config, &call_info.call, &input)
}
Err(err) => Err(err.into()),
};
let response = match value {
Ok(value) => {
let span = value.span();
match value {
Value::CustomValue { val, .. } => match bincode::serialize(&val) {
Ok(data) => {
let name = val.value_string();
PluginResponse::PluginData(name, PluginData { data, span })
}
Err(err) => PluginResponse::Error(
ShellError::PluginFailedToEncode {
msg: err.to_string(),
}
.into(),
),
// Run the plugin, handling any input or output streams
ReceivedPluginCall::Run {
engine,
call:
CallInfo {
name,
config,
call,
input,
},
value => PluginResponse::Value(Box::new(value)),
} => {
let result = plugin.run(&name, &config, &call, input);
let write_result = engine
.write_response(result)
.map(|writer| writer.write_background());
try_or_report!(engine, write_result);
}
}
Err(err) => PluginResponse::Error(err),
};
encoder
.encode_response(&response, &mut std::io::stdout())
.expect("Error encoding response");
}
PluginCall::CollapseCustomValue(plugin_data) => {
let response = bincode::deserialize::<Box<dyn CustomValue>>(&plugin_data.data)
.map_err(|err| ShellError::PluginFailedToDecode {
msg: err.to_string(),
})
.and_then(|val| val.to_base_value(plugin_data.span))
.map(Box::new)
.map_err(LabeledError::from)
.map_or_else(PluginResponse::Error, PluginResponse::Value);
encoder
.encode_response(&response, &mut std::io::stdout())
.expect("Error encoding response");
// Do an operation on a custom value
ReceivedPluginCall::CustomValueOp {
engine,
custom_value,
op,
} => {
let local_value = try_or_report!(
engine,
custom_value
.item
.deserialize_to_custom_value(custom_value.span)
);
match op {
CustomValueOp::ToBaseValue => {
let result = local_value
.to_base_value(custom_value.span)
.map(|value| PipelineData::Value(value, None));
let write_result = engine
.write_response(result)
.map(|writer| writer.write_background());
try_or_report!(engine, write_result);
}
}
}
}
}
fn print_help(plugin: &mut impl Plugin, encoder: impl PluginEncoder) {
// This will stop the manager
drop(interface);
}
fn print_help(plugin: &mut impl StreamingPlugin, encoder: impl PluginEncoder) {
println!("Nushell Plugin");
println!("Encoder: {}", encoder.name());

View File

@ -1,33 +1,201 @@
mod evaluated_call;
mod plugin_custom_value;
mod plugin_data;
mod protocol_info;
#[cfg(test)]
mod tests;
#[cfg(test)]
pub(crate) mod test_util;
pub use evaluated_call::EvaluatedCall;
use nu_protocol::{PluginSignature, ShellError, Span, Value};
use nu_protocol::{PluginSignature, RawStream, ShellError, Span, Spanned, Value};
pub use plugin_custom_value::PluginCustomValue;
pub use plugin_data::PluginData;
pub(crate) use protocol_info::ProtocolInfo;
use serde::{Deserialize, Serialize};
#[derive(Serialize, Deserialize, Debug)]
pub struct CallInfo {
#[cfg(test)]
pub(crate) use protocol_info::Protocol;
/// A sequential identifier for a stream
pub type StreamId = usize;
/// A sequential identifier for a [`PluginCall`]
pub type PluginCallId = usize;
/// Information about a plugin command invocation. This includes an [`EvaluatedCall`] as a
/// serializable representation of [`nu_protocol::ast::Call`]. The type parameter determines
/// the input type.
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct CallInfo<D> {
/// The name of the command to be run
pub name: String,
/// Information about the invocation, including arguments
pub call: EvaluatedCall,
pub input: CallInput,
/// Pipeline input. This is usually [`nu_protocol::PipelineData`] or [`PipelineDataHeader`]
pub input: D,
/// Plugin configuration, if available
pub config: Option<Value>,
}
#[derive(Serialize, Deserialize, Debug, PartialEq)]
pub enum CallInput {
/// The initial (and perhaps only) part of any [`nu_protocol::PipelineData`] sent over the wire.
///
/// This may contain a single value, or may initiate a stream with a [`StreamId`].
#[derive(Serialize, Deserialize, Debug, PartialEq, Clone)]
pub enum PipelineDataHeader {
/// No input
Empty,
/// A single value
Value(Value),
Data(PluginData),
/// Initiate [`nu_protocol::PipelineData::ListStream`].
///
/// Items are sent via [`StreamData`]
ListStream(ListStreamInfo),
/// Initiate [`nu_protocol::PipelineData::ExternalStream`].
///
/// Items are sent via [`StreamData`]
ExternalStream(ExternalStreamInfo),
}
// Information sent to the plugin
#[derive(Serialize, Deserialize, Debug)]
pub enum PluginCall {
/// Additional information about list (value) streams
#[derive(Serialize, Deserialize, Debug, PartialEq, Eq, Clone)]
pub struct ListStreamInfo {
pub id: StreamId,
}
/// Additional information about external streams
#[derive(Serialize, Deserialize, Debug, PartialEq, Eq, Clone)]
pub struct ExternalStreamInfo {
pub span: Span,
pub stdout: Option<RawStreamInfo>,
pub stderr: Option<RawStreamInfo>,
pub exit_code: Option<ListStreamInfo>,
pub trim_end_newline: bool,
}
/// Additional information about raw (byte) streams
#[derive(Serialize, Deserialize, Debug, PartialEq, Eq, Clone)]
pub struct RawStreamInfo {
pub id: StreamId,
pub is_binary: bool,
pub known_size: Option<u64>,
}
impl RawStreamInfo {
pub(crate) fn new(id: StreamId, stream: &RawStream) -> Self {
RawStreamInfo {
id,
is_binary: stream.is_binary,
known_size: stream.known_size,
}
}
}
/// Calls that a plugin can execute. The type parameter determines the input type.
#[derive(Serialize, Deserialize, Debug, Clone)]
pub enum PluginCall<D> {
Signature,
CallInfo(CallInfo),
CollapseCustomValue(PluginData),
Run(CallInfo<D>),
CustomValueOp(Spanned<PluginCustomValue>, CustomValueOp),
}
/// Operations supported for custom values.
#[derive(Serialize, Deserialize, Debug, Clone)]
pub enum CustomValueOp {
/// [`to_base_value()`](nu_protocol::CustomValue::to_base_value)
ToBaseValue,
}
/// Any data sent to the plugin
#[derive(Serialize, Deserialize, Debug, Clone)]
pub enum PluginInput {
/// This must be the first message. Indicates supported protocol
Hello(ProtocolInfo),
/// Execute a [`PluginCall`], such as `Run` or `Signature`. The ID should not have been used
/// before.
Call(PluginCallId, PluginCall<PipelineDataHeader>),
/// Stream control or data message. Untagged to keep them as small as possible.
///
/// For example, `Stream(Ack(0))` is encoded as `{"Ack": 0}`
#[serde(untagged)]
Stream(StreamMessage),
}
impl TryFrom<PluginInput> for StreamMessage {
type Error = PluginInput;
fn try_from(msg: PluginInput) -> Result<StreamMessage, PluginInput> {
match msg {
PluginInput::Stream(stream_msg) => Ok(stream_msg),
_ => Err(msg),
}
}
}
impl From<StreamMessage> for PluginInput {
fn from(stream_msg: StreamMessage) -> PluginInput {
PluginInput::Stream(stream_msg)
}
}
/// A single item of stream data for a stream.
#[derive(Serialize, Deserialize, Debug, Clone)]
pub enum StreamData {
List(Value),
Raw(Result<Vec<u8>, ShellError>),
}
impl From<Value> for StreamData {
fn from(value: Value) -> Self {
StreamData::List(value)
}
}
impl From<Result<Vec<u8>, ShellError>> for StreamData {
fn from(value: Result<Vec<u8>, ShellError>) -> Self {
StreamData::Raw(value)
}
}
impl TryFrom<StreamData> for Value {
type Error = ShellError;
fn try_from(data: StreamData) -> Result<Value, ShellError> {
match data {
StreamData::List(value) => Ok(value),
StreamData::Raw(_) => Err(ShellError::PluginFailedToDecode {
msg: "expected list stream data, found raw data".into(),
}),
}
}
}
impl TryFrom<StreamData> for Result<Vec<u8>, ShellError> {
type Error = ShellError;
fn try_from(data: StreamData) -> Result<Result<Vec<u8>, ShellError>, ShellError> {
match data {
StreamData::Raw(value) => Ok(value),
StreamData::List(_) => Err(ShellError::PluginFailedToDecode {
msg: "expected raw stream data, found list data".into(),
}),
}
}
}
/// A stream control or data message.
#[derive(Serialize, Deserialize, Debug, Clone)]
pub enum StreamMessage {
/// Append data to the stream. Sent by the stream producer.
Data(StreamId, StreamData),
/// End of stream. Sent by the stream producer.
End(StreamId),
/// Notify that the read end of the stream has closed, and further messages should not be
/// sent. Sent by the stream consumer.
Drop(StreamId),
/// Acknowledge that a message has been consumed. This is used to implement flow control by
/// the stream producer. Sent by the stream consumer.
Ack(StreamId),
}
/// An error message with debugging information that can be passed to Nushell from the plugin
@ -36,7 +204,7 @@ pub enum PluginCall {
/// a [Plugin](crate::Plugin)'s [`run`](crate::Plugin::run()) method. It contains
/// the error message along with optional [Span] data to support highlighting in the
/// shell.
#[derive(Clone, PartialEq, Eq, Serialize, Deserialize, Debug)]
#[derive(Serialize, Deserialize, PartialEq, Eq, Debug, Clone)]
pub struct LabeledError {
/// The name of the error
pub label: String,
@ -48,81 +216,108 @@ pub struct LabeledError {
impl From<LabeledError> for ShellError {
fn from(error: LabeledError) -> Self {
match error.span {
Some(span) => ShellError::GenericError {
if error.span.is_some() {
ShellError::GenericError {
error: error.label,
msg: error.msg,
span: Some(span),
span: error.span,
help: None,
inner: vec![],
},
None => ShellError::GenericError {
}
} else {
ShellError::GenericError {
error: error.label,
msg: "".into(),
span: None,
help: Some(error.msg),
help: (!error.msg.is_empty()).then_some(error.msg),
inner: vec![],
},
}
}
}
}
impl From<ShellError> for LabeledError {
fn from(error: ShellError) -> Self {
match error {
ShellError::GenericError {
error: label,
msg,
span,
..
} => LabeledError { label, msg, span },
ShellError::CantConvert {
to_type: expected,
from_type: input,
span,
help: _help,
} => LabeledError {
label: format!("Can't convert to {expected}"),
msg: format!("can't convert from {input} to {expected}"),
use miette::Diagnostic;
// This is not perfect - we can only take the first labeled span as that's all we have
// space for.
if let Some(labeled_span) = error.labels().and_then(|mut iter| iter.nth(0)) {
let offset = labeled_span.offset();
let span = Span::new(offset, offset + labeled_span.len());
LabeledError {
label: error.to_string(),
msg: labeled_span
.label()
.map(|label| label.to_owned())
.unwrap_or_else(|| "".into()),
span: Some(span),
},
ShellError::DidYouMean { suggestion, span } => LabeledError {
label: "Name not found".into(),
msg: format!("did you mean '{suggestion}'?"),
span: Some(span),
},
ShellError::PluginFailedToLoad { msg } => LabeledError {
label: "Plugin failed to load".into(),
msg,
}
} else {
LabeledError {
label: error.to_string(),
msg: error
.help()
.map(|help| help.to_string())
.unwrap_or_else(|| "".into()),
span: None,
},
ShellError::PluginFailedToEncode { msg } => LabeledError {
label: "Plugin failed to encode".into(),
msg,
span: None,
},
ShellError::PluginFailedToDecode { msg } => LabeledError {
label: "Plugin failed to decode".into(),
msg,
span: None,
},
err => LabeledError {
label: "Error - Add to LabeledError From<ShellError>".into(),
msg: err.to_string(),
span: None,
},
}
}
}
}
// Information received from the plugin
// Needs to be public to communicate with nu-parser but not typically
// used by Plugin authors
/// Response to a [`PluginCall`]. The type parameter determines the output type for pipeline data.
///
/// Note: exported for internal use, not public.
#[derive(Serialize, Deserialize, Debug, Clone)]
#[doc(hidden)]
#[derive(Serialize, Deserialize)]
pub enum PluginResponse {
pub enum PluginCallResponse<D> {
Error(LabeledError),
Signature(Vec<PluginSignature>),
Value(Box<Value>),
PluginData(String, PluginData),
PipelineData(D),
}
impl PluginCallResponse<PipelineDataHeader> {
/// Construct a plugin call response with a single value
pub fn value(value: Value) -> PluginCallResponse<PipelineDataHeader> {
if value.is_nothing() {
PluginCallResponse::PipelineData(PipelineDataHeader::Empty)
} else {
PluginCallResponse::PipelineData(PipelineDataHeader::Value(value))
}
}
}
/// Information received from the plugin
///
/// Note: exported for internal use, not public.
#[derive(Serialize, Deserialize, Debug, Clone)]
#[doc(hidden)]
pub enum PluginOutput {
/// This must be the first message. Indicates supported protocol
Hello(ProtocolInfo),
/// A response to a [`PluginCall`]. The ID should be the same sent with the plugin call this
/// is a response to
CallResponse(PluginCallId, PluginCallResponse<PipelineDataHeader>),
/// Stream control or data message. Untagged to keep them as small as possible.
///
/// For example, `Stream(Ack(0))` is encoded as `{"Ack": 0}`
#[serde(untagged)]
Stream(StreamMessage),
}
impl TryFrom<PluginOutput> for StreamMessage {
type Error = PluginOutput;
fn try_from(msg: PluginOutput) -> Result<StreamMessage, PluginOutput> {
match msg {
PluginOutput::Stream(stream_msg) => Ok(stream_msg),
_ => Err(msg),
}
}
}
impl From<StreamMessage> for PluginOutput {
fn from(stream_msg: StreamMessage) -> PluginOutput {
PluginOutput::Stream(stream_msg)
}
}

View File

@ -1,37 +1,39 @@
use std::path::PathBuf;
use std::sync::Arc;
use nu_protocol::{CustomValue, ShellError, Value};
use serde::Serialize;
use nu_protocol::{CustomValue, ShellError, Span, Spanned, Value};
use serde::{Deserialize, Serialize};
use crate::plugin::{call_plugin, create_command, get_plugin_encoding};
use crate::plugin::PluginIdentity;
use super::{PluginCall, PluginData, PluginResponse};
#[cfg(test)]
mod tests;
/// An opaque container for a custom value that is handled fully by a plugin
///
/// This is constructed by the main nushell engine when it receives [`PluginResponse::PluginData`]
/// it stores that data as well as metadata related to the plugin to be able to call the plugin
/// later.
/// Since the data in it is opaque to the engine, there are only two final destinations for it:
/// either it will be sent back to the plugin that generated it across a pipeline, or it will be
/// sent to the plugin with a request to collapse it into a base value
#[derive(Clone, Debug, Serialize)]
/// This is the only type of custom value that is allowed to cross the plugin serialization
/// boundary.
///
/// [`EngineInterface`](crate::interface::EngineInterface) is responsible for ensuring
/// that local plugin custom values are converted to and from [`PluginCustomData`] on the boundary.
///
/// [`PluginInterface`](crate::interface::PluginInterface) is responsible for adding the
/// appropriate [`PluginIdentity`](crate::plugin::PluginIdentity), ensuring that only
/// [`PluginCustomData`] is contained within any values sent, and that the `source` of any
/// values sent matches the plugin it is being sent to.
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct PluginCustomValue {
/// The name of the custom value as defined by the plugin
/// The name of the custom value as defined by the plugin (`value_string()`)
pub name: String,
/// The bincoded representation of the custom value on the plugin side
pub data: Vec<u8>,
pub filename: PathBuf,
// PluginCustomValue must implement Serialize because all CustomValues must implement Serialize
// However, the main place where values are serialized and deserialized is when they are being
// sent between plugins and nushell's main engine. PluginCustomValue is never meant to be sent
// between that boundary
#[serde(skip)]
pub shell: Option<PathBuf>,
#[serde(skip)]
pub source: String,
/// Which plugin the custom value came from. This is not defined on the plugin side. The engine
/// side is responsible for maintaining it, and it is not sent over the serialization boundary.
#[serde(skip, default)]
pub source: Option<Arc<PluginIdentity>>,
}
#[typetag::serde]
impl CustomValue for PluginCustomValue {
fn clone_value(&self, span: nu_protocol::Span) -> nu_protocol::Value {
Value::custom_value(Box::new(self.clone()), span)
@ -45,83 +47,295 @@ impl CustomValue for PluginCustomValue {
&self,
span: nu_protocol::Span,
) -> Result<nu_protocol::Value, nu_protocol::ShellError> {
let mut plugin_cmd = create_command(&self.filename, self.shell.as_deref());
let mut child = plugin_cmd.spawn().map_err(|err| ShellError::GenericError {
let wrap_err = |err: ShellError| ShellError::GenericError {
error: format!(
"Unable to spawn plugin for {} to get base value",
"Unable to spawn plugin `{}` to get base value",
self.source
.as_ref()
.map(|s| s.plugin_name.as_str())
.unwrap_or("<unknown>")
),
msg: format!("{err}"),
msg: err.to_string(),
span: Some(span),
help: None,
inner: vec![],
inner: vec![err],
};
let identity = self.source.clone().ok_or_else(|| {
wrap_err(ShellError::NushellFailed {
msg: "The plugin source for the custom value was not set".into(),
})
})?;
let plugin_call = PluginCall::CollapseCustomValue(PluginData {
data: self.data.clone(),
let empty_env: Option<(String, String)> = None;
let plugin = identity.spawn(empty_env).map_err(wrap_err)?;
plugin
.custom_value_to_base_value(Spanned {
item: self.clone(),
span,
});
let encoding = {
let stdout_reader = match &mut child.stdout {
Some(out) => out,
None => {
return Err(ShellError::PluginFailedToLoad {
msg: "Plugin missing stdout reader".into(),
})
}
};
get_plugin_encoding(stdout_reader)?
};
let response = call_plugin(&mut child, plugin_call, &encoding, span).map_err(|err| {
ShellError::GenericError {
error: format!(
"Unable to decode call for {} to get base value",
self.source
),
msg: format!("{err}"),
span: Some(span),
help: None,
inner: vec![],
}
});
let value = match response {
Ok(PluginResponse::Value(value)) => Ok(*value),
Ok(PluginResponse::PluginData(..)) => Err(ShellError::GenericError {
error: "Plugin misbehaving".into(),
msg: "Plugin returned custom data as a response to a collapse call".into(),
span: Some(span),
help: None,
inner: vec![],
}),
Ok(PluginResponse::Error(err)) => Err(err.into()),
Ok(PluginResponse::Signature(..)) => Err(ShellError::GenericError {
error: "Plugin missing value".into(),
msg: "Received a signature from plugin instead of value".into(),
span: Some(span),
help: None,
inner: vec![],
}),
Err(err) => Err(err),
};
// We need to call .wait() on the child, or we'll risk summoning the zombie horde
let _ = child.wait();
value
.map_err(wrap_err)
}
fn as_any(&self) -> &dyn std::any::Any {
self
}
fn typetag_name(&self) -> &'static str {
"PluginCustomValue"
}
fn typetag_deserialize(&self) {
unimplemented!("typetag_deserialize")
impl PluginCustomValue {
/// Serialize a custom value into a [`PluginCustomValue`]. This should only be done on the
/// plugin side.
pub(crate) fn serialize_from_custom_value(
custom_value: &dyn CustomValue,
span: Span,
) -> Result<PluginCustomValue, ShellError> {
let name = custom_value.value_string();
bincode::serialize(custom_value)
.map(|data| PluginCustomValue {
name,
data,
source: None,
})
.map_err(|err| ShellError::CustomValueFailedToEncode {
msg: err.to_string(),
span,
})
}
/// Deserialize a [`PluginCustomValue`] into a `Box<dyn CustomValue>`. This should only be done
/// on the plugin side.
pub(crate) fn deserialize_to_custom_value(
&self,
span: Span,
) -> Result<Box<dyn CustomValue>, ShellError> {
bincode::deserialize::<Box<dyn CustomValue>>(&self.data).map_err(|err| {
ShellError::CustomValueFailedToDecode {
msg: err.to_string(),
span,
}
})
}
/// Add a [`PluginIdentity`] to all [`PluginCustomValue`]s within a value, recursively.
pub(crate) fn add_source(value: &mut Value, source: &Arc<PluginIdentity>) {
let span = value.span();
match value {
// Set source on custom value
Value::CustomValue { ref val, .. } => {
if let Some(custom_value) = val.as_any().downcast_ref::<PluginCustomValue>() {
// Since there's no `as_mut_any()`, we have to copy the whole thing
let mut custom_value = custom_value.clone();
custom_value.source = Some(source.clone());
*value = Value::custom_value(Box::new(custom_value), span);
}
}
// Any values that can contain other values need to be handled recursively
Value::Range { ref mut val, .. } => {
Self::add_source(&mut val.from, source);
Self::add_source(&mut val.to, source);
Self::add_source(&mut val.incr, source);
}
Value::Record { ref mut val, .. } => {
for (_, rec_value) in val.iter_mut() {
Self::add_source(rec_value, source);
}
}
Value::List { ref mut vals, .. } => {
for list_value in vals.iter_mut() {
Self::add_source(list_value, source);
}
}
// All of these don't contain other values
Value::Bool { .. }
| Value::Int { .. }
| Value::Float { .. }
| Value::Filesize { .. }
| Value::Duration { .. }
| Value::Date { .. }
| Value::String { .. }
| Value::Glob { .. }
| Value::Block { .. }
| Value::Closure { .. }
| Value::Nothing { .. }
| Value::Error { .. }
| Value::Binary { .. }
| Value::CellPath { .. } => (),
// LazyRecord could generate other values, but we shouldn't be receiving it anyway
//
// It's better to handle this as a bug
Value::LazyRecord { .. } => unimplemented!("add_source for LazyRecord"),
}
}
/// Check that all [`CustomValue`]s present within the `value` are [`PluginCustomValue`]s that
/// come from the given `source`, and return an error if not.
///
/// This method will collapse `LazyRecord` in-place as necessary to make the guarantee,
/// since `LazyRecord` could return something different the next time it is called.
pub(crate) fn verify_source(
value: &mut Value,
source: &PluginIdentity,
) -> Result<(), ShellError> {
let span = value.span();
match value {
// Set source on custom value
Value::CustomValue { val, .. } => {
if let Some(custom_value) = val.as_any().downcast_ref::<PluginCustomValue>() {
if custom_value.source.as_deref() == Some(source) {
Ok(())
} else {
Err(ShellError::CustomValueIncorrectForPlugin {
name: custom_value.name.clone(),
span,
dest_plugin: source.plugin_name.clone(),
src_plugin: custom_value.source.as_ref().map(|s| s.plugin_name.clone()),
})
}
} else {
// Only PluginCustomValues can be sent
Err(ShellError::CustomValueIncorrectForPlugin {
name: val.value_string(),
span,
dest_plugin: source.plugin_name.clone(),
src_plugin: None,
})
}
}
// Any values that can contain other values need to be handled recursively
Value::Range { val, .. } => {
Self::verify_source(&mut val.from, source)?;
Self::verify_source(&mut val.to, source)?;
Self::verify_source(&mut val.incr, source)
}
Value::Record { ref mut val, .. } => val
.iter_mut()
.try_for_each(|(_, rec_value)| Self::verify_source(rec_value, source)),
Value::List { ref mut vals, .. } => vals
.iter_mut()
.try_for_each(|list_value| Self::verify_source(list_value, source)),
// All of these don't contain other values
Value::Bool { .. }
| Value::Int { .. }
| Value::Float { .. }
| Value::Filesize { .. }
| Value::Duration { .. }
| Value::Date { .. }
| Value::String { .. }
| Value::Glob { .. }
| Value::Block { .. }
| Value::Closure { .. }
| Value::Nothing { .. }
| Value::Error { .. }
| Value::Binary { .. }
| Value::CellPath { .. } => Ok(()),
// LazyRecord would be a problem for us, since it could return something else the next
// time, and we have to collect it anyway to serialize it. Collect it in place, and then
// verify the source of the result
Value::LazyRecord { val, .. } => {
*value = val.collect()?;
Self::verify_source(value, source)
}
}
}
/// Convert all plugin-native custom values to [`PluginCustomValue`] within the given `value`,
/// recursively. This should only be done on the plugin side.
pub(crate) fn serialize_custom_values_in(value: &mut Value) -> Result<(), ShellError> {
let span = value.span();
match value {
Value::CustomValue { ref val, .. } => {
if val.as_any().downcast_ref::<PluginCustomValue>().is_some() {
// Already a PluginCustomValue
Ok(())
} else {
let serialized = Self::serialize_from_custom_value(&**val, span)?;
*value = Value::custom_value(Box::new(serialized), span);
Ok(())
}
}
// Any values that can contain other values need to be handled recursively
Value::Range { ref mut val, .. } => {
Self::serialize_custom_values_in(&mut val.from)?;
Self::serialize_custom_values_in(&mut val.to)?;
Self::serialize_custom_values_in(&mut val.incr)
}
Value::Record { ref mut val, .. } => val
.iter_mut()
.try_for_each(|(_, rec_value)| Self::serialize_custom_values_in(rec_value)),
Value::List { ref mut vals, .. } => vals
.iter_mut()
.try_for_each(Self::serialize_custom_values_in),
// All of these don't contain other values
Value::Bool { .. }
| Value::Int { .. }
| Value::Float { .. }
| Value::Filesize { .. }
| Value::Duration { .. }
| Value::Date { .. }
| Value::String { .. }
| Value::Glob { .. }
| Value::Block { .. }
| Value::Closure { .. }
| Value::Nothing { .. }
| Value::Error { .. }
| Value::Binary { .. }
| Value::CellPath { .. } => Ok(()),
// Collect any lazy records that exist and try again
Value::LazyRecord { val, .. } => {
*value = val.collect()?;
Self::serialize_custom_values_in(value)
}
}
}
/// Convert all [`PluginCustomValue`]s to plugin-native custom values within the given `value`,
/// recursively. This should only be done on the plugin side.
pub(crate) fn deserialize_custom_values_in(value: &mut Value) -> Result<(), ShellError> {
let span = value.span();
match value {
Value::CustomValue { ref val, .. } => {
if let Some(val) = val.as_any().downcast_ref::<PluginCustomValue>() {
let deserialized = val.deserialize_to_custom_value(span)?;
*value = Value::custom_value(deserialized, span);
Ok(())
} else {
// Already not a PluginCustomValue
Ok(())
}
}
// Any values that can contain other values need to be handled recursively
Value::Range { ref mut val, .. } => {
Self::deserialize_custom_values_in(&mut val.from)?;
Self::deserialize_custom_values_in(&mut val.to)?;
Self::deserialize_custom_values_in(&mut val.incr)
}
Value::Record { ref mut val, .. } => val
.iter_mut()
.try_for_each(|(_, rec_value)| Self::deserialize_custom_values_in(rec_value)),
Value::List { ref mut vals, .. } => vals
.iter_mut()
.try_for_each(Self::deserialize_custom_values_in),
// All of these don't contain other values
Value::Bool { .. }
| Value::Int { .. }
| Value::Float { .. }
| Value::Filesize { .. }
| Value::Duration { .. }
| Value::Date { .. }
| Value::String { .. }
| Value::Glob { .. }
| Value::Block { .. }
| Value::Closure { .. }
| Value::Nothing { .. }
| Value::Error { .. }
| Value::Binary { .. }
| Value::CellPath { .. } => Ok(()),
// Collect any lazy records that exist and try again
Value::LazyRecord { val, .. } => {
*value = val.collect()?;
Self::deserialize_custom_values_in(value)
}
}
}
}

View File

@ -0,0 +1,492 @@
use nu_protocol::{ast::RangeInclusion, record, CustomValue, Range, ShellError, Span, Value};
use crate::{
plugin::PluginIdentity,
protocol::test_util::{
expected_test_custom_value, test_plugin_custom_value, test_plugin_custom_value_with_source,
TestCustomValue,
},
};
use super::PluginCustomValue;
#[test]
fn serialize_deserialize() -> Result<(), ShellError> {
let original_value = TestCustomValue(32);
let span = Span::test_data();
let serialized = PluginCustomValue::serialize_from_custom_value(&original_value, span)?;
assert_eq!(original_value.value_string(), serialized.name);
assert!(serialized.source.is_none());
let deserialized = serialized.deserialize_to_custom_value(span)?;
let downcasted = deserialized
.as_any()
.downcast_ref::<TestCustomValue>()
.expect("failed to downcast: not TestCustomValue");
assert_eq!(original_value, *downcasted);
Ok(())
}
#[test]
fn expected_serialize_output() -> Result<(), ShellError> {
let original_value = expected_test_custom_value();
let span = Span::test_data();
let serialized = PluginCustomValue::serialize_from_custom_value(&original_value, span)?;
assert_eq!(
test_plugin_custom_value().data,
serialized.data,
"The bincode configuration is probably different from what we expected. \
Fix test_plugin_custom_value() to match it"
);
Ok(())
}
#[test]
fn add_source_at_root() -> Result<(), ShellError> {
let mut val = Value::test_custom_value(Box::new(test_plugin_custom_value()));
let source = PluginIdentity::new_fake("foo");
PluginCustomValue::add_source(&mut val, &source);
let custom_value = val.as_custom_value()?;
let plugin_custom_value: &PluginCustomValue = custom_value
.as_any()
.downcast_ref()
.expect("not PluginCustomValue");
assert_eq!(Some(source), plugin_custom_value.source);
Ok(())
}
fn check_range_custom_values(
val: &Value,
mut f: impl FnMut(&str, &dyn CustomValue) -> Result<(), ShellError>,
) -> Result<(), ShellError> {
let range = val.as_range()?;
for (name, val) in [
("from", &range.from),
("incr", &range.incr),
("to", &range.to),
] {
let custom_value = val
.as_custom_value()
.unwrap_or_else(|_| panic!("{name} not custom value"));
f(name, custom_value)?;
}
Ok(())
}
#[test]
fn add_source_nested_range() -> Result<(), ShellError> {
let orig_custom_val = Value::test_custom_value(Box::new(test_plugin_custom_value()));
let mut val = Value::test_range(Range {
from: orig_custom_val.clone(),
incr: orig_custom_val.clone(),
to: orig_custom_val.clone(),
inclusion: RangeInclusion::Inclusive,
});
let source = PluginIdentity::new_fake("foo");
PluginCustomValue::add_source(&mut val, &source);
check_range_custom_values(&val, |name, custom_value| {
let plugin_custom_value: &PluginCustomValue = custom_value
.as_any()
.downcast_ref()
.unwrap_or_else(|| panic!("{name} not PluginCustomValue"));
assert_eq!(
Some(&source),
plugin_custom_value.source.as_ref(),
"{name} source not set correctly"
);
Ok(())
})
}
fn check_record_custom_values(
val: &Value,
keys: &[&str],
mut f: impl FnMut(&str, &dyn CustomValue) -> Result<(), ShellError>,
) -> Result<(), ShellError> {
let record = val.as_record()?;
for key in keys {
let val = record
.get(key)
.unwrap_or_else(|| panic!("record does not contain '{key}'"));
let custom_value = val
.as_custom_value()
.unwrap_or_else(|_| panic!("'{key}' not custom value"));
f(key, custom_value)?;
}
Ok(())
}
#[test]
fn add_source_nested_record() -> Result<(), ShellError> {
let orig_custom_val = Value::test_custom_value(Box::new(test_plugin_custom_value()));
let mut val = Value::test_record(record! {
"foo" => orig_custom_val.clone(),
"bar" => orig_custom_val.clone(),
});
let source = PluginIdentity::new_fake("foo");
PluginCustomValue::add_source(&mut val, &source);
check_record_custom_values(&val, &["foo", "bar"], |key, custom_value| {
let plugin_custom_value: &PluginCustomValue = custom_value
.as_any()
.downcast_ref()
.unwrap_or_else(|| panic!("'{key}' not PluginCustomValue"));
assert_eq!(
Some(&source),
plugin_custom_value.source.as_ref(),
"'{key}' source not set correctly"
);
Ok(())
})
}
fn check_list_custom_values(
val: &Value,
indices: impl IntoIterator<Item = usize>,
mut f: impl FnMut(usize, &dyn CustomValue) -> Result<(), ShellError>,
) -> Result<(), ShellError> {
let list = val.as_list()?;
for index in indices {
let val = list
.get(index)
.unwrap_or_else(|| panic!("[{index}] not present in list"));
let custom_value = val
.as_custom_value()
.unwrap_or_else(|_| panic!("[{index}] not custom value"));
f(index, custom_value)?;
}
Ok(())
}
#[test]
fn add_source_nested_list() -> Result<(), ShellError> {
let orig_custom_val = Value::test_custom_value(Box::new(test_plugin_custom_value()));
let mut val = Value::test_list(vec![orig_custom_val.clone(), orig_custom_val.clone()]);
let source = PluginIdentity::new_fake("foo");
PluginCustomValue::add_source(&mut val, &source);
check_list_custom_values(&val, 0..=1, |index, custom_value| {
let plugin_custom_value: &PluginCustomValue = custom_value
.as_any()
.downcast_ref()
.unwrap_or_else(|| panic!("[{index}] not PluginCustomValue"));
assert_eq!(
Some(&source),
plugin_custom_value.source.as_ref(),
"[{index}] source not set correctly"
);
Ok(())
})
}
#[test]
fn verify_source_error_message() -> Result<(), ShellError> {
let span = Span::new(5, 7);
let mut ok_val = Value::custom_value(Box::new(test_plugin_custom_value_with_source()), span);
let mut native_val = Value::custom_value(Box::new(TestCustomValue(32)), span);
let mut foreign_val = {
let mut val = test_plugin_custom_value();
val.source = Some(PluginIdentity::new_fake("other"));
Value::custom_value(Box::new(val), span)
};
let source = PluginIdentity::new_fake("test");
PluginCustomValue::verify_source(&mut ok_val, &source).expect("ok_val should be verified ok");
for (val, src_plugin) in [(&mut native_val, None), (&mut foreign_val, Some("other"))] {
let error = PluginCustomValue::verify_source(val, &source).expect_err(&format!(
"a custom value from {src_plugin:?} should result in an error"
));
if let ShellError::CustomValueIncorrectForPlugin {
name,
span: err_span,
dest_plugin,
src_plugin: err_src_plugin,
} = error
{
assert_eq!("TestCustomValue", name, "error.name from {src_plugin:?}");
assert_eq!(span, err_span, "error.span from {src_plugin:?}");
assert_eq!("test", dest_plugin, "error.dest_plugin from {src_plugin:?}");
assert_eq!(src_plugin, err_src_plugin.as_deref(), "error.src_plugin");
} else {
panic!("the error returned should be CustomValueIncorrectForPlugin");
}
}
Ok(())
}
#[test]
fn verify_source_nested_range() -> Result<(), ShellError> {
let native_val = Value::test_custom_value(Box::new(TestCustomValue(32)));
let source = PluginIdentity::new_fake("test");
for (name, mut val) in [
(
"from",
Value::test_range(Range {
from: native_val.clone(),
incr: Value::test_nothing(),
to: Value::test_nothing(),
inclusion: RangeInclusion::RightExclusive,
}),
),
(
"incr",
Value::test_range(Range {
from: Value::test_nothing(),
incr: native_val.clone(),
to: Value::test_nothing(),
inclusion: RangeInclusion::RightExclusive,
}),
),
(
"to",
Value::test_range(Range {
from: Value::test_nothing(),
incr: Value::test_nothing(),
to: native_val.clone(),
inclusion: RangeInclusion::RightExclusive,
}),
),
] {
PluginCustomValue::verify_source(&mut val, &source)
.expect_err(&format!("error not generated on {name}"));
}
let mut ok_range = Value::test_range(Range {
from: Value::test_nothing(),
incr: Value::test_nothing(),
to: Value::test_nothing(),
inclusion: RangeInclusion::RightExclusive,
});
PluginCustomValue::verify_source(&mut ok_range, &source)
.expect("ok_range should not generate error");
Ok(())
}
#[test]
fn verify_source_nested_record() -> Result<(), ShellError> {
let native_val = Value::test_custom_value(Box::new(TestCustomValue(32)));
let source = PluginIdentity::new_fake("test");
for (name, mut val) in [
(
"first element foo",
Value::test_record(record! {
"foo" => native_val.clone(),
"bar" => Value::test_nothing(),
}),
),
(
"second element bar",
Value::test_record(record! {
"foo" => Value::test_nothing(),
"bar" => native_val.clone(),
}),
),
] {
PluginCustomValue::verify_source(&mut val, &source)
.expect_err(&format!("error not generated on {name}"));
}
let mut ok_record = Value::test_record(record! {"foo" => Value::test_nothing()});
PluginCustomValue::verify_source(&mut ok_record, &source)
.expect("ok_record should not generate error");
Ok(())
}
#[test]
fn verify_source_nested_list() -> Result<(), ShellError> {
let native_val = Value::test_custom_value(Box::new(TestCustomValue(32)));
let source = PluginIdentity::new_fake("test");
for (name, mut val) in [
(
"first element",
Value::test_list(vec![native_val.clone(), Value::test_nothing()]),
),
(
"second element",
Value::test_list(vec![Value::test_nothing(), native_val.clone()]),
),
] {
PluginCustomValue::verify_source(&mut val, &source)
.expect_err(&format!("error not generated on {name}"));
}
let mut ok_list = Value::test_list(vec![Value::test_nothing()]);
PluginCustomValue::verify_source(&mut ok_list, &source)
.expect("ok_list should not generate error");
Ok(())
}
#[test]
fn serialize_in_root() -> Result<(), ShellError> {
let span = Span::new(4, 10);
let mut val = Value::custom_value(Box::new(expected_test_custom_value()), span);
PluginCustomValue::serialize_custom_values_in(&mut val)?;
assert_eq!(span, val.span());
let custom_value = val.as_custom_value()?;
if let Some(plugin_custom_value) = custom_value.as_any().downcast_ref::<PluginCustomValue>() {
assert_eq!("TestCustomValue", plugin_custom_value.name);
assert_eq!(test_plugin_custom_value().data, plugin_custom_value.data);
assert!(plugin_custom_value.source.is_none());
} else {
panic!("Failed to downcast to PluginCustomValue");
}
Ok(())
}
#[test]
fn serialize_in_range() -> Result<(), ShellError> {
let orig_custom_val = Value::test_custom_value(Box::new(TestCustomValue(-1)));
let mut val = Value::test_range(Range {
from: orig_custom_val.clone(),
incr: orig_custom_val.clone(),
to: orig_custom_val.clone(),
inclusion: RangeInclusion::Inclusive,
});
PluginCustomValue::serialize_custom_values_in(&mut val)?;
check_range_custom_values(&val, |name, custom_value| {
let plugin_custom_value: &PluginCustomValue = custom_value
.as_any()
.downcast_ref()
.unwrap_or_else(|| panic!("{name} not PluginCustomValue"));
assert_eq!(
"TestCustomValue", plugin_custom_value.name,
"{name} name not set correctly"
);
Ok(())
})
}
#[test]
fn serialize_in_record() -> Result<(), ShellError> {
let orig_custom_val = Value::test_custom_value(Box::new(TestCustomValue(32)));
let mut val = Value::test_record(record! {
"foo" => orig_custom_val.clone(),
"bar" => orig_custom_val.clone(),
});
PluginCustomValue::serialize_custom_values_in(&mut val)?;
check_record_custom_values(&val, &["foo", "bar"], |key, custom_value| {
let plugin_custom_value: &PluginCustomValue = custom_value
.as_any()
.downcast_ref()
.unwrap_or_else(|| panic!("'{key}' not PluginCustomValue"));
assert_eq!(
"TestCustomValue", plugin_custom_value.name,
"'{key}' name not set correctly"
);
Ok(())
})
}
#[test]
fn serialize_in_list() -> Result<(), ShellError> {
let orig_custom_val = Value::test_custom_value(Box::new(TestCustomValue(24)));
let mut val = Value::test_list(vec![orig_custom_val.clone(), orig_custom_val.clone()]);
PluginCustomValue::serialize_custom_values_in(&mut val)?;
check_list_custom_values(&val, 0..=1, |index, custom_value| {
let plugin_custom_value: &PluginCustomValue = custom_value
.as_any()
.downcast_ref()
.unwrap_or_else(|| panic!("[{index}] not PluginCustomValue"));
assert_eq!(
"TestCustomValue", plugin_custom_value.name,
"[{index}] name not set correctly"
);
Ok(())
})
}
#[test]
fn deserialize_in_root() -> Result<(), ShellError> {
let span = Span::new(4, 10);
let mut val = Value::custom_value(Box::new(test_plugin_custom_value()), span);
PluginCustomValue::deserialize_custom_values_in(&mut val)?;
assert_eq!(span, val.span());
let custom_value = val.as_custom_value()?;
if let Some(test_custom_value) = custom_value.as_any().downcast_ref::<TestCustomValue>() {
assert_eq!(expected_test_custom_value(), *test_custom_value);
} else {
panic!("Failed to downcast to TestCustomValue");
}
Ok(())
}
#[test]
fn deserialize_in_range() -> Result<(), ShellError> {
let orig_custom_val = Value::test_custom_value(Box::new(test_plugin_custom_value()));
let mut val = Value::test_range(Range {
from: orig_custom_val.clone(),
incr: orig_custom_val.clone(),
to: orig_custom_val.clone(),
inclusion: RangeInclusion::Inclusive,
});
PluginCustomValue::deserialize_custom_values_in(&mut val)?;
check_range_custom_values(&val, |name, custom_value| {
let test_custom_value: &TestCustomValue = custom_value
.as_any()
.downcast_ref()
.unwrap_or_else(|| panic!("{name} not TestCustomValue"));
assert_eq!(
expected_test_custom_value(),
*test_custom_value,
"{name} not deserialized correctly"
);
Ok(())
})
}
#[test]
fn deserialize_in_record() -> Result<(), ShellError> {
let orig_custom_val = Value::test_custom_value(Box::new(test_plugin_custom_value()));
let mut val = Value::test_record(record! {
"foo" => orig_custom_val.clone(),
"bar" => orig_custom_val.clone(),
});
PluginCustomValue::deserialize_custom_values_in(&mut val)?;
check_record_custom_values(&val, &["foo", "bar"], |key, custom_value| {
let test_custom_value: &TestCustomValue = custom_value
.as_any()
.downcast_ref()
.unwrap_or_else(|| panic!("'{key}' not TestCustomValue"));
assert_eq!(
expected_test_custom_value(),
*test_custom_value,
"{key} not deserialized correctly"
);
Ok(())
})
}
#[test]
fn deserialize_in_list() -> Result<(), ShellError> {
let orig_custom_val = Value::test_custom_value(Box::new(test_plugin_custom_value()));
let mut val = Value::test_list(vec![orig_custom_val.clone(), orig_custom_val.clone()]);
PluginCustomValue::deserialize_custom_values_in(&mut val)?;
check_list_custom_values(&val, 0..=1, |index, custom_value| {
let test_custom_value: &TestCustomValue = custom_value
.as_any()
.downcast_ref()
.unwrap_or_else(|| panic!("[{index}] not TestCustomValue"));
assert_eq!(
expected_test_custom_value(),
*test_custom_value,
"[{index}] name not deserialized correctly"
);
Ok(())
})
}

View File

@ -1,8 +0,0 @@
use nu_protocol::Span;
use serde::{Deserialize, Serialize};
#[derive(Serialize, Deserialize, Debug, PartialEq, Eq)]
pub struct PluginData {
pub data: Vec<u8>,
pub span: Span,
}

View File

@ -0,0 +1,80 @@
use nu_protocol::ShellError;
use serde::{Deserialize, Serialize};
/// Protocol information, sent as a `Hello` message on initialization. This determines the
/// compatibility of the plugin and engine. They are considered to be compatible if the lower
/// version is semver compatible with the higher one.
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct ProtocolInfo {
/// The name of the protocol being implemented. Only one protocol is supported. This field
/// can be safely ignored, because not matching is a deserialization error
pub protocol: Protocol,
/// The semantic version of the protocol. This should be the version of the `nu-plugin`
/// crate
pub version: String,
/// Supported optional features. This helps to maintain semver compatibility when adding new
/// features
pub features: Vec<Feature>,
}
impl Default for ProtocolInfo {
fn default() -> ProtocolInfo {
ProtocolInfo {
protocol: Protocol::NuPlugin,
version: env!("CARGO_PKG_VERSION").into(),
features: vec![],
}
}
}
impl ProtocolInfo {
pub fn is_compatible_with(&self, other: &ProtocolInfo) -> Result<bool, ShellError> {
fn parse_failed(error: semver::Error) -> ShellError {
ShellError::PluginFailedToLoad {
msg: format!("Failed to parse protocol version: {error}"),
}
}
let mut versions = [
semver::Version::parse(&self.version).map_err(parse_failed)?,
semver::Version::parse(&other.version).map_err(parse_failed)?,
];
versions.sort();
// For example, if the lower version is 1.1.0, and the higher version is 1.2.3, the
// requirement is that 1.2.3 matches ^1.1.0 (which it does)
Ok(semver::Comparator {
op: semver::Op::Caret,
major: versions[0].major,
minor: Some(versions[0].minor),
patch: Some(versions[0].patch),
pre: versions[0].pre.clone(),
}
.matches(&versions[1]))
}
}
/// Indicates the protocol in use. Only one protocol is supported.
#[derive(Serialize, Deserialize, Debug, Clone, Default)]
pub enum Protocol {
/// Serializes to the value `"nu-plugin"`
#[serde(rename = "nu-plugin")]
#[default]
NuPlugin,
}
/// Indicates optional protocol features. This can help to make non-breaking-change additions to
/// the protocol. Features are not restricted to plain strings and can contain additional
/// configuration data.
///
/// Optional features should not be used by the protocol if they are not present in the
/// [`ProtocolInfo`] sent by the other side.
#[derive(Serialize, Deserialize, Debug, Clone)]
#[serde(tag = "name")]
pub enum Feature {
/// A feature that was not recognized on deserialization. Attempting to serialize this feature
/// is an error. Matching against it may only be used if necessary to determine whether
/// unsupported features are present.
#[serde(other, skip_serializing)]
Unknown,
}

View File

@ -0,0 +1,50 @@
use nu_protocol::{CustomValue, ShellError, Span, Value};
use serde::{Deserialize, Serialize};
use crate::plugin::PluginIdentity;
use super::PluginCustomValue;
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub(crate) struct TestCustomValue(pub i32);
#[typetag::serde]
impl CustomValue for TestCustomValue {
fn clone_value(&self, span: Span) -> Value {
Value::custom_value(Box::new(self.clone()), span)
}
fn value_string(&self) -> String {
"TestCustomValue".into()
}
fn to_base_value(&self, span: Span) -> Result<Value, ShellError> {
Ok(Value::int(self.0 as i64, span))
}
fn as_any(&self) -> &dyn std::any::Any {
self
}
}
pub(crate) fn test_plugin_custom_value() -> PluginCustomValue {
let data = bincode::serialize(&expected_test_custom_value() as &dyn CustomValue)
.expect("bincode serialization of the expected_test_custom_value() failed");
PluginCustomValue {
name: "TestCustomValue".into(),
data,
source: None,
}
}
pub(crate) fn expected_test_custom_value() -> TestCustomValue {
TestCustomValue(-1)
}
pub(crate) fn test_plugin_custom_value_with_source() -> PluginCustomValue {
PluginCustomValue {
source: Some(PluginIdentity::new_fake("test")),
..test_plugin_custom_value()
}
}

View File

@ -0,0 +1,35 @@
use super::*;
#[test]
fn protocol_info_compatible() -> Result<(), ShellError> {
let ver_1_2_3 = ProtocolInfo {
protocol: Protocol::NuPlugin,
version: "1.2.3".into(),
features: vec![],
};
let ver_1_1_0 = ProtocolInfo {
protocol: Protocol::NuPlugin,
version: "1.1.0".into(),
features: vec![],
};
assert!(ver_1_1_0.is_compatible_with(&ver_1_2_3)?);
assert!(ver_1_2_3.is_compatible_with(&ver_1_1_0)?);
Ok(())
}
#[test]
fn protocol_info_incompatible() -> Result<(), ShellError> {
let ver_2_0_0 = ProtocolInfo {
protocol: Protocol::NuPlugin,
version: "2.0.0".into(),
features: vec![],
};
let ver_1_1_0 = ProtocolInfo {
protocol: Protocol::NuPlugin,
version: "1.1.0".into(),
features: vec![],
};
assert!(!ver_2_0_0.is_compatible_with(&ver_1_1_0)?);
assert!(!ver_1_1_0.is_compatible_with(&ver_2_0_0)?);
Ok(())
}

View File

@ -0,0 +1,65 @@
use std::sync::atomic::{AtomicUsize, Ordering::Relaxed};
use nu_protocol::ShellError;
/// Implements an atomically incrementing sequential series of numbers
#[derive(Debug, Default)]
pub(crate) struct Sequence(AtomicUsize);
impl Sequence {
/// Return the next available id from a sequence, returning an error on overflow
#[track_caller]
pub(crate) fn next(&self) -> Result<usize, ShellError> {
// It's totally safe to use Relaxed ordering here, as there aren't other memory operations
// that depend on this value having been set for safety
//
// We're only not using `fetch_add` so that we can check for overflow, as wrapping with the
// identifier would lead to a serious bug - however unlikely that is.
self.0
.fetch_update(Relaxed, Relaxed, |current| current.checked_add(1))
.map_err(|_| ShellError::NushellFailedHelp {
msg: "an accumulator for identifiers overflowed".into(),
help: format!("see {}", std::panic::Location::caller()),
})
}
}
#[test]
fn output_is_sequential() {
let sequence = Sequence::default();
for (expected, generated) in (0..1000).zip(std::iter::repeat_with(|| sequence.next())) {
assert_eq!(expected, generated.expect("error in sequence"));
}
}
#[test]
fn output_is_unique_even_under_contention() {
let sequence = Sequence::default();
std::thread::scope(|scope| {
// Spawn four threads, all advancing the sequence simultaneously
let threads = (0..4)
.map(|_| {
scope.spawn(|| {
(0..100000)
.map(|_| sequence.next())
.collect::<Result<Vec<_>, _>>()
})
})
.collect::<Vec<_>>();
// Collect all of the results into a single flat vec
let mut results = threads
.into_iter()
.flat_map(|thread| thread.join().expect("panicked").expect("error"))
.collect::<Vec<usize>>();
// Check uniqueness
results.sort();
let initial_length = results.len();
results.dedup();
let deduplicated_length = results.len();
assert_eq!(initial_length, deduplicated_length);
})
}

View File

@ -1,53 +1,94 @@
use crate::{
plugin::{Encoder, PluginEncoder},
protocol::{PluginInput, PluginOutput},
};
use nu_protocol::ShellError;
use serde::Deserialize;
use crate::{plugin::PluginEncoder, protocol::PluginResponse};
/// A `PluginEncoder` that enables the plugin to communicate with Nushel with JSON
/// A `PluginEncoder` that enables the plugin to communicate with Nushell with JSON
/// serialized data.
#[derive(Clone, Debug)]
///
/// Each message in the stream is followed by a newline when serializing, but is not required for
/// deserialization. The output is not pretty printed and each object does not contain newlines.
/// If it is more convenient, a plugin may choose to separate messages by newline.
#[derive(Clone, Copy, Debug)]
pub struct JsonSerializer;
impl PluginEncoder for JsonSerializer {
fn name(&self) -> &str {
"json"
}
}
fn encode_call(
impl Encoder<PluginInput> for JsonSerializer {
fn encode(
&self,
plugin_call: &crate::protocol::PluginCall,
plugin_input: &PluginInput,
writer: &mut impl std::io::Write,
) -> Result<(), nu_protocol::ShellError> {
serde_json::to_writer(writer, plugin_call).map_err(|err| ShellError::PluginFailedToEncode {
serde_json::to_writer(&mut *writer, plugin_input).map_err(json_encode_err)?;
writer.write_all(b"\n").map_err(|err| ShellError::IOError {
msg: err.to_string(),
})
}
fn decode_call(
fn decode(
&self,
reader: &mut impl std::io::BufRead,
) -> Result<crate::protocol::PluginCall, nu_protocol::ShellError> {
serde_json::from_reader(reader).map_err(|err| ShellError::PluginFailedToEncode {
) -> Result<Option<PluginInput>, nu_protocol::ShellError> {
let mut de = serde_json::Deserializer::from_reader(reader);
PluginInput::deserialize(&mut de)
.map(Some)
.or_else(json_decode_err)
}
}
impl Encoder<PluginOutput> for JsonSerializer {
fn encode(
&self,
plugin_output: &PluginOutput,
writer: &mut impl std::io::Write,
) -> Result<(), ShellError> {
serde_json::to_writer(&mut *writer, plugin_output).map_err(json_encode_err)?;
writer.write_all(b"\n").map_err(|err| ShellError::IOError {
msg: err.to_string(),
})
}
fn encode_response(
fn decode(
&self,
plugin_response: &PluginResponse,
writer: &mut impl std::io::Write,
) -> Result<(), ShellError> {
serde_json::to_writer(writer, plugin_response).map_err(|err| {
reader: &mut impl std::io::BufRead,
) -> Result<Option<PluginOutput>, ShellError> {
let mut de = serde_json::Deserializer::from_reader(reader);
PluginOutput::deserialize(&mut de)
.map(Some)
.or_else(json_decode_err)
}
}
/// Handle a `serde_json` encode error.
fn json_encode_err(err: serde_json::Error) -> ShellError {
if err.is_io() {
ShellError::IOError {
msg: err.to_string(),
}
} else {
ShellError::PluginFailedToEncode {
msg: err.to_string(),
}
})
}
}
fn decode_response(
&self,
reader: &mut impl std::io::BufRead,
) -> Result<PluginResponse, ShellError> {
serde_json::from_reader(reader).map_err(|err| ShellError::PluginFailedToEncode {
/// Handle a `serde_json` decode error. Returns `Ok(None)` on eof.
fn json_decode_err<T>(err: serde_json::Error) -> Result<Option<T>, ShellError> {
if err.is_eof() {
Ok(None)
} else if err.is_io() {
Err(ShellError::IOError {
msg: err.to_string(),
})
} else {
Err(ShellError::PluginFailedToDecode {
msg: err.to_string(),
})
}
@ -56,306 +97,38 @@ impl PluginEncoder for JsonSerializer {
#[cfg(test)]
mod tests {
use super::*;
use crate::protocol::{
CallInfo, CallInput, EvaluatedCall, LabeledError, PluginCall, PluginData, PluginResponse,
};
use nu_protocol::{PluginSignature, Span, Spanned, SyntaxShape, Value};
crate::serializers::tests::generate_tests!(JsonSerializer {});
#[test]
fn callinfo_round_trip_signature() {
let plugin_call = PluginCall::Signature;
let encoder = JsonSerializer {};
let mut buffer: Vec<u8> = Vec::new();
encoder
.encode_call(&plugin_call, &mut buffer)
.expect("unable to serialize message");
let returned = encoder
.decode_call(&mut buffer.as_slice())
.expect("unable to deserialize message");
match returned {
PluginCall::Signature => {}
PluginCall::CallInfo(_) => panic!("decoded into wrong value"),
PluginCall::CollapseCustomValue(_) => panic!("decoded into wrong value"),
}
}
#[test]
fn callinfo_round_trip_callinfo() {
let name = "test".to_string();
let input = Value::bool(false, Span::new(1, 20));
let call = EvaluatedCall {
head: Span::new(0, 10),
positional: vec![
Value::float(1.0, Span::new(0, 10)),
Value::string("something", Span::new(0, 10)),
],
named: vec![(
Spanned {
item: "name".to_string(),
span: Span::new(0, 10),
},
Some(Value::float(1.0, Span::new(0, 10))),
)],
};
let plugin_call = PluginCall::CallInfo(CallInfo {
name: name.clone(),
call: call.clone(),
input: CallInput::Value(input.clone()),
config: None,
});
let encoder = JsonSerializer {};
let mut buffer: Vec<u8> = Vec::new();
encoder
.encode_call(&plugin_call, &mut buffer)
.expect("unable to serialize message");
let returned = encoder
.decode_call(&mut buffer.as_slice())
.expect("unable to deserialize message");
match returned {
PluginCall::Signature => panic!("returned wrong call type"),
PluginCall::CallInfo(call_info) => {
assert_eq!(name, call_info.name);
assert_eq!(CallInput::Value(input), call_info.input);
assert_eq!(call.head, call_info.call.head);
assert_eq!(call.positional.len(), call_info.call.positional.len());
call.positional
.iter()
.zip(call_info.call.positional.iter())
.for_each(|(lhs, rhs)| assert_eq!(lhs, rhs));
call.named
.iter()
.zip(call_info.call.named.iter())
.for_each(|(lhs, rhs)| {
// Comparing the keys
assert_eq!(lhs.0.item, rhs.0.item);
match (&lhs.1, &rhs.1) {
(None, None) => {}
(Some(a), Some(b)) => assert_eq!(a, b),
_ => panic!("not matching values"),
}
});
}
PluginCall::CollapseCustomValue(_) => panic!("returned wrong call type"),
}
}
#[test]
fn callinfo_round_trip_collapsecustomvalue() {
let data = vec![1, 2, 3, 4, 5, 6, 7];
let span = Span::new(0, 20);
let collapse_custom_value = PluginCall::CollapseCustomValue(PluginData {
data: data.clone(),
span,
});
let encoder = JsonSerializer {};
let mut buffer: Vec<u8> = Vec::new();
encoder
.encode_call(&collapse_custom_value, &mut buffer)
.expect("unable to serialize message");
let returned = encoder
.decode_call(&mut buffer.as_slice())
.expect("unable to deserialize message");
match returned {
PluginCall::Signature => panic!("returned wrong call type"),
PluginCall::CallInfo(_) => panic!("returned wrong call type"),
PluginCall::CollapseCustomValue(plugin_data) => {
assert_eq!(data, plugin_data.data);
assert_eq!(span, plugin_data.span);
}
}
}
#[test]
fn response_round_trip_signature() {
let signature = PluginSignature::build("nu-plugin")
.required("first", SyntaxShape::String, "first required")
.required("second", SyntaxShape::Int, "second required")
.required_named("first-named", SyntaxShape::String, "first named", Some('f'))
.required_named(
"second-named",
SyntaxShape::String,
"second named",
Some('s'),
)
.rest("remaining", SyntaxShape::Int, "remaining");
let response = PluginResponse::Signature(vec![signature.clone()]);
let encoder = JsonSerializer {};
let mut buffer: Vec<u8> = Vec::new();
encoder
.encode_response(&response, &mut buffer)
.expect("unable to serialize message");
let returned = encoder
.decode_response(&mut buffer.as_slice())
.expect("unable to deserialize message");
match returned {
PluginResponse::Error(_) => panic!("returned wrong call type"),
PluginResponse::Value(_) => panic!("returned wrong call type"),
PluginResponse::PluginData(..) => panic!("returned wrong call type"),
PluginResponse::Signature(returned_signature) => {
assert_eq!(returned_signature.len(), 1);
assert_eq!(signature.sig.name, returned_signature[0].sig.name);
assert_eq!(signature.sig.usage, returned_signature[0].sig.usage);
assert_eq!(
signature.sig.extra_usage,
returned_signature[0].sig.extra_usage
);
assert_eq!(signature.sig.is_filter, returned_signature[0].sig.is_filter);
signature
.sig
.required_positional
.iter()
.zip(returned_signature[0].sig.required_positional.iter())
.for_each(|(lhs, rhs)| assert_eq!(lhs, rhs));
signature
.sig
.optional_positional
.iter()
.zip(returned_signature[0].sig.optional_positional.iter())
.for_each(|(lhs, rhs)| assert_eq!(lhs, rhs));
signature
.sig
.named
.iter()
.zip(returned_signature[0].sig.named.iter())
.for_each(|(lhs, rhs)| assert_eq!(lhs, rhs));
assert_eq!(
signature.sig.rest_positional,
returned_signature[0].sig.rest_positional,
fn json_ends_in_newline() {
let mut out = vec![];
JsonSerializer {}
.encode(&PluginInput::Call(0, PluginCall::Signature), &mut out)
.expect("serialization error");
let string = std::str::from_utf8(&out).expect("utf-8 error");
assert!(
string.ends_with('\n'),
"doesn't end with newline: {:?}",
string
);
}
}
}
#[test]
fn response_round_trip_value() {
let value = Value::int(10, Span::new(2, 30));
let response = PluginResponse::Value(Box::new(value.clone()));
let encoder = JsonSerializer {};
let mut buffer: Vec<u8> = Vec::new();
encoder
.encode_response(&response, &mut buffer)
.expect("unable to serialize message");
let returned = encoder
.decode_response(&mut buffer.as_slice())
.expect("unable to deserialize message");
match returned {
PluginResponse::Error(_) => panic!("returned wrong call type"),
PluginResponse::Signature(_) => panic!("returned wrong call type"),
PluginResponse::PluginData(..) => panic!("returned wrong call type"),
PluginResponse::Value(returned_value) => {
assert_eq!(&value, returned_value.as_ref())
}
}
}
#[test]
fn response_round_trip_plugin_data() {
let name = "test".to_string();
let data = vec![1, 2, 3, 4, 5];
let span = Span::new(2, 30);
let response = PluginResponse::PluginData(
name.clone(),
PluginData {
data: data.clone(),
span,
},
);
let encoder = JsonSerializer {};
let mut buffer: Vec<u8> = Vec::new();
encoder
.encode_response(&response, &mut buffer)
.expect("unable to serialize message");
let returned = encoder
.decode_response(&mut buffer.as_slice())
.expect("unable to deserialize message");
match returned {
PluginResponse::Error(_) => panic!("returned wrong call type"),
PluginResponse::Signature(_) => panic!("returned wrong call type"),
PluginResponse::Value(_) => panic!("returned wrong call type"),
PluginResponse::PluginData(returned_name, returned_plugin_data) => {
assert_eq!(name, returned_name);
assert_eq!(data, returned_plugin_data.data);
assert_eq!(span, returned_plugin_data.span);
}
}
}
#[test]
fn response_round_trip_error() {
let error = LabeledError {
label: "label".into(),
msg: "msg".into(),
span: Some(Span::new(2, 30)),
};
let response = PluginResponse::Error(error.clone());
let encoder = JsonSerializer {};
let mut buffer: Vec<u8> = Vec::new();
encoder
.encode_response(&response, &mut buffer)
.expect("unable to serialize message");
let returned = encoder
.decode_response(&mut buffer.as_slice())
.expect("unable to deserialize message");
match returned {
PluginResponse::Error(msg) => assert_eq!(error, msg),
PluginResponse::Signature(_) => panic!("returned wrong call type"),
PluginResponse::Value(_) => panic!("returned wrong call type"),
PluginResponse::PluginData(..) => panic!("returned wrong call type"),
}
}
#[test]
fn response_round_trip_error_none() {
let error = LabeledError {
label: "label".into(),
msg: "msg".into(),
span: None,
};
let response = PluginResponse::Error(error.clone());
let encoder = JsonSerializer {};
let mut buffer: Vec<u8> = Vec::new();
encoder
.encode_response(&response, &mut buffer)
.expect("unable to serialize message");
let returned = encoder
.decode_response(&mut buffer.as_slice())
.expect("unable to deserialize message");
match returned {
PluginResponse::Error(msg) => assert_eq!(error, msg),
PluginResponse::Signature(_) => panic!("returned wrong call type"),
PluginResponse::Value(_) => panic!("returned wrong call type"),
PluginResponse::PluginData(..) => panic!("returned wrong call type"),
}
fn json_has_no_other_newlines() {
let mut out = vec![];
// use something deeply nested, to try to trigger any pretty printing
let output = PluginOutput::Stream(StreamMessage::Data(
0,
StreamData::List(Value::test_list(vec![
Value::test_int(4),
// in case escaping failed
Value::test_string("newline\ncontaining\nstring"),
])),
));
JsonSerializer {}
.encode(&output, &mut out)
.expect("serialization error");
let string = std::str::from_utf8(&out).expect("utf-8 error");
assert_eq!(1, string.chars().filter(|ch| *ch == '\n').count());
}
}

View File

@ -1,14 +1,14 @@
use crate::{
plugin::PluginEncoder,
protocol::{PluginCall, PluginResponse},
};
use crate::plugin::{Encoder, PluginEncoder};
use nu_protocol::ShellError;
pub mod json;
pub mod msgpack;
#[cfg(test)]
mod tests;
#[doc(hidden)]
#[derive(Clone, Debug)]
#[derive(Clone, Copy, Debug)]
pub enum EncodingType {
Json(json::JsonSerializer),
MsgPack(msgpack::MsgPackSerializer),
@ -23,48 +23,6 @@ impl EncodingType {
}
}
pub fn encode_call(
&self,
plugin_call: &PluginCall,
writer: &mut impl std::io::Write,
) -> Result<(), ShellError> {
match self {
EncodingType::Json(encoder) => encoder.encode_call(plugin_call, writer),
EncodingType::MsgPack(encoder) => encoder.encode_call(plugin_call, writer),
}
}
pub fn decode_call(
&self,
reader: &mut impl std::io::BufRead,
) -> Result<PluginCall, ShellError> {
match self {
EncodingType::Json(encoder) => encoder.decode_call(reader),
EncodingType::MsgPack(encoder) => encoder.decode_call(reader),
}
}
pub fn encode_response(
&self,
plugin_response: &PluginResponse,
writer: &mut impl std::io::Write,
) -> Result<(), ShellError> {
match self {
EncodingType::Json(encoder) => encoder.encode_response(plugin_response, writer),
EncodingType::MsgPack(encoder) => encoder.encode_response(plugin_response, writer),
}
}
pub fn decode_response(
&self,
reader: &mut impl std::io::BufRead,
) -> Result<PluginResponse, ShellError> {
match self {
EncodingType::Json(encoder) => encoder.decode_response(reader),
EncodingType::MsgPack(encoder) => encoder.decode_response(reader),
}
}
pub fn to_str(&self) -> &'static str {
match self {
Self::Json(_) => "json",
@ -72,3 +30,29 @@ impl EncodingType {
}
}
}
impl PluginEncoder for EncodingType {
fn name(&self) -> &str {
self.to_str()
}
}
impl<T> Encoder<T> for EncodingType
where
json::JsonSerializer: Encoder<T>,
msgpack::MsgPackSerializer: Encoder<T>,
{
fn encode(&self, data: &T, writer: &mut impl std::io::Write) -> Result<(), ShellError> {
match self {
EncodingType::Json(encoder) => encoder.encode(data, writer),
EncodingType::MsgPack(encoder) => encoder.encode(data, writer),
}
}
fn decode(&self, reader: &mut impl std::io::BufRead) -> Result<Option<T>, ShellError> {
match self {
EncodingType::Json(encoder) => encoder.decode(reader),
EncodingType::MsgPack(encoder) => encoder.decode(reader),
}
}
}

View File

@ -1,362 +1,110 @@
use crate::{plugin::PluginEncoder, protocol::PluginResponse};
use nu_protocol::ShellError;
use std::io::ErrorKind;
/// A `PluginEncoder` that enables the plugin to communicate with Nushel with MsgPack
use crate::{
plugin::{Encoder, PluginEncoder},
protocol::{PluginInput, PluginOutput},
};
use nu_protocol::ShellError;
use serde::Deserialize;
/// A `PluginEncoder` that enables the plugin to communicate with Nushell with MsgPack
/// serialized data.
#[derive(Clone, Debug)]
///
/// Each message is written as a MessagePack object. There is no message envelope or separator.
#[derive(Clone, Copy, Debug)]
pub struct MsgPackSerializer;
impl PluginEncoder for MsgPackSerializer {
fn name(&self) -> &str {
"msgpack"
}
}
fn encode_call(
impl Encoder<PluginInput> for MsgPackSerializer {
fn encode(
&self,
plugin_call: &crate::protocol::PluginCall,
plugin_input: &PluginInput,
writer: &mut impl std::io::Write,
) -> Result<(), nu_protocol::ShellError> {
rmp_serde::encode::write(writer, plugin_call).map_err(|err| {
ShellError::PluginFailedToEncode {
msg: err.to_string(),
}
})
rmp_serde::encode::write(writer, plugin_input).map_err(rmp_encode_err)
}
fn decode_call(
fn decode(
&self,
reader: &mut impl std::io::BufRead,
) -> Result<crate::protocol::PluginCall, nu_protocol::ShellError> {
rmp_serde::from_read(reader).map_err(|err| ShellError::PluginFailedToDecode {
msg: err.to_string(),
})
) -> Result<Option<PluginInput>, ShellError> {
let mut de = rmp_serde::Deserializer::new(reader);
PluginInput::deserialize(&mut de)
.map(Some)
.or_else(rmp_decode_err)
}
}
fn encode_response(
impl Encoder<PluginOutput> for MsgPackSerializer {
fn encode(
&self,
plugin_response: &PluginResponse,
plugin_output: &PluginOutput,
writer: &mut impl std::io::Write,
) -> Result<(), ShellError> {
rmp_serde::encode::write(writer, plugin_response).map_err(|err| {
rmp_serde::encode::write(writer, plugin_output).map_err(rmp_encode_err)
}
fn decode(
&self,
reader: &mut impl std::io::BufRead,
) -> Result<Option<PluginOutput>, ShellError> {
let mut de = rmp_serde::Deserializer::new(reader);
PluginOutput::deserialize(&mut de)
.map(Some)
.or_else(rmp_decode_err)
}
}
/// Handle a msgpack encode error
fn rmp_encode_err(err: rmp_serde::encode::Error) -> ShellError {
match err {
rmp_serde::encode::Error::InvalidValueWrite(_) => {
// I/O error
ShellError::IOError {
msg: err.to_string(),
}
}
_ => {
// Something else
ShellError::PluginFailedToEncode {
msg: err.to_string(),
}
})
}
}
}
fn decode_response(
&self,
reader: &mut impl std::io::BufRead,
) -> Result<PluginResponse, ShellError> {
rmp_serde::from_read(reader).map_err(|err| ShellError::PluginFailedToDecode {
/// Handle a msgpack decode error. Returns `Ok(None)` on eof
fn rmp_decode_err<T>(err: rmp_serde::decode::Error) -> Result<Option<T>, ShellError> {
match err {
rmp_serde::decode::Error::InvalidMarkerRead(err)
if matches!(err.kind(), ErrorKind::UnexpectedEof) =>
{
// EOF
Ok(None)
}
rmp_serde::decode::Error::InvalidMarkerRead(_)
| rmp_serde::decode::Error::InvalidDataRead(_) => {
// I/O error
Err(ShellError::IOError {
msg: err.to_string(),
})
}
_ => {
// Something else
Err(ShellError::PluginFailedToDecode {
msg: err.to_string(),
})
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::protocol::{
CallInfo, CallInput, EvaluatedCall, LabeledError, PluginCall, PluginData, PluginResponse,
};
use nu_protocol::{PluginSignature, Span, Spanned, SyntaxShape, Value};
#[test]
fn callinfo_round_trip_signature() {
let plugin_call = PluginCall::Signature;
let encoder = MsgPackSerializer {};
let mut buffer: Vec<u8> = Vec::new();
encoder
.encode_call(&plugin_call, &mut buffer)
.expect("unable to serialize message");
let returned = encoder
.decode_call(&mut buffer.as_slice())
.expect("unable to deserialize message");
match returned {
PluginCall::Signature => {}
PluginCall::CallInfo(_) => panic!("decoded into wrong value"),
PluginCall::CollapseCustomValue(_) => panic!("decoded into wrong value"),
}
}
#[test]
fn callinfo_round_trip_callinfo() {
let name = "test".to_string();
let input = Value::bool(false, Span::new(1, 20));
let call = EvaluatedCall {
head: Span::new(0, 10),
positional: vec![
Value::float(1.0, Span::new(0, 10)),
Value::string("something", Span::new(0, 10)),
],
named: vec![(
Spanned {
item: "name".to_string(),
span: Span::new(0, 10),
},
Some(Value::float(1.0, Span::new(0, 10))),
)],
};
let plugin_call = PluginCall::CallInfo(CallInfo {
name: name.clone(),
call: call.clone(),
input: CallInput::Value(input.clone()),
config: None,
});
let encoder = MsgPackSerializer {};
let mut buffer: Vec<u8> = Vec::new();
encoder
.encode_call(&plugin_call, &mut buffer)
.expect("unable to serialize message");
let returned = encoder
.decode_call(&mut buffer.as_slice())
.expect("unable to deserialize message");
match returned {
PluginCall::Signature => panic!("returned wrong call type"),
PluginCall::CallInfo(call_info) => {
assert_eq!(name, call_info.name);
assert_eq!(CallInput::Value(input), call_info.input);
assert_eq!(call.head, call_info.call.head);
assert_eq!(call.positional.len(), call_info.call.positional.len());
call.positional
.iter()
.zip(call_info.call.positional.iter())
.for_each(|(lhs, rhs)| assert_eq!(lhs, rhs));
call.named
.iter()
.zip(call_info.call.named.iter())
.for_each(|(lhs, rhs)| {
// Comparing the keys
assert_eq!(lhs.0.item, rhs.0.item);
match (&lhs.1, &rhs.1) {
(None, None) => {}
(Some(a), Some(b)) => assert_eq!(a, b),
_ => panic!("not matching values"),
}
});
}
PluginCall::CollapseCustomValue(_) => panic!("returned wrong call type"),
}
}
#[test]
fn callinfo_round_trip_collapsecustomvalue() {
let data = vec![1, 2, 3, 4, 5, 6, 7];
let span = Span::new(0, 20);
let collapse_custom_value = PluginCall::CollapseCustomValue(PluginData {
data: data.clone(),
span,
});
let encoder = MsgPackSerializer {};
let mut buffer: Vec<u8> = Vec::new();
encoder
.encode_call(&collapse_custom_value, &mut buffer)
.expect("unable to serialize message");
let returned = encoder
.decode_call(&mut buffer.as_slice())
.expect("unable to deserialize message");
match returned {
PluginCall::Signature => panic!("returned wrong call type"),
PluginCall::CallInfo(_) => panic!("returned wrong call type"),
PluginCall::CollapseCustomValue(plugin_data) => {
assert_eq!(data, plugin_data.data);
assert_eq!(span, plugin_data.span);
}
}
}
#[test]
fn response_round_trip_signature() {
let signature = PluginSignature::build("nu-plugin")
.required("first", SyntaxShape::String, "first required")
.required("second", SyntaxShape::Int, "second required")
.required_named("first-named", SyntaxShape::String, "first named", Some('f'))
.required_named(
"second-named",
SyntaxShape::String,
"second named",
Some('s'),
)
.rest("remaining", SyntaxShape::Int, "remaining");
let response = PluginResponse::Signature(vec![signature.clone()]);
let encoder = MsgPackSerializer {};
let mut buffer: Vec<u8> = Vec::new();
encoder
.encode_response(&response, &mut buffer)
.expect("unable to serialize message");
let returned = encoder
.decode_response(&mut buffer.as_slice())
.expect("unable to deserialize message");
match returned {
PluginResponse::Error(_) => panic!("returned wrong call type"),
PluginResponse::Value(_) => panic!("returned wrong call type"),
PluginResponse::PluginData(..) => panic!("returned wrong call type"),
PluginResponse::Signature(returned_signature) => {
assert_eq!(returned_signature.len(), 1);
assert_eq!(signature.sig.name, returned_signature[0].sig.name);
assert_eq!(signature.sig.usage, returned_signature[0].sig.usage);
assert_eq!(
signature.sig.extra_usage,
returned_signature[0].sig.extra_usage
);
assert_eq!(signature.sig.is_filter, returned_signature[0].sig.is_filter);
signature
.sig
.required_positional
.iter()
.zip(returned_signature[0].sig.required_positional.iter())
.for_each(|(lhs, rhs)| assert_eq!(lhs, rhs));
signature
.sig
.optional_positional
.iter()
.zip(returned_signature[0].sig.optional_positional.iter())
.for_each(|(lhs, rhs)| assert_eq!(lhs, rhs));
signature
.sig
.named
.iter()
.zip(returned_signature[0].sig.named.iter())
.for_each(|(lhs, rhs)| assert_eq!(lhs, rhs));
assert_eq!(
signature.sig.rest_positional,
returned_signature[0].sig.rest_positional,
);
}
}
}
#[test]
fn response_round_trip_value() {
let value = Value::int(10, Span::new(2, 30));
let response = PluginResponse::Value(Box::new(value.clone()));
let encoder = MsgPackSerializer {};
let mut buffer: Vec<u8> = Vec::new();
encoder
.encode_response(&response, &mut buffer)
.expect("unable to serialize message");
let returned = encoder
.decode_response(&mut buffer.as_slice())
.expect("unable to deserialize message");
match returned {
PluginResponse::Error(_) => panic!("returned wrong call type"),
PluginResponse::Signature(_) => panic!("returned wrong call type"),
PluginResponse::PluginData(..) => panic!("returned wrong call type"),
PluginResponse::Value(returned_value) => {
assert_eq!(&value, returned_value.as_ref())
}
}
}
#[test]
fn response_round_trip_plugin_data() {
let name = "test".to_string();
let data = vec![1, 2, 3, 4, 5];
let span = Span::new(2, 30);
let response = PluginResponse::PluginData(
name.clone(),
PluginData {
data: data.clone(),
span,
},
);
let encoder = MsgPackSerializer {};
let mut buffer: Vec<u8> = Vec::new();
encoder
.encode_response(&response, &mut buffer)
.expect("unable to serialize message");
let returned = encoder
.decode_response(&mut buffer.as_slice())
.expect("unable to deserialize message");
match returned {
PluginResponse::Error(_) => panic!("returned wrong call type"),
PluginResponse::Signature(_) => panic!("returned wrong call type"),
PluginResponse::Value(_) => panic!("returned wrong call type"),
PluginResponse::PluginData(returned_name, returned_plugin_data) => {
assert_eq!(name, returned_name);
assert_eq!(data, returned_plugin_data.data);
assert_eq!(span, returned_plugin_data.span);
}
}
}
#[test]
fn response_round_trip_error() {
let error = LabeledError {
label: "label".into(),
msg: "msg".into(),
span: Some(Span::new(2, 30)),
};
let response = PluginResponse::Error(error.clone());
let encoder = MsgPackSerializer {};
let mut buffer: Vec<u8> = Vec::new();
encoder
.encode_response(&response, &mut buffer)
.expect("unable to serialize message");
let returned = encoder
.decode_response(&mut buffer.as_slice())
.expect("unable to deserialize message");
match returned {
PluginResponse::Error(msg) => assert_eq!(error, msg),
PluginResponse::Signature(_) => panic!("returned wrong call type"),
PluginResponse::Value(_) => panic!("returned wrong call type"),
PluginResponse::PluginData(..) => panic!("returned wrong call type"),
}
}
#[test]
fn response_round_trip_error_none() {
let error = LabeledError {
label: "label".into(),
msg: "msg".into(),
span: None,
};
let response = PluginResponse::Error(error.clone());
let encoder = MsgPackSerializer {};
let mut buffer: Vec<u8> = Vec::new();
encoder
.encode_response(&response, &mut buffer)
.expect("unable to serialize message");
let returned = encoder
.decode_response(&mut buffer.as_slice())
.expect("unable to deserialize message");
match returned {
PluginResponse::Error(msg) => assert_eq!(error, msg),
PluginResponse::Signature(_) => panic!("returned wrong call type"),
PluginResponse::Value(_) => panic!("returned wrong call type"),
PluginResponse::PluginData(..) => panic!("returned wrong call type"),
}
}
crate::serializers::tests::generate_tests!(MsgPackSerializer {});
}

View File

@ -0,0 +1,538 @@
macro_rules! generate_tests {
($encoder:expr) => {
use crate::protocol::{
CallInfo, CustomValueOp, EvaluatedCall, LabeledError, PipelineDataHeader, PluginCall,
PluginCallResponse, PluginCustomValue, PluginInput, PluginOutput, StreamData,
StreamMessage,
};
use nu_protocol::{PluginSignature, Span, Spanned, SyntaxShape, Value};
#[test]
fn decode_eof() {
let mut buffer: &[u8] = &[];
let encoder = $encoder;
let result: Option<PluginInput> = encoder
.decode(&mut buffer)
.expect("eof should not result in an error");
assert!(result.is_none(), "decode result: {result:?}");
let result: Option<PluginOutput> = encoder
.decode(&mut buffer)
.expect("eof should not result in an error");
assert!(result.is_none(), "decode result: {result:?}");
}
#[test]
fn decode_io_error() {
struct ErrorProducer;
impl std::io::Read for ErrorProducer {
fn read(&mut self, _buf: &mut [u8]) -> std::io::Result<usize> {
Err(std::io::Error::from(std::io::ErrorKind::ConnectionReset))
}
}
let encoder = $encoder;
let mut buffered = std::io::BufReader::new(ErrorProducer);
match Encoder::<PluginInput>::decode(&encoder, &mut buffered) {
Ok(_) => panic!("decode: i/o error was not passed through"),
Err(ShellError::IOError { .. }) => (), // okay
Err(other) => panic!(
"decode: got other error, should have been a \
ShellError::IOError: {other:?}"
),
}
match Encoder::<PluginOutput>::decode(&encoder, &mut buffered) {
Ok(_) => panic!("decode: i/o error was not passed through"),
Err(ShellError::IOError { .. }) => (), // okay
Err(other) => panic!(
"decode: got other error, should have been a \
ShellError::IOError: {other:?}"
),
}
}
#[test]
fn decode_gibberish() {
// just a sequence of bytes that shouldn't be valid in anything we use
let gibberish: &[u8] = &[
0, 80, 74, 85, 117, 122, 86, 100, 74, 115, 20, 104, 55, 98, 67, 203, 83, 85, 77,
112, 74, 79, 254, 71, 80,
];
let encoder = $encoder;
let mut buffered = std::io::BufReader::new(&gibberish[..]);
match Encoder::<PluginInput>::decode(&encoder, &mut buffered) {
Ok(value) => panic!("decode: parsed successfully => {value:?}"),
Err(ShellError::PluginFailedToDecode { .. }) => (), // okay
Err(other) => panic!(
"decode: got other error, should have been a \
ShellError::PluginFailedToDecode: {other:?}"
),
}
let mut buffered = std::io::BufReader::new(&gibberish[..]);
match Encoder::<PluginOutput>::decode(&encoder, &mut buffered) {
Ok(value) => panic!("decode: parsed successfully => {value:?}"),
Err(ShellError::PluginFailedToDecode { .. }) => (), // okay
Err(other) => panic!(
"decode: got other error, should have been a \
ShellError::PluginFailedToDecode: {other:?}"
),
}
}
#[test]
fn call_round_trip_signature() {
let plugin_call = PluginCall::Signature;
let plugin_input = PluginInput::Call(0, plugin_call);
let encoder = $encoder;
let mut buffer: Vec<u8> = Vec::new();
encoder
.encode(&plugin_input, &mut buffer)
.expect("unable to serialize message");
let returned = encoder
.decode(&mut buffer.as_slice())
.expect("unable to deserialize message")
.expect("eof");
match returned {
PluginInput::Call(0, PluginCall::Signature) => {}
_ => panic!("decoded into wrong value: {returned:?}"),
}
}
#[test]
fn call_round_trip_run() {
let name = "test".to_string();
let input = Value::bool(false, Span::new(1, 20));
let call = EvaluatedCall {
head: Span::new(0, 10),
positional: vec![
Value::float(1.0, Span::new(0, 10)),
Value::string("something", Span::new(0, 10)),
],
named: vec![(
Spanned {
item: "name".to_string(),
span: Span::new(0, 10),
},
Some(Value::float(1.0, Span::new(0, 10))),
)],
};
let plugin_call = PluginCall::Run(CallInfo {
name: name.clone(),
call: call.clone(),
input: PipelineDataHeader::Value(input.clone()),
config: None,
});
let plugin_input = PluginInput::Call(1, plugin_call);
let encoder = $encoder;
let mut buffer: Vec<u8> = Vec::new();
encoder
.encode(&plugin_input, &mut buffer)
.expect("unable to serialize message");
let returned = encoder
.decode(&mut buffer.as_slice())
.expect("unable to deserialize message")
.expect("eof");
match returned {
PluginInput::Call(1, PluginCall::Run(call_info)) => {
assert_eq!(name, call_info.name);
assert_eq!(PipelineDataHeader::Value(input), call_info.input);
assert_eq!(call.head, call_info.call.head);
assert_eq!(call.positional.len(), call_info.call.positional.len());
call.positional
.iter()
.zip(call_info.call.positional.iter())
.for_each(|(lhs, rhs)| assert_eq!(lhs, rhs));
call.named
.iter()
.zip(call_info.call.named.iter())
.for_each(|(lhs, rhs)| {
// Comparing the keys
assert_eq!(lhs.0.item, rhs.0.item);
match (&lhs.1, &rhs.1) {
(None, None) => {}
(Some(a), Some(b)) => assert_eq!(a, b),
_ => panic!("not matching values"),
}
});
}
_ => panic!("decoded into wrong value: {returned:?}"),
}
}
#[test]
fn call_round_trip_customvalueop() {
let data = vec![1, 2, 3, 4, 5, 6, 7];
let span = Span::new(0, 20);
let custom_value_op = PluginCall::CustomValueOp(
Spanned {
item: PluginCustomValue {
name: "Foo".into(),
data: data.clone(),
source: None,
},
span,
},
CustomValueOp::ToBaseValue,
);
let plugin_input = PluginInput::Call(2, custom_value_op);
let encoder = $encoder;
let mut buffer: Vec<u8> = Vec::new();
encoder
.encode(&plugin_input, &mut buffer)
.expect("unable to serialize message");
let returned = encoder
.decode(&mut buffer.as_slice())
.expect("unable to deserialize message")
.expect("eof");
match returned {
PluginInput::Call(2, PluginCall::CustomValueOp(val, op)) => {
assert_eq!("Foo", val.item.name);
assert_eq!(data, val.item.data);
assert_eq!(span, val.span);
#[allow(unreachable_patterns)]
match op {
CustomValueOp::ToBaseValue => (),
_ => panic!("wrong op: {op:?}"),
}
}
_ => panic!("decoded into wrong value: {returned:?}"),
}
}
#[test]
fn response_round_trip_signature() {
let signature = PluginSignature::build("nu-plugin")
.required("first", SyntaxShape::String, "first required")
.required("second", SyntaxShape::Int, "second required")
.required_named("first-named", SyntaxShape::String, "first named", Some('f'))
.required_named(
"second-named",
SyntaxShape::String,
"second named",
Some('s'),
)
.rest("remaining", SyntaxShape::Int, "remaining");
let response = PluginCallResponse::Signature(vec![signature.clone()]);
let output = PluginOutput::CallResponse(3, response);
let encoder = $encoder;
let mut buffer: Vec<u8> = Vec::new();
encoder
.encode(&output, &mut buffer)
.expect("unable to serialize message");
let returned = encoder
.decode(&mut buffer.as_slice())
.expect("unable to deserialize message")
.expect("eof");
match returned {
PluginOutput::CallResponse(
3,
PluginCallResponse::Signature(returned_signature),
) => {
assert_eq!(returned_signature.len(), 1);
assert_eq!(signature.sig.name, returned_signature[0].sig.name);
assert_eq!(signature.sig.usage, returned_signature[0].sig.usage);
assert_eq!(
signature.sig.extra_usage,
returned_signature[0].sig.extra_usage
);
assert_eq!(signature.sig.is_filter, returned_signature[0].sig.is_filter);
signature
.sig
.required_positional
.iter()
.zip(returned_signature[0].sig.required_positional.iter())
.for_each(|(lhs, rhs)| assert_eq!(lhs, rhs));
signature
.sig
.optional_positional
.iter()
.zip(returned_signature[0].sig.optional_positional.iter())
.for_each(|(lhs, rhs)| assert_eq!(lhs, rhs));
signature
.sig
.named
.iter()
.zip(returned_signature[0].sig.named.iter())
.for_each(|(lhs, rhs)| assert_eq!(lhs, rhs));
assert_eq!(
signature.sig.rest_positional,
returned_signature[0].sig.rest_positional,
);
}
_ => panic!("decoded into wrong value: {returned:?}"),
}
}
#[test]
fn response_round_trip_value() {
let value = Value::int(10, Span::new(2, 30));
let response = PluginCallResponse::value(value.clone());
let output = PluginOutput::CallResponse(4, response);
let encoder = $encoder;
let mut buffer: Vec<u8> = Vec::new();
encoder
.encode(&output, &mut buffer)
.expect("unable to serialize message");
let returned = encoder
.decode(&mut buffer.as_slice())
.expect("unable to deserialize message")
.expect("eof");
match returned {
PluginOutput::CallResponse(
4,
PluginCallResponse::PipelineData(PipelineDataHeader::Value(returned_value)),
) => {
assert_eq!(value, returned_value)
}
_ => panic!("decoded into wrong value: {returned:?}"),
}
}
#[test]
fn response_round_trip_plugin_custom_value() {
let name = "test";
let data = vec![1, 2, 3, 4, 5];
let span = Span::new(2, 30);
let value = Value::custom_value(
Box::new(PluginCustomValue {
name: name.into(),
data: data.clone(),
source: None,
}),
span,
);
let response = PluginCallResponse::PipelineData(PipelineDataHeader::Value(value));
let output = PluginOutput::CallResponse(5, response);
let encoder = $encoder;
let mut buffer: Vec<u8> = Vec::new();
encoder
.encode(&output, &mut buffer)
.expect("unable to serialize message");
let returned = encoder
.decode(&mut buffer.as_slice())
.expect("unable to deserialize message")
.expect("eof");
match returned {
PluginOutput::CallResponse(
5,
PluginCallResponse::PipelineData(PipelineDataHeader::Value(returned_value)),
) => {
assert_eq!(span, returned_value.span());
if let Some(plugin_val) = returned_value
.as_custom_value()
.unwrap()
.as_any()
.downcast_ref::<PluginCustomValue>()
{
assert_eq!(name, plugin_val.name);
assert_eq!(data, plugin_val.data);
} else {
panic!("returned CustomValue is not a PluginCustomValue");
}
}
_ => panic!("decoded into wrong value: {returned:?}"),
}
}
#[test]
fn response_round_trip_error() {
let error = LabeledError {
label: "label".into(),
msg: "msg".into(),
span: Some(Span::new(2, 30)),
};
let response = PluginCallResponse::Error(error.clone());
let output = PluginOutput::CallResponse(6, response);
let encoder = $encoder;
let mut buffer: Vec<u8> = Vec::new();
encoder
.encode(&output, &mut buffer)
.expect("unable to serialize message");
let returned = encoder
.decode(&mut buffer.as_slice())
.expect("unable to deserialize message")
.expect("eof");
match returned {
PluginOutput::CallResponse(6, PluginCallResponse::Error(msg)) => {
assert_eq!(error, msg)
}
_ => panic!("decoded into wrong value: {returned:?}"),
}
}
#[test]
fn response_round_trip_error_none() {
let error = LabeledError {
label: "label".into(),
msg: "msg".into(),
span: None,
};
let response = PluginCallResponse::Error(error.clone());
let output = PluginOutput::CallResponse(7, response);
let encoder = $encoder;
let mut buffer: Vec<u8> = Vec::new();
encoder
.encode(&output, &mut buffer)
.expect("unable to serialize message");
let returned = encoder
.decode(&mut buffer.as_slice())
.expect("unable to deserialize message")
.expect("eof");
match returned {
PluginOutput::CallResponse(7, PluginCallResponse::Error(msg)) => {
assert_eq!(error, msg)
}
_ => panic!("decoded into wrong value: {returned:?}"),
}
}
#[test]
fn input_round_trip_stream_data_list() {
let span = Span::new(12, 30);
let item = Value::int(1, span);
let stream_data = StreamData::List(item.clone());
let plugin_input = PluginInput::Stream(StreamMessage::Data(0, stream_data));
let encoder = $encoder;
let mut buffer: Vec<u8> = Vec::new();
encoder
.encode(&plugin_input, &mut buffer)
.expect("unable to serialize message");
let returned = encoder
.decode(&mut buffer.as_slice())
.expect("unable to deserialize message")
.expect("eof");
match returned {
PluginInput::Stream(StreamMessage::Data(id, StreamData::List(list_data))) => {
assert_eq!(0, id);
assert_eq!(item, list_data);
}
_ => panic!("decoded into wrong value: {returned:?}"),
}
}
#[test]
fn input_round_trip_stream_data_raw() {
let data = b"Hello world";
let stream_data = StreamData::Raw(Ok(data.to_vec()));
let plugin_input = PluginInput::Stream(StreamMessage::Data(1, stream_data));
let encoder = $encoder;
let mut buffer: Vec<u8> = Vec::new();
encoder
.encode(&plugin_input, &mut buffer)
.expect("unable to serialize message");
let returned = encoder
.decode(&mut buffer.as_slice())
.expect("unable to deserialize message")
.expect("eof");
match returned {
PluginInput::Stream(StreamMessage::Data(id, StreamData::Raw(bytes))) => {
assert_eq!(1, id);
match bytes {
Ok(bytes) => assert_eq!(data, &bytes[..]),
Err(err) => panic!("decoded into error variant: {err:?}"),
}
}
_ => panic!("decoded into wrong value: {returned:?}"),
}
}
#[test]
fn output_round_trip_stream_data_list() {
let span = Span::new(12, 30);
let item = Value::int(1, span);
let stream_data = StreamData::List(item.clone());
let plugin_output = PluginOutput::Stream(StreamMessage::Data(4, stream_data));
let encoder = $encoder;
let mut buffer: Vec<u8> = Vec::new();
encoder
.encode(&plugin_output, &mut buffer)
.expect("unable to serialize message");
let returned = encoder
.decode(&mut buffer.as_slice())
.expect("unable to deserialize message")
.expect("eof");
match returned {
PluginOutput::Stream(StreamMessage::Data(id, StreamData::List(list_data))) => {
assert_eq!(4, id);
assert_eq!(item, list_data);
}
_ => panic!("decoded into wrong value: {returned:?}"),
}
}
#[test]
fn output_round_trip_stream_data_raw() {
let data = b"Hello world";
let stream_data = StreamData::Raw(Ok(data.to_vec()));
let plugin_output = PluginOutput::Stream(StreamMessage::Data(5, stream_data));
let encoder = $encoder;
let mut buffer: Vec<u8> = Vec::new();
encoder
.encode(&plugin_output, &mut buffer)
.expect("unable to serialize message");
let returned = encoder
.decode(&mut buffer.as_slice())
.expect("unable to deserialize message")
.expect("eof");
match returned {
PluginOutput::Stream(StreamMessage::Data(id, StreamData::Raw(bytes))) => {
assert_eq!(5, id);
match bytes {
Ok(bytes) => assert_eq!(data, &bytes[..]),
Err(err) => panic!("decoded into error variant: {err:?}"),
}
}
_ => panic!("decoded into wrong value: {returned:?}"),
}
}
};
}
pub(crate) use generate_tests;

View File

@ -6,7 +6,7 @@ use crate::engine::Command;
use crate::{BlockId, Category, Flag, PositionalArg, SyntaxShape, Type};
/// A simple wrapper for Signature that includes examples.
#[derive(Clone, Serialize, Deserialize)]
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PluginSignature {
pub sig: Signature,
pub examples: Vec<PluginExample>,

View File

@ -774,6 +774,54 @@ pub enum ShellError {
#[diagnostic(code(nu::shell::plugin_failed_to_decode))]
PluginFailedToDecode { msg: String },
/// A custom value cannot be sent to the given plugin.
///
/// ## Resolution
///
/// Custom values can only be used with the plugin they came from. Use a command from that
/// plugin instead.
#[error("Custom value `{name}` cannot be sent to plugin")]
#[diagnostic(code(nu::shell::custom_value_incorrect_for_plugin))]
CustomValueIncorrectForPlugin {
name: String,
#[label("the `{dest_plugin}` plugin does not support this kind of value")]
span: Span,
dest_plugin: String,
#[help("this value came from the `{}` plugin")]
src_plugin: Option<String>,
},
/// The plugin failed to encode a custom value.
///
/// ## Resolution
///
/// This is likely a bug with the plugin itself. The plugin may have tried to send a custom
/// value that is not serializable.
#[error("Custom value failed to encode")]
#[diagnostic(code(nu::shell::custom_value_failed_to_encode))]
CustomValueFailedToEncode {
msg: String,
#[label("{msg}")]
span: Span,
},
/// The plugin failed to encode a custom value.
///
/// ## Resolution
///
/// This may be a bug within the plugin, or the plugin may have been updated in between the
/// creation of the custom value and its use.
#[error("Custom value failed to decode")]
#[diagnostic(code(nu::shell::custom_value_failed_to_decode))]
#[diagnostic(help(
"the plugin may have been updated and no longer support this custom value"
))]
CustomValueFailedToDecode {
msg: String,
#[label("{msg}")]
span: Span,
},
/// I/O operation interrupted.
///
/// ## Resolution

View File

@ -164,7 +164,6 @@ pub enum Value {
#[serde(rename = "span")]
internal_span: Span,
},
#[serde(skip_serializing)]
CustomValue {
val: Box<dyn CustomValue>,
// note: spans are being refactored out of Value

View File

@ -4,7 +4,7 @@ mod second_custom_value;
use cool_custom_value::CoolCustomValue;
use nu_plugin::{serve_plugin, MsgPackSerializer, Plugin};
use nu_plugin::{EvaluatedCall, LabeledError};
use nu_protocol::{Category, PluginSignature, ShellError, Value};
use nu_protocol::{Category, PluginSignature, ShellError, SyntaxShape, Value};
use second_custom_value::SecondCustomValue;
struct CustomValuePlugin;
@ -21,6 +21,14 @@ impl Plugin for CustomValuePlugin {
PluginSignature::build("custom-value update")
.usage("PluginSignature for a plugin that updates a custom value")
.category(Category::Experimental),
PluginSignature::build("custom-value update-arg")
.usage("PluginSignature for a plugin that updates a custom value as an argument")
.required(
"custom_value",
SyntaxShape::Any,
"the custom value to update",
)
.category(Category::Experimental),
]
}
@ -35,6 +43,7 @@ impl Plugin for CustomValuePlugin {
"custom-value generate" => self.generate(call, input),
"custom-value generate2" => self.generate2(call, input),
"custom-value update" => self.update(call, input),
"custom-value update-arg" => self.update(call, &call.req(0)?),
_ => Err(LabeledError {
label: "Plugin call with wrong name signature".into(),
msg: "the signature used to call the plugin does not match any name in the plugin signature vector".into(),

View File

@ -10,9 +10,8 @@
# register <path-to-py-file>
#
# Be careful with the spans. Miette will crash if a span is outside the
# size of the contents vector. For this example we are using 0 and 1, which will
# point to the beginning of the contents vector. We strongly suggest using the span
# found in the plugin call head
# size of the contents vector. We strongly suggest using the span found in the
# plugin call head as in this example.
#
# The plugin will be run using the active Python implementation. If you are in
# a Python environment, that is the Python version that is used
@ -113,7 +112,7 @@ def signatures():
}
def process_call(plugin_call):
def process_call(id, plugin_call):
"""
plugin_call is a dictionary with the information from the call
It should contain:
@ -127,277 +126,38 @@ def process_call(plugin_call):
sys.stderr.write(json.dumps(plugin_call, indent=4))
sys.stderr.write("\n")
# Get the span from the call
span = plugin_call["Run"]["call"]["head"]
# Creates a Value of type List that will be encoded and sent to Nushell
return {
value = {
"Value": {
"List": {
"vals": [
{
"Record": {
"val": {
"cols": ["one", "two", "three"],
"vals": [
{
"Int": {
"val": 0,
"span": {"start": 0, "end": 1},
"val": x * y,
"span": span
}
} for y in [0, 1, 2]
]
},
{
"Int": {
"val": 0,
"span": {"start": 0, "end": 1},
"span": span
}
},
{
"Int": {
"val": 0,
"span": {"start": 0, "end": 1},
}
},
} for x in range(0, 10)
],
"span": {"start": 0, "end": 1},
}
},
{
"Record": {
"cols": ["one", "two", "three"],
"vals": [
{
"Int": {
"val": 0,
"span": {"start": 0, "end": 1},
}
},
{
"Int": {
"val": 1,
"span": {"start": 0, "end": 1},
}
},
{
"Int": {
"val": 2,
"span": {"start": 0, "end": 1},
}
},
],
"span": {"start": 0, "end": 1},
}
},
{
"Record": {
"cols": ["one", "two", "three"],
"vals": [
{
"Int": {
"val": 0,
"span": {"start": 0, "end": 1},
}
},
{
"Int": {
"val": 2,
"span": {"start": 0, "end": 1},
}
},
{
"Int": {
"val": 4,
"span": {"start": 0, "end": 1},
}
},
],
"span": {"start": 0, "end": 1},
}
},
{
"Record": {
"cols": ["one", "two", "three"],
"vals": [
{
"Int": {
"val": 0,
"span": {"start": 0, "end": 1},
}
},
{
"Int": {
"val": 3,
"span": {"start": 0, "end": 1},
}
},
{
"Int": {
"val": 6,
"span": {"start": 0, "end": 1},
}
},
],
"span": {"start": 0, "end": 1},
}
},
{
"Record": {
"cols": ["one", "two", "three"],
"vals": [
{
"Int": {
"val": 0,
"span": {"start": 0, "end": 1},
}
},
{
"Int": {
"val": 4,
"span": {"start": 0, "end": 1},
}
},
{
"Int": {
"val": 8,
"span": {"start": 0, "end": 1},
}
},
],
"span": {"start": 0, "end": 1},
}
},
{
"Record": {
"cols": ["one", "two", "three"],
"vals": [
{
"Int": {
"val": 0,
"span": {"start": 0, "end": 1},
}
},
{
"Int": {
"val": 5,
"span": {"start": 0, "end": 1},
}
},
{
"Int": {
"val": 10,
"span": {"start": 0, "end": 1},
}
},
],
"span": {"start": 0, "end": 1},
}
},
{
"Record": {
"cols": ["one", "two", "three"],
"vals": [
{
"Int": {
"val": 0,
"span": {"start": 0, "end": 1},
}
},
{
"Int": {
"val": 6,
"span": {"start": 0, "end": 1},
}
},
{
"Int": {
"val": 12,
"span": {"start": 0, "end": 1},
}
},
],
"span": {"start": 0, "end": 1},
}
},
{
"Record": {
"cols": ["one", "two", "three"],
"vals": [
{
"Int": {
"val": 0,
"span": {"start": 0, "end": 1},
}
},
{
"Int": {
"val": 7,
"span": {"start": 0, "end": 1},
}
},
{
"Int": {
"val": 14,
"span": {"start": 0, "end": 1},
}
},
],
"span": {"start": 0, "end": 1},
}
},
{
"Record": {
"cols": ["one", "two", "three"],
"vals": [
{
"Int": {
"val": 0,
"span": {"start": 0, "end": 1},
}
},
{
"Int": {
"val": 8,
"span": {"start": 0, "end": 1},
}
},
{
"Int": {
"val": 16,
"span": {"start": 0, "end": 1},
}
},
],
"span": {"start": 0, "end": 1},
}
},
{
"Record": {
"cols": ["one", "two", "three"],
"vals": [
{
"Int": {
"val": 0,
"span": {"start": 0, "end": 1},
}
},
{
"Int": {
"val": 9,
"span": {"start": 0, "end": 1},
}
},
{
"Int": {
"val": 18,
"span": {"start": 0, "end": 1},
}
},
],
"span": {"start": 0, "end": 1},
}
},
],
"span": {"start": 0, "end": 1},
"span": span
}
}
}
write_response(id, {"PipelineData": value})
def tell_nushell_encoding():
sys.stdout.write(chr(4))
@ -406,30 +166,79 @@ def tell_nushell_encoding():
sys.stdout.flush()
def plugin():
tell_nushell_encoding()
call_str = ",".join(sys.stdin.readlines())
plugin_call = json.loads(call_str)
def tell_nushell_hello():
"""
A `Hello` message is required at startup to inform nushell of the protocol capabilities and
compatibility of the plugin. The version specified should be the version of nushell that this
plugin was tested and developed against.
"""
hello = {
"Hello": {
"protocol": "nu-plugin", # always this value
"version": "0.90.2",
"features": []
}
}
sys.stdout.write(json.dumps(hello))
sys.stdout.write("\n")
sys.stdout.flush()
if plugin_call == "Signature":
signature = json.dumps(signatures())
sys.stdout.write(signature)
elif "CallInfo" in plugin_call:
response = process_call(plugin_call)
sys.stdout.write(json.dumps(response))
def write_response(id, response):
"""
Use this format to send a response to a plugin call. The ID of the plugin call is required.
"""
wrapped_response = {
"CallResponse": [
id,
response,
]
}
sys.stdout.write(json.dumps(wrapped_response))
sys.stdout.write("\n")
sys.stdout.flush()
else:
# Use this error format if you want to return an error back to Nushell
def write_error(id, msg, span=None):
"""
Use this error format to send errors to nushell in response to a plugin call. The ID of the
plugin call is required.
"""
error = {
"Error": {
"label": "ERROR from plugin",
"msg": "error message pointing to call head span",
"span": {"start": 0, "end": 1},
"msg": msg,
"span": span
}
}
sys.stdout.write(json.dumps(error))
write_response(id, error)
def handle_input(input):
if "Hello" in input:
return
elif "Call" in input:
[id, plugin_call] = input["Call"]
if "Signature" in plugin_call:
write_response(id, signatures())
elif "Run" in plugin_call:
process_call(id, plugin_call)
else:
write_error(id, "Operation not supported: " + str(plugin_call))
else:
sys.stderr.write("Unknown message: " + str(input) + "\n")
exit(1)
def plugin():
tell_nushell_encoding()
tell_nushell_hello()
for line in sys.stdin:
input = json.loads(line)
handle_input(input)
if __name__ == "__main__":
if len(sys.argv) == 2 and sys.argv[1] == "--stdio":
plugin()
else:
print("Run me from inside nushell!")

View File

@ -0,0 +1,19 @@
[package]
authors = ["The Nushell Project Developers"]
description = "An example of stream handling in nushell plugins"
repository = "https://github.com/nushell/nushell/tree/main/crates/nu_plugin_stream_example"
edition = "2021"
license = "MIT"
name = "nu_plugin_stream_example"
version = "0.90.2"
[[bin]]
name = "nu_plugin_stream_example"
bench = false
[lib]
bench = false
[dependencies]
nu-plugin = { path = "../nu-plugin", version = "0.90.2" }
nu-protocol = { path = "../nu-protocol", version = "0.90.2", features = ["plugin"] }

View File

@ -0,0 +1,48 @@
# Streaming Plugin Example
Crate with a simple example of the `StreamingPlugin` trait that needs to be implemented
in order to create a binary that can be registered into nushell declaration list
## `stream_example seq`
This command demonstrates generating list streams. It generates numbers from the first argument
to the second argument just like the builtin `seq` command does.
Examples:
> ```nushell
> stream_example seq 1 10
> ```
[1 2 3 4 5 6 7 8 9 10]
> ```nushell
> stream_example seq 1 10 | describe
> ```
list<int> (stream)
## `stream_example sum`
This command demonstrates consuming list streams. It consumes a stream of numbers and calculates the
sum just like the builtin `math sum` command does.
Examples:
> ```nushell
> seq 1 5 | stream_example sum
> ```
15
## `stream_example collect-external`
This command demonstrates transforming streams into external streams. The list (or stream) of
strings on input will be concatenated into an external stream (raw input) on stdout.
> ```nushell
> [Hello "\n" world how are you] | stream_example collect-external
> ````
Hello
worldhowareyou

View File

@ -0,0 +1,67 @@
use nu_plugin::{EvaluatedCall, LabeledError};
use nu_protocol::{ListStream, PipelineData, RawStream, Value};
pub struct Example;
mod int_or_float;
use self::int_or_float::IntOrFloat;
impl Example {
pub fn seq(
&self,
call: &EvaluatedCall,
_input: PipelineData,
) -> Result<PipelineData, LabeledError> {
let first: i64 = call.req(0)?;
let last: i64 = call.req(1)?;
let span = call.head;
let iter = (first..=last).map(move |number| Value::int(number, span));
let list_stream = ListStream::from_stream(iter, None);
Ok(PipelineData::ListStream(list_stream, None))
}
pub fn sum(
&self,
call: &EvaluatedCall,
input: PipelineData,
) -> Result<PipelineData, LabeledError> {
let mut acc = IntOrFloat::Int(0);
let span = input.span();
for value in input {
if let Ok(n) = value.as_i64() {
acc.add_i64(n);
} else if let Ok(n) = value.as_f64() {
acc.add_f64(n);
} else {
return Err(LabeledError {
label: "Stream only accepts ints and floats".into(),
msg: format!("found {}", value.get_type()),
span,
});
}
}
Ok(PipelineData::Value(acc.to_value(call.head), None))
}
pub fn collect_external(
&self,
call: &EvaluatedCall,
input: PipelineData,
) -> Result<PipelineData, LabeledError> {
let stream = input.into_iter().map(|value| {
value
.as_str()
.map(|str| str.as_bytes())
.or_else(|_| value.as_binary())
.map(|bin| bin.to_vec())
});
Ok(PipelineData::ExternalStream {
stdout: Some(RawStream::new(Box::new(stream), None, call.head, None)),
stderr: None,
exit_code: None,
span: call.head,
metadata: None,
trim_end_newline: false,
})
}
}

View File

@ -0,0 +1,42 @@
use nu_protocol::Value;
use nu_protocol::Span;
/// Accumulates numbers into either an int or a float. Changes type to float on the first
/// float received.
#[derive(Clone, Copy)]
pub(crate) enum IntOrFloat {
Int(i64),
Float(f64),
}
impl IntOrFloat {
pub(crate) fn add_i64(&mut self, n: i64) {
match self {
IntOrFloat::Int(ref mut v) => {
*v += n;
}
IntOrFloat::Float(ref mut v) => {
*v += n as f64;
}
}
}
pub(crate) fn add_f64(&mut self, n: f64) {
match self {
IntOrFloat::Int(v) => {
*self = IntOrFloat::Float(*v as f64 + n);
}
IntOrFloat::Float(ref mut v) => {
*v += n;
}
}
}
pub(crate) fn to_value(self, span: Span) -> Value {
match self {
IntOrFloat::Int(v) => Value::int(v, span),
IntOrFloat::Float(v) => Value::float(v, span),
}
}
}

View File

@ -0,0 +1,4 @@
mod example;
mod nu;
pub use example::Example;

View File

@ -0,0 +1,30 @@
use nu_plugin::{serve_plugin, MsgPackSerializer};
use nu_plugin_stream_example::Example;
fn main() {
// When defining your plugin, you can select the Serializer that could be
// used to encode and decode the messages. The available options are
// MsgPackSerializer and JsonSerializer. Both are defined in the serializer
// folder in nu-plugin.
serve_plugin(&mut Example {}, MsgPackSerializer {})
// Note
// When creating plugins in other languages one needs to consider how a plugin
// is added and used in nushell.
// The steps are:
// - The plugin is register. In this stage nushell calls the binary file of
// the plugin sending information using the encoded PluginCall::PluginSignature object.
// Use this encoded data in your plugin to design the logic that will return
// the encoded signatures.
// Nushell is expecting and encoded PluginResponse::PluginSignature with all the
// plugin signatures
// - When calling the plugin, nushell sends to the binary file the encoded
// PluginCall::CallInfo which has all the call information, such as the
// values of the arguments, the name of the signature called and the input
// from the pipeline.
// Use this data to design your plugin login and to create the value that
// will be sent to nushell
// Nushell expects an encoded PluginResponse::Value from the plugin
// - If an error needs to be sent back to nushell, one can encode PluginResponse::Error.
// This is a labeled error that nushell can format for pretty printing
}

View File

@ -0,0 +1,86 @@
use crate::Example;
use nu_plugin::{EvaluatedCall, LabeledError, StreamingPlugin};
use nu_protocol::{
Category, PipelineData, PluginExample, PluginSignature, Span, SyntaxShape, Type, Value,
};
impl StreamingPlugin for Example {
fn signature(&self) -> Vec<PluginSignature> {
let span = Span::unknown();
vec![
PluginSignature::build("stream_example")
.usage("Examples for streaming plugins")
.search_terms(vec!["example".into()])
.category(Category::Experimental),
PluginSignature::build("stream_example seq")
.usage("Example stream generator for a list of values")
.search_terms(vec!["example".into()])
.required("first", SyntaxShape::Int, "first number to generate")
.required("last", SyntaxShape::Int, "last number to generate")
.input_output_type(Type::Nothing, Type::List(Type::Int.into()))
.plugin_examples(vec![PluginExample {
example: "stream_example seq 1 3".into(),
description: "generate a sequence from 1 to 3".into(),
result: Some(Value::list(
vec![
Value::int(1, span),
Value::int(2, span),
Value::int(3, span),
],
span,
)),
}])
.category(Category::Experimental),
PluginSignature::build("stream_example sum")
.usage("Example stream consumer for a list of values")
.search_terms(vec!["example".into()])
.input_output_types(vec![
(Type::List(Type::Int.into()), Type::Int),
(Type::List(Type::Float.into()), Type::Float),
])
.plugin_examples(vec![PluginExample {
example: "seq 1 5 | stream_example sum".into(),
description: "sum values from 1 to 5".into(),
result: Some(Value::int(15, span)),
}])
.category(Category::Experimental),
PluginSignature::build("stream_example collect-external")
.usage("Example transformer to raw external stream")
.search_terms(vec!["example".into()])
.input_output_types(vec![
(Type::List(Type::String.into()), Type::String),
(Type::List(Type::Binary.into()), Type::Binary),
])
.plugin_examples(vec![PluginExample {
example: "[a b] | stream_example collect-external".into(),
description: "collect strings into one stream".into(),
result: Some(Value::string("ab", span)),
}])
.category(Category::Experimental),
]
}
fn run(
&mut self,
name: &str,
_config: &Option<Value>,
call: &EvaluatedCall,
input: PipelineData,
) -> Result<PipelineData, LabeledError> {
match name {
"stream_example" => Err(LabeledError {
label: "No subcommand provided".into(),
msg: "add --help here to see usage".into(),
span: Some(call.head)
}),
"stream_example seq" => self.seq(call, input),
"stream_example sum" => self.sum(call, input),
"stream_example collect-external" => self.collect_external(call, input),
_ => Err(LabeledError {
label: "Plugin call with wrong name signature".into(),
msg: "the signature used to call the plugin does not match any name in the plugin signature vector".into(),
span: Some(call.head),
}),
}
}
}

View File

@ -26,6 +26,20 @@ fn can_get_custom_value_from_plugin_and_pass_it_over() {
);
}
#[test]
fn can_get_custom_value_from_plugin_and_pass_it_over_as_an_argument() {
let actual = nu_with_plugins!(
cwd: "tests",
plugin: ("nu_plugin_custom_values"),
"custom-value update-arg (custom-value generate)"
);
assert_eq!(
actual.out,
"I used to be a custom value! My data was (abcxyz)"
);
}
#[test]
fn can_generate_and_updated_multiple_types_of_custom_values() {
let actual = nu_with_plugins!(
@ -65,7 +79,10 @@ fn fails_if_passing_engine_custom_values_to_plugins() {
assert!(actual
.err
.contains("Plugin custom-value update can not handle the custom value SQLiteDatabase"));
.contains("`SQLiteDatabase` cannot be sent to plugin"));
assert!(actual
.err
.contains("the `custom_values` plugin does not support this kind of value"));
}
#[test]
@ -81,5 +98,8 @@ fn fails_if_passing_custom_values_across_plugins() {
assert!(actual
.err
.contains("Plugin inc can not handle the custom value CoolCustomValue"));
.contains("`CoolCustomValue` cannot be sent to plugin"));
assert!(actual
.err
.contains("the `inc` plugin does not support this kind of value"));
}

View File

@ -3,3 +3,4 @@ mod core_inc;
mod custom_values;
mod formats;
mod register;
mod stream;

166
tests/plugins/stream.rs Normal file
View File

@ -0,0 +1,166 @@
use nu_test_support::nu_with_plugins;
use pretty_assertions::assert_eq;
#[test]
fn seq_produces_stream() {
let actual = nu_with_plugins!(
cwd: "tests/fixtures/formats",
plugin: ("nu_plugin_stream_example"),
"stream_example seq 1 5 | describe"
);
assert_eq!(actual.out, "list<int> (stream)");
}
#[test]
fn seq_describe_no_collect_succeeds_without_error() {
// This tests to ensure that there's no error if the stream is suddenly closed
let actual = nu_with_plugins!(
cwd: "tests/fixtures/formats",
plugin: ("nu_plugin_stream_example"),
"stream_example seq 1 5 | describe --no-collect"
);
assert_eq!(actual.out, "stream");
assert_eq!(actual.err, "");
}
#[test]
fn seq_stream_collects_to_correct_list() {
let actual = nu_with_plugins!(
cwd: "tests/fixtures/formats",
plugin: ("nu_plugin_stream_example"),
"stream_example seq 1 5 | to json --raw"
);
assert_eq!(actual.out, "[1,2,3,4,5]");
let actual = nu_with_plugins!(
cwd: "tests/fixtures/formats",
plugin: ("nu_plugin_stream_example"),
"stream_example seq 1 0 | to json --raw"
);
assert_eq!(actual.out, "[]");
}
#[test]
fn seq_big_stream() {
// Testing big streams helps to ensure there are no deadlocking bugs
let actual = nu_with_plugins!(
cwd: "tests/fixtures/formats",
plugin: ("nu_plugin_stream_example"),
"stream_example seq 1 100000 | length"
);
assert_eq!(actual.out, "100000");
}
#[test]
fn sum_accepts_list_of_int() {
let actual = nu_with_plugins!(
cwd: "tests/fixtures/formats",
plugin: ("nu_plugin_stream_example"),
"[1 2 3] | stream_example sum"
);
assert_eq!(actual.out, "6");
}
#[test]
fn sum_accepts_list_of_float() {
let actual = nu_with_plugins!(
cwd: "tests/fixtures/formats",
plugin: ("nu_plugin_stream_example"),
"[1.0 2.0 3.5] | stream_example sum"
);
assert_eq!(actual.out, "6.5");
}
#[test]
fn sum_accepts_stream_of_int() {
let actual = nu_with_plugins!(
cwd: "tests/fixtures/formats",
plugin: ("nu_plugin_stream_example"),
"seq 1 5 | stream_example sum"
);
assert_eq!(actual.out, "15");
}
#[test]
fn sum_accepts_stream_of_float() {
let actual = nu_with_plugins!(
cwd: "tests/fixtures/formats",
plugin: ("nu_plugin_stream_example"),
"seq 1 5 | into float | stream_example sum"
);
assert_eq!(actual.out, "15");
}
#[test]
fn sum_big_stream() {
// Testing big streams helps to ensure there are no deadlocking bugs
let actual = nu_with_plugins!(
cwd: "tests/fixtures/formats",
plugin: ("nu_plugin_stream_example"),
"seq 1 100000 | stream_example sum"
);
assert_eq!(actual.out, "5000050000");
}
#[test]
fn collect_external_accepts_list_of_string() {
let actual = nu_with_plugins!(
cwd: "tests/fixtures/formats",
plugin: ("nu_plugin_stream_example"),
"[a b] | stream_example collect-external"
);
assert_eq!(actual.out, "ab");
}
#[test]
fn collect_external_accepts_list_of_binary() {
let actual = nu_with_plugins!(
cwd: "tests/fixtures/formats",
plugin: ("nu_plugin_stream_example"),
"[0x[41] 0x[42]] | stream_example collect-external"
);
assert_eq!(actual.out, "AB");
}
#[test]
fn collect_external_produces_raw_input() {
let actual = nu_with_plugins!(
cwd: "tests/fixtures/formats",
plugin: ("nu_plugin_stream_example"),
"[a b c] | stream_example collect-external | describe"
);
assert_eq!(actual.out, "raw input");
}
#[test]
fn collect_external_big_stream() {
// This in particular helps to ensure that a big stream can be both read and written at the same
// time without deadlocking
let actual = nu_with_plugins!(
cwd: "tests/fixtures/formats",
plugin: ("nu_plugin_stream_example"),
r#"(
seq 1 10000 |
to text |
each { into string } |
stream_example collect-external |
lines |
length
)"#
);
assert_eq!(actual.out, "10000");
}

View File

@ -281,6 +281,14 @@
Source='target\$(var.Profile)\nu_plugin_gstat.exe'
KeyPath='yes'/>
</Component>
<Component Id='binary23' Guid='*' Win64='$(var.Win64)'>
<File
Id='exe23'
Name='nu_plugin_stream_example.exe'
DiskId='1'
Source='target\$(var.Profile)\nu_plugin_stream_example.exe'
KeyPath='yes'/>
</Component>
</Directory>
</Directory>
</Directory>