mirror of
https://github.com/nushell/nushell.git
synced 2024-12-22 15:13:01 +01:00
Bidirectional communication and streams for plugins (#11911)
This commit is contained in:
parent
461f69ac5d
commit
88f1f386bb
12
Cargo.lock
generated
12
Cargo.lock
generated
@ -3153,11 +3153,15 @@ name = "nu-plugin"
|
||||
version = "0.90.2"
|
||||
dependencies = [
|
||||
"bincode",
|
||||
"log",
|
||||
"miette",
|
||||
"nu-engine",
|
||||
"nu-protocol",
|
||||
"rmp-serde",
|
||||
"semver",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"typetag",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@ -3330,6 +3334,14 @@ dependencies = [
|
||||
"sxd-xpath",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "nu_plugin_stream_example"
|
||||
version = "0.90.2"
|
||||
dependencies = [
|
||||
"nu-plugin",
|
||||
"nu-protocol",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "num"
|
||||
version = "0.2.1"
|
||||
|
@ -43,6 +43,7 @@ members = [
|
||||
"crates/nu_plugin_inc",
|
||||
"crates/nu_plugin_gstat",
|
||||
"crates/nu_plugin_example",
|
||||
"crates/nu_plugin_stream_example",
|
||||
"crates/nu_plugin_query",
|
||||
"crates/nu_plugin_custom_values",
|
||||
"crates/nu_plugin_formats",
|
||||
|
@ -1,7 +1,7 @@
|
||||
use criterion::{criterion_group, criterion_main, BatchSize, Criterion};
|
||||
use nu_cli::eval_source;
|
||||
use nu_parser::parse;
|
||||
use nu_plugin::{EncodingType, PluginResponse};
|
||||
use nu_plugin::{Encoder, EncodingType, PluginCallResponse, PluginOutput};
|
||||
use nu_protocol::{engine::EngineState, PipelineData, Span, Value};
|
||||
use nu_utils::{get_default_config, get_default_env};
|
||||
use std::path::{Path, PathBuf};
|
||||
@ -148,10 +148,12 @@ fn encoding_benchmarks(c: &mut Criterion) {
|
||||
for fmt in ["json", "msgpack"] {
|
||||
group.bench_function(&format!("{fmt} encode {row_cnt} * {col_cnt}"), |b| {
|
||||
let mut res = vec![];
|
||||
let test_data =
|
||||
PluginResponse::Value(Box::new(encoding_test_data(row_cnt, col_cnt)));
|
||||
let test_data = PluginOutput::CallResponse(
|
||||
0,
|
||||
PluginCallResponse::value(encoding_test_data(row_cnt, col_cnt)),
|
||||
);
|
||||
let encoder = EncodingType::try_from_bytes(fmt.as_bytes()).unwrap();
|
||||
b.iter(|| encoder.encode_response(&test_data, &mut res))
|
||||
b.iter(|| encoder.encode(&test_data, &mut res))
|
||||
});
|
||||
}
|
||||
}
|
||||
@ -165,14 +167,16 @@ fn decoding_benchmarks(c: &mut Criterion) {
|
||||
for fmt in ["json", "msgpack"] {
|
||||
group.bench_function(&format!("{fmt} decode for {row_cnt} * {col_cnt}"), |b| {
|
||||
let mut res = vec![];
|
||||
let test_data =
|
||||
PluginResponse::Value(Box::new(encoding_test_data(row_cnt, col_cnt)));
|
||||
let test_data = PluginOutput::CallResponse(
|
||||
0,
|
||||
PluginCallResponse::value(encoding_test_data(row_cnt, col_cnt)),
|
||||
);
|
||||
let encoder = EncodingType::try_from_bytes(fmt.as_bytes()).unwrap();
|
||||
encoder.encode_response(&test_data, &mut res).unwrap();
|
||||
encoder.encode(&test_data, &mut res).unwrap();
|
||||
let mut binary_data = std::io::Cursor::new(res);
|
||||
b.iter(|| {
|
||||
b.iter(|| -> Result<Option<PluginOutput>, _> {
|
||||
binary_data.set_position(0);
|
||||
encoder.decode_response(&mut binary_data)
|
||||
encoder.decode(&mut binary_data)
|
||||
})
|
||||
});
|
||||
}
|
||||
|
@ -18,3 +18,7 @@ bincode = "1.3"
|
||||
rmp-serde = "1.1"
|
||||
serde = { version = "1.0" }
|
||||
serde_json = { version = "1.0" }
|
||||
log = "0.4"
|
||||
miette = "7.0"
|
||||
semver = "1.0"
|
||||
typetag = "0.2"
|
||||
|
@ -15,7 +15,7 @@
|
||||
//! function, which will handle all of the input and output serialization when
|
||||
//! invoked by Nushell.
|
||||
//!
|
||||
//! ```
|
||||
//! ```rust,no_run
|
||||
//! use nu_plugin::{EvaluatedCall, LabeledError, MsgPackSerializer, Plugin, serve_plugin};
|
||||
//! use nu_protocol::{PluginSignature, Value};
|
||||
//!
|
||||
@ -46,8 +46,21 @@
|
||||
//! that demonstrates the full range of plugin capabilities.
|
||||
mod plugin;
|
||||
mod protocol;
|
||||
mod sequence;
|
||||
mod serializers;
|
||||
|
||||
pub use plugin::{get_signature, serve_plugin, Plugin, PluginDeclaration};
|
||||
pub use protocol::{EvaluatedCall, LabeledError, PluginResponse};
|
||||
pub use serializers::{json::JsonSerializer, msgpack::MsgPackSerializer, EncodingType};
|
||||
pub use plugin::{serve_plugin, Plugin, PluginEncoder, StreamingPlugin};
|
||||
pub use protocol::{EvaluatedCall, LabeledError};
|
||||
pub use serializers::{json::JsonSerializer, msgpack::MsgPackSerializer};
|
||||
|
||||
// Used by other nu crates.
|
||||
#[doc(hidden)]
|
||||
pub use plugin::{get_signature, PluginDeclaration};
|
||||
#[doc(hidden)]
|
||||
pub use serializers::EncodingType;
|
||||
|
||||
// Used by external benchmarks.
|
||||
#[doc(hidden)]
|
||||
pub use plugin::Encoder;
|
||||
#[doc(hidden)]
|
||||
pub use protocol::{PluginCallResponse, PluginOutput};
|
||||
|
46
crates/nu-plugin/src/plugin/context.rs
Normal file
46
crates/nu-plugin/src/plugin/context.rs
Normal file
@ -0,0 +1,46 @@
|
||||
use std::sync::{atomic::AtomicBool, Arc};
|
||||
|
||||
use nu_protocol::{
|
||||
ast::Call,
|
||||
engine::{EngineState, Stack},
|
||||
};
|
||||
|
||||
/// Object safe trait for abstracting operations required of the plugin context.
|
||||
pub(crate) trait PluginExecutionContext: Send + Sync {
|
||||
/// The interrupt signal, if present
|
||||
fn ctrlc(&self) -> Option<&Arc<AtomicBool>>;
|
||||
}
|
||||
|
||||
/// The execution context of a plugin command. May be extended with more fields in the future.
|
||||
pub(crate) struct PluginExecutionCommandContext {
|
||||
ctrlc: Option<Arc<AtomicBool>>,
|
||||
}
|
||||
|
||||
impl PluginExecutionCommandContext {
|
||||
pub fn new(
|
||||
engine_state: &EngineState,
|
||||
_stack: &Stack,
|
||||
_call: &Call,
|
||||
) -> PluginExecutionCommandContext {
|
||||
PluginExecutionCommandContext {
|
||||
ctrlc: engine_state.ctrlc.clone(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl PluginExecutionContext for PluginExecutionCommandContext {
|
||||
fn ctrlc(&self) -> Option<&Arc<AtomicBool>> {
|
||||
self.ctrlc.as_ref()
|
||||
}
|
||||
}
|
||||
|
||||
/// A bogus execution context for testing that doesn't really implement anything properly
|
||||
#[cfg(test)]
|
||||
pub(crate) struct PluginExecutionBogusContext;
|
||||
|
||||
#[cfg(test)]
|
||||
impl PluginExecutionContext for PluginExecutionBogusContext {
|
||||
fn ctrlc(&self) -> Option<&Arc<AtomicBool>> {
|
||||
None
|
||||
}
|
||||
}
|
@ -1,10 +1,7 @@
|
||||
use crate::EvaluatedCall;
|
||||
|
||||
use super::{call_plugin, create_command, get_plugin_encoding};
|
||||
use crate::protocol::{
|
||||
CallInfo, CallInput, PluginCall, PluginCustomValue, PluginData, PluginResponse,
|
||||
};
|
||||
use super::{PluginExecutionCommandContext, PluginIdentity};
|
||||
use crate::protocol::{CallInfo, EvaluatedCall};
|
||||
use std::path::{Path, PathBuf};
|
||||
use std::sync::Arc;
|
||||
|
||||
use nu_engine::eval_block;
|
||||
use nu_protocol::engine::{Command, EngineState, Stack};
|
||||
@ -16,8 +13,7 @@ use nu_protocol::{Example, PipelineData, ShellError, Value};
|
||||
pub struct PluginDeclaration {
|
||||
name: String,
|
||||
signature: PluginSignature,
|
||||
filename: PathBuf,
|
||||
shell: Option<PathBuf>,
|
||||
identity: Arc<PluginIdentity>,
|
||||
}
|
||||
|
||||
impl PluginDeclaration {
|
||||
@ -25,8 +21,7 @@ impl PluginDeclaration {
|
||||
Self {
|
||||
name: signature.sig.name.clone(),
|
||||
signature,
|
||||
filename,
|
||||
shell,
|
||||
identity: Arc::new(PluginIdentity::new(filename, shell)),
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -76,76 +71,18 @@ impl Command for PluginDeclaration {
|
||||
call: &Call,
|
||||
input: PipelineData,
|
||||
) -> Result<PipelineData, ShellError> {
|
||||
// Call the command with self path
|
||||
// Decode information from plugin
|
||||
// Create PipelineData
|
||||
let source_file = Path::new(&self.filename);
|
||||
let mut plugin_cmd = create_command(source_file, self.shell.as_deref());
|
||||
// We need the current environment variables for `python` based plugins
|
||||
// Or we'll likely have a problem when a plugin is implemented in a virtual Python environment.
|
||||
let current_envs = nu_engine::env::env_to_strings(engine_state, stack).unwrap_or_default();
|
||||
plugin_cmd.envs(current_envs);
|
||||
|
||||
let mut child = plugin_cmd.spawn().map_err(|err| {
|
||||
let decl = engine_state.get_decl(call.decl_id);
|
||||
ShellError::GenericError {
|
||||
error: format!("Unable to spawn plugin for {}", decl.name()),
|
||||
msg: format!("{err}"),
|
||||
span: Some(call.head),
|
||||
help: None,
|
||||
inner: vec![],
|
||||
}
|
||||
})?;
|
||||
|
||||
let input = input.into_value(call.head);
|
||||
let span = input.span();
|
||||
let input = match input {
|
||||
Value::CustomValue { val, .. } => {
|
||||
match val.as_any().downcast_ref::<PluginCustomValue>() {
|
||||
Some(plugin_data) if plugin_data.filename == self.filename => {
|
||||
CallInput::Data(PluginData {
|
||||
data: plugin_data.data.clone(),
|
||||
span,
|
||||
})
|
||||
}
|
||||
_ => {
|
||||
let custom_value_name = val.value_string();
|
||||
return Err(ShellError::GenericError {
|
||||
error: format!(
|
||||
"Plugin {} can not handle the custom value {}",
|
||||
self.name, custom_value_name
|
||||
),
|
||||
msg: format!("custom value {custom_value_name}"),
|
||||
span: Some(span),
|
||||
help: None,
|
||||
inner: vec![],
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
Value::LazyRecord { val, .. } => CallInput::Value(val.collect()?),
|
||||
value => CallInput::Value(value),
|
||||
};
|
||||
// Create the EvaluatedCall to send to the plugin first - it's best for this to fail early,
|
||||
// before we actually try to run the plugin command
|
||||
let evaluated_call = EvaluatedCall::try_from_call(call, engine_state, stack)?;
|
||||
|
||||
// Fetch the configuration for a plugin
|
||||
//
|
||||
// The `plugin` must match the registered name of a plugin. For
|
||||
// `register nu_plugin_example` the plugin config lookup uses `"example"`
|
||||
let config = self
|
||||
.filename
|
||||
.file_stem()
|
||||
.and_then(|file| {
|
||||
file.to_string_lossy()
|
||||
.clone()
|
||||
.strip_prefix("nu_plugin_")
|
||||
.map(|name| {
|
||||
nu_engine::get_config(engine_state, stack)
|
||||
.plugins
|
||||
.get(name)
|
||||
.cloned()
|
||||
})
|
||||
})
|
||||
.flatten()
|
||||
let config = nu_engine::get_config(engine_state, stack)
|
||||
.plugins
|
||||
.get(&self.identity.plugin_name)
|
||||
.cloned()
|
||||
.map(|value| {
|
||||
let span = value.span();
|
||||
match value {
|
||||
@ -164,70 +101,41 @@ impl Command for PluginDeclaration {
|
||||
}
|
||||
});
|
||||
|
||||
let plugin_call = PluginCall::CallInfo(CallInfo {
|
||||
name: self.name.clone(),
|
||||
call: EvaluatedCall::try_from_call(call, engine_state, stack)?,
|
||||
input,
|
||||
config,
|
||||
});
|
||||
// We need the current environment variables for `python` based plugins
|
||||
// Or we'll likely have a problem when a plugin is implemented in a virtual Python environment.
|
||||
let current_envs = nu_engine::env::env_to_strings(engine_state, stack).unwrap_or_default();
|
||||
|
||||
let encoding = {
|
||||
let stdout_reader = match &mut child.stdout {
|
||||
Some(out) => out,
|
||||
None => {
|
||||
return Err(ShellError::PluginFailedToLoad {
|
||||
msg: "Plugin missing stdout reader".into(),
|
||||
})
|
||||
}
|
||||
};
|
||||
get_plugin_encoding(stdout_reader)?
|
||||
};
|
||||
let response = call_plugin(&mut child, plugin_call, &encoding, call.head).map_err(|err| {
|
||||
// Start the plugin
|
||||
let plugin = self.identity.clone().spawn(current_envs).map_err(|err| {
|
||||
let decl = engine_state.get_decl(call.decl_id);
|
||||
ShellError::GenericError {
|
||||
error: format!("Unable to decode call for {}", decl.name()),
|
||||
error: format!("Unable to spawn plugin for `{}`", decl.name()),
|
||||
msg: err.to_string(),
|
||||
span: Some(call.head),
|
||||
help: None,
|
||||
inner: vec![],
|
||||
}
|
||||
});
|
||||
})?;
|
||||
|
||||
let pipeline_data = match response {
|
||||
Ok(PluginResponse::Value(value)) => {
|
||||
Ok(PipelineData::Value(value.as_ref().clone(), None))
|
||||
}
|
||||
Ok(PluginResponse::PluginData(name, plugin_data)) => Ok(PipelineData::Value(
|
||||
Value::custom_value(
|
||||
Box::new(PluginCustomValue {
|
||||
name,
|
||||
data: plugin_data.data,
|
||||
filename: self.filename.clone(),
|
||||
shell: self.shell.clone(),
|
||||
source: engine_state.get_decl(call.decl_id).name().to_owned(),
|
||||
}),
|
||||
plugin_data.span,
|
||||
),
|
||||
None,
|
||||
)),
|
||||
Ok(PluginResponse::Error(err)) => Err(err.into()),
|
||||
Ok(PluginResponse::Signature(..)) => Err(ShellError::GenericError {
|
||||
error: "Plugin missing value".into(),
|
||||
msg: "Received a signature from plugin instead of value".into(),
|
||||
span: Some(call.head),
|
||||
help: None,
|
||||
inner: vec![],
|
||||
}),
|
||||
Err(err) => Err(err),
|
||||
};
|
||||
// Create the context to execute in
|
||||
let context = Arc::new(PluginExecutionCommandContext::new(
|
||||
engine_state,
|
||||
stack,
|
||||
call,
|
||||
));
|
||||
|
||||
// We need to call .wait() on the child, or we'll risk summoning the zombie horde
|
||||
let _ = child.wait();
|
||||
|
||||
pipeline_data
|
||||
plugin.run(
|
||||
CallInfo {
|
||||
name: self.name.clone(),
|
||||
call: evaluated_call,
|
||||
input,
|
||||
config,
|
||||
},
|
||||
context,
|
||||
)
|
||||
}
|
||||
|
||||
fn is_plugin(&self) -> Option<(&Path, Option<&Path>)> {
|
||||
Some((&self.filename, self.shell.as_deref()))
|
||||
Some((&self.identity.filename, self.identity.shell.as_deref()))
|
||||
}
|
||||
}
|
||||
|
110
crates/nu-plugin/src/plugin/identity.rs
Normal file
110
crates/nu-plugin/src/plugin/identity.rs
Normal file
@ -0,0 +1,110 @@
|
||||
use std::{
|
||||
ffi::OsStr,
|
||||
path::{Path, PathBuf},
|
||||
sync::Arc,
|
||||
};
|
||||
|
||||
use nu_protocol::ShellError;
|
||||
|
||||
use super::{create_command, make_plugin_interface, PluginInterface};
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub struct PluginIdentity {
|
||||
/// The filename used to start the plugin
|
||||
pub(crate) filename: PathBuf,
|
||||
/// The shell used to start the plugin, if required
|
||||
pub(crate) shell: Option<PathBuf>,
|
||||
/// The friendly name of the plugin (e.g. `inc` for `C:\nu_plugin_inc.exe`)
|
||||
pub(crate) plugin_name: String,
|
||||
}
|
||||
|
||||
impl PluginIdentity {
|
||||
pub(crate) fn new(filename: impl Into<PathBuf>, shell: Option<PathBuf>) -> PluginIdentity {
|
||||
let filename = filename.into();
|
||||
// `C:\nu_plugin_inc.exe` becomes `inc`
|
||||
// `/home/nu/.cargo/bin/nu_plugin_inc` becomes `inc`
|
||||
// any other path, including if it doesn't start with nu_plugin_, becomes
|
||||
// `<invalid plugin name>`
|
||||
let plugin_name = filename
|
||||
.file_stem()
|
||||
.map(|stem| stem.to_string_lossy().into_owned())
|
||||
.and_then(|stem| stem.strip_prefix("nu_plugin_").map(|s| s.to_owned()))
|
||||
.unwrap_or_else(|| {
|
||||
log::warn!(
|
||||
"filename `{}` is not a valid plugin name, must start with nu_plugin_",
|
||||
filename.display()
|
||||
);
|
||||
"<invalid plugin name>".into()
|
||||
});
|
||||
PluginIdentity {
|
||||
filename,
|
||||
shell,
|
||||
plugin_name,
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(all(test, windows))]
|
||||
pub(crate) fn new_fake(name: &str) -> Arc<PluginIdentity> {
|
||||
Arc::new(PluginIdentity::new(
|
||||
format!(r"C:\fake\path\nu_plugin_{name}.exe"),
|
||||
None,
|
||||
))
|
||||
}
|
||||
|
||||
#[cfg(all(test, not(windows)))]
|
||||
pub(crate) fn new_fake(name: &str) -> Arc<PluginIdentity> {
|
||||
Arc::new(PluginIdentity::new(
|
||||
format!(r"/fake/path/nu_plugin_{name}"),
|
||||
None,
|
||||
))
|
||||
}
|
||||
|
||||
/// Run the plugin command stored in this [`PluginIdentity`], then set up and return the
|
||||
/// [`PluginInterface`] attached to it.
|
||||
pub(crate) fn spawn(
|
||||
self: Arc<Self>,
|
||||
envs: impl IntoIterator<Item = (impl AsRef<OsStr>, impl AsRef<OsStr>)>,
|
||||
) -> Result<PluginInterface, ShellError> {
|
||||
let source_file = Path::new(&self.filename);
|
||||
let mut plugin_cmd = create_command(source_file, self.shell.as_deref());
|
||||
|
||||
// We need the current environment variables for `python` based plugins
|
||||
// Or we'll likely have a problem when a plugin is implemented in a virtual Python environment.
|
||||
plugin_cmd.envs(envs);
|
||||
|
||||
let program_name = plugin_cmd.get_program().to_os_string().into_string();
|
||||
|
||||
// Run the plugin command
|
||||
let child = plugin_cmd.spawn().map_err(|err| {
|
||||
let error_msg = match err.kind() {
|
||||
std::io::ErrorKind::NotFound => match program_name {
|
||||
Ok(prog_name) => {
|
||||
format!("Can't find {prog_name}, please make sure that {prog_name} is in PATH.")
|
||||
}
|
||||
_ => {
|
||||
format!("Error spawning child process: {err}")
|
||||
}
|
||||
},
|
||||
_ => {
|
||||
format!("Error spawning child process: {err}")
|
||||
}
|
||||
};
|
||||
ShellError::PluginFailedToLoad { msg: error_msg }
|
||||
})?;
|
||||
|
||||
make_plugin_interface(child, self)
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parses_name_from_path() {
|
||||
assert_eq!("test", PluginIdentity::new_fake("test").plugin_name);
|
||||
assert_eq!(
|
||||
"<invalid plugin name>",
|
||||
PluginIdentity::new("other", None).plugin_name
|
||||
);
|
||||
assert_eq!(
|
||||
"<invalid plugin name>",
|
||||
PluginIdentity::new("", None).plugin_name
|
||||
);
|
||||
}
|
437
crates/nu-plugin/src/plugin/interface.rs
Normal file
437
crates/nu-plugin/src/plugin/interface.rs
Normal file
@ -0,0 +1,437 @@
|
||||
//! Implements the stream multiplexing interface for both the plugin side and the engine side.
|
||||
|
||||
use std::{
|
||||
io::Write,
|
||||
sync::{
|
||||
atomic::{AtomicBool, Ordering::Relaxed},
|
||||
Arc, Mutex,
|
||||
},
|
||||
thread,
|
||||
};
|
||||
|
||||
use nu_protocol::{ListStream, PipelineData, RawStream, ShellError};
|
||||
|
||||
use crate::{
|
||||
plugin::Encoder,
|
||||
protocol::{
|
||||
ExternalStreamInfo, ListStreamInfo, PipelineDataHeader, RawStreamInfo, StreamMessage,
|
||||
},
|
||||
sequence::Sequence,
|
||||
};
|
||||
|
||||
mod stream;
|
||||
|
||||
mod engine;
|
||||
pub(crate) use engine::{EngineInterfaceManager, ReceivedPluginCall};
|
||||
|
||||
mod plugin;
|
||||
pub(crate) use plugin::{PluginInterface, PluginInterfaceManager};
|
||||
|
||||
use self::stream::{StreamManager, StreamManagerHandle, StreamWriter, WriteStreamMessage};
|
||||
|
||||
#[cfg(test)]
|
||||
mod test_util;
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests;
|
||||
|
||||
/// The maximum number of list stream values to send without acknowledgement. This should be tuned
|
||||
/// with consideration for memory usage.
|
||||
const LIST_STREAM_HIGH_PRESSURE: i32 = 100;
|
||||
|
||||
/// The maximum number of raw stream buffers to send without acknowledgement. This should be tuned
|
||||
/// with consideration for memory usage.
|
||||
const RAW_STREAM_HIGH_PRESSURE: i32 = 50;
|
||||
|
||||
/// Read input/output from the stream.
|
||||
pub(crate) trait PluginRead<T> {
|
||||
/// Returns `Ok(None)` on end of stream.
|
||||
fn read(&mut self) -> Result<Option<T>, ShellError>;
|
||||
}
|
||||
|
||||
impl<R, E, T> PluginRead<T> for (R, E)
|
||||
where
|
||||
R: std::io::BufRead,
|
||||
E: Encoder<T>,
|
||||
{
|
||||
fn read(&mut self) -> Result<Option<T>, ShellError> {
|
||||
self.1.decode(&mut self.0)
|
||||
}
|
||||
}
|
||||
|
||||
impl<R, T> PluginRead<T> for &mut R
|
||||
where
|
||||
R: PluginRead<T>,
|
||||
{
|
||||
fn read(&mut self) -> Result<Option<T>, ShellError> {
|
||||
(**self).read()
|
||||
}
|
||||
}
|
||||
|
||||
/// Write input/output to the stream.
|
||||
///
|
||||
/// The write should be atomic, without interference from other threads.
|
||||
pub(crate) trait PluginWrite<T>: Send + Sync {
|
||||
fn write(&self, data: &T) -> Result<(), ShellError>;
|
||||
|
||||
/// Flush any internal buffers, if applicable.
|
||||
fn flush(&self) -> Result<(), ShellError>;
|
||||
}
|
||||
|
||||
impl<E, T> PluginWrite<T> for (std::io::Stdout, E)
|
||||
where
|
||||
E: Encoder<T>,
|
||||
{
|
||||
fn write(&self, data: &T) -> Result<(), ShellError> {
|
||||
let mut lock = self.0.lock();
|
||||
self.1.encode(data, &mut lock)
|
||||
}
|
||||
|
||||
fn flush(&self) -> Result<(), ShellError> {
|
||||
self.0.lock().flush().map_err(|err| ShellError::IOError {
|
||||
msg: err.to_string(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl<W, E, T> PluginWrite<T> for (Mutex<W>, E)
|
||||
where
|
||||
W: std::io::Write + Send,
|
||||
E: Encoder<T>,
|
||||
{
|
||||
fn write(&self, data: &T) -> Result<(), ShellError> {
|
||||
let mut lock = self.0.lock().map_err(|_| ShellError::NushellFailed {
|
||||
msg: "writer mutex poisoned".into(),
|
||||
})?;
|
||||
self.1.encode(data, &mut *lock)
|
||||
}
|
||||
|
||||
fn flush(&self) -> Result<(), ShellError> {
|
||||
let mut lock = self.0.lock().map_err(|_| ShellError::NushellFailed {
|
||||
msg: "writer mutex poisoned".into(),
|
||||
})?;
|
||||
lock.flush().map_err(|err| ShellError::IOError {
|
||||
msg: err.to_string(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl<W, T> PluginWrite<T> for &W
|
||||
where
|
||||
W: PluginWrite<T>,
|
||||
{
|
||||
fn write(&self, data: &T) -> Result<(), ShellError> {
|
||||
(**self).write(data)
|
||||
}
|
||||
|
||||
fn flush(&self) -> Result<(), ShellError> {
|
||||
(**self).flush()
|
||||
}
|
||||
}
|
||||
|
||||
/// An interface manager handles I/O and state management for communication between a plugin and the
|
||||
/// engine. See [`PluginInterfaceManager`] for communication from the engine side to a plugin, or
|
||||
/// [`EngineInterfaceManager`] for communication from the plugin side to the engine.
|
||||
///
|
||||
/// There is typically one [`InterfaceManager`] consuming input from a background thread, and
|
||||
/// managing shared state.
|
||||
pub(crate) trait InterfaceManager {
|
||||
/// The corresponding interface type.
|
||||
type Interface: Interface + 'static;
|
||||
|
||||
/// The input message type.
|
||||
type Input;
|
||||
|
||||
/// Make a new interface that communicates with this [`InterfaceManager`].
|
||||
fn get_interface(&self) -> Self::Interface;
|
||||
|
||||
/// Consume an input message.
|
||||
///
|
||||
/// When implementing, call [`.consume_stream_message()`] for any encapsulated
|
||||
/// [`StreamMessage`]s received.
|
||||
fn consume(&mut self, input: Self::Input) -> Result<(), ShellError>;
|
||||
|
||||
/// Get the [`StreamManager`] for handling operations related to stream messages.
|
||||
fn stream_manager(&self) -> &StreamManager;
|
||||
|
||||
/// Prepare [`PipelineData`] after reading. This is called by `read_pipeline_data()` as
|
||||
/// a hook so that values that need special handling can be taken care of.
|
||||
fn prepare_pipeline_data(&self, data: PipelineData) -> Result<PipelineData, ShellError>;
|
||||
|
||||
/// Consume an input stream message.
|
||||
///
|
||||
/// This method is provided for implementors to use.
|
||||
fn consume_stream_message(&mut self, message: StreamMessage) -> Result<(), ShellError> {
|
||||
self.stream_manager().handle_message(message)
|
||||
}
|
||||
|
||||
/// Generate `PipelineData` for reading a stream, given a [`PipelineDataHeader`] that was
|
||||
/// received from the other side.
|
||||
///
|
||||
/// This method is provided for implementors to use.
|
||||
fn read_pipeline_data(
|
||||
&self,
|
||||
header: PipelineDataHeader,
|
||||
ctrlc: Option<&Arc<AtomicBool>>,
|
||||
) -> Result<PipelineData, ShellError> {
|
||||
self.prepare_pipeline_data(match header {
|
||||
PipelineDataHeader::Empty => PipelineData::Empty,
|
||||
PipelineDataHeader::Value(value) => PipelineData::Value(value, None),
|
||||
PipelineDataHeader::ListStream(info) => {
|
||||
let handle = self.stream_manager().get_handle();
|
||||
let reader = handle.read_stream(info.id, self.get_interface())?;
|
||||
PipelineData::ListStream(ListStream::from_stream(reader, ctrlc.cloned()), None)
|
||||
}
|
||||
PipelineDataHeader::ExternalStream(info) => {
|
||||
let handle = self.stream_manager().get_handle();
|
||||
let span = info.span;
|
||||
let new_raw_stream = |raw_info: RawStreamInfo| {
|
||||
let reader = handle.read_stream(raw_info.id, self.get_interface())?;
|
||||
let mut stream =
|
||||
RawStream::new(Box::new(reader), ctrlc.cloned(), span, raw_info.known_size);
|
||||
stream.is_binary = raw_info.is_binary;
|
||||
Ok::<_, ShellError>(stream)
|
||||
};
|
||||
PipelineData::ExternalStream {
|
||||
stdout: info.stdout.map(new_raw_stream).transpose()?,
|
||||
stderr: info.stderr.map(new_raw_stream).transpose()?,
|
||||
exit_code: info
|
||||
.exit_code
|
||||
.map(|list_info| {
|
||||
handle
|
||||
.read_stream(list_info.id, self.get_interface())
|
||||
.map(|reader| ListStream::from_stream(reader, ctrlc.cloned()))
|
||||
})
|
||||
.transpose()?,
|
||||
span: info.span,
|
||||
metadata: None,
|
||||
trim_end_newline: info.trim_end_newline,
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
/// An interface provides an API for communicating with a plugin or the engine and facilitates
|
||||
/// stream I/O. See [`PluginInterface`] for the API from the engine side to a plugin, or
|
||||
/// [`EngineInterface`] for the API from the plugin side to the engine.
|
||||
///
|
||||
/// There can be multiple copies of the interface managed by a single [`InterfaceManager`].
|
||||
pub(crate) trait Interface: Clone + Send {
|
||||
/// The output message type, which must be capable of encapsulating a [`StreamMessage`].
|
||||
type Output: From<StreamMessage>;
|
||||
|
||||
/// Write an output message.
|
||||
fn write(&self, output: Self::Output) -> Result<(), ShellError>;
|
||||
|
||||
/// Flush the output buffer, so messages are visible to the other side.
|
||||
fn flush(&self) -> Result<(), ShellError>;
|
||||
|
||||
/// Get the sequence for generating new [`StreamId`](crate::protocol::StreamId)s.
|
||||
fn stream_id_sequence(&self) -> &Sequence;
|
||||
|
||||
/// Get the [`StreamManagerHandle`] for doing stream operations.
|
||||
fn stream_manager_handle(&self) -> &StreamManagerHandle;
|
||||
|
||||
/// Prepare [`PipelineData`] to be written. This is called by `init_write_pipeline_data()` as
|
||||
/// a hook so that values that need special handling can be taken care of.
|
||||
fn prepare_pipeline_data(&self, data: PipelineData) -> Result<PipelineData, ShellError>;
|
||||
|
||||
/// Initialize a write for [`PipelineData`]. This returns two parts: the header, which can be
|
||||
/// embedded in the particular message that references the stream, and a writer, which will
|
||||
/// write out all of the data in the pipeline when `.write()` is called.
|
||||
///
|
||||
/// Note that not all [`PipelineData`] starts a stream. You should call `write()` anyway, as
|
||||
/// it will automatically handle this case.
|
||||
///
|
||||
/// This method is provided for implementors to use.
|
||||
fn init_write_pipeline_data(
|
||||
&self,
|
||||
data: PipelineData,
|
||||
) -> Result<(PipelineDataHeader, PipelineDataWriter<Self>), ShellError> {
|
||||
// Allocate a stream id and a writer
|
||||
let new_stream = |high_pressure_mark: i32| {
|
||||
// Get a free stream id
|
||||
let id = self.stream_id_sequence().next()?;
|
||||
// Create the writer
|
||||
let writer =
|
||||
self.stream_manager_handle()
|
||||
.write_stream(id, self.clone(), high_pressure_mark)?;
|
||||
Ok::<_, ShellError>((id, writer))
|
||||
};
|
||||
match self.prepare_pipeline_data(data)? {
|
||||
PipelineData::Value(value, _) => {
|
||||
Ok((PipelineDataHeader::Value(value), PipelineDataWriter::None))
|
||||
}
|
||||
PipelineData::Empty => Ok((PipelineDataHeader::Empty, PipelineDataWriter::None)),
|
||||
PipelineData::ListStream(stream, _) => {
|
||||
let (id, writer) = new_stream(LIST_STREAM_HIGH_PRESSURE)?;
|
||||
Ok((
|
||||
PipelineDataHeader::ListStream(ListStreamInfo { id }),
|
||||
PipelineDataWriter::ListStream(writer, stream),
|
||||
))
|
||||
}
|
||||
PipelineData::ExternalStream {
|
||||
stdout,
|
||||
stderr,
|
||||
exit_code,
|
||||
span,
|
||||
metadata: _,
|
||||
trim_end_newline,
|
||||
} => {
|
||||
// Create the writers and stream ids
|
||||
let stdout_stream = stdout
|
||||
.is_some()
|
||||
.then(|| new_stream(RAW_STREAM_HIGH_PRESSURE))
|
||||
.transpose()?;
|
||||
let stderr_stream = stderr
|
||||
.is_some()
|
||||
.then(|| new_stream(RAW_STREAM_HIGH_PRESSURE))
|
||||
.transpose()?;
|
||||
let exit_code_stream = exit_code
|
||||
.is_some()
|
||||
.then(|| new_stream(LIST_STREAM_HIGH_PRESSURE))
|
||||
.transpose()?;
|
||||
// Generate the header, with the stream ids
|
||||
let header = PipelineDataHeader::ExternalStream(ExternalStreamInfo {
|
||||
span,
|
||||
stdout: stdout
|
||||
.as_ref()
|
||||
.zip(stdout_stream.as_ref())
|
||||
.map(|(stream, (id, _))| RawStreamInfo::new(*id, stream)),
|
||||
stderr: stderr
|
||||
.as_ref()
|
||||
.zip(stderr_stream.as_ref())
|
||||
.map(|(stream, (id, _))| RawStreamInfo::new(*id, stream)),
|
||||
exit_code: exit_code_stream
|
||||
.as_ref()
|
||||
.map(|&(id, _)| ListStreamInfo { id }),
|
||||
trim_end_newline,
|
||||
});
|
||||
// Collect the writers
|
||||
let writer = PipelineDataWriter::ExternalStream {
|
||||
stdout: stdout_stream.map(|(_, writer)| writer).zip(stdout),
|
||||
stderr: stderr_stream.map(|(_, writer)| writer).zip(stderr),
|
||||
exit_code: exit_code_stream.map(|(_, writer)| writer).zip(exit_code),
|
||||
};
|
||||
Ok((header, writer))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> WriteStreamMessage for T
|
||||
where
|
||||
T: Interface,
|
||||
{
|
||||
fn write_stream_message(&mut self, msg: StreamMessage) -> Result<(), ShellError> {
|
||||
self.write(msg.into())
|
||||
}
|
||||
|
||||
fn flush(&mut self) -> Result<(), ShellError> {
|
||||
<Self as Interface>::flush(self)
|
||||
}
|
||||
}
|
||||
|
||||
/// Completes the write operation for a [`PipelineData`]. You must call
|
||||
/// [`PipelineDataWriter::write()`] to write all of the data contained within the streams.
|
||||
#[derive(Default)]
|
||||
#[must_use]
|
||||
pub(crate) enum PipelineDataWriter<W: WriteStreamMessage> {
|
||||
#[default]
|
||||
None,
|
||||
ListStream(StreamWriter<W>, ListStream),
|
||||
ExternalStream {
|
||||
stdout: Option<(StreamWriter<W>, RawStream)>,
|
||||
stderr: Option<(StreamWriter<W>, RawStream)>,
|
||||
exit_code: Option<(StreamWriter<W>, ListStream)>,
|
||||
},
|
||||
}
|
||||
|
||||
impl<W> PipelineDataWriter<W>
|
||||
where
|
||||
W: WriteStreamMessage + Send + 'static,
|
||||
{
|
||||
/// Write all of the data in each of the streams. This method waits for completion.
|
||||
pub(crate) fn write(self) -> Result<(), ShellError> {
|
||||
match self {
|
||||
// If no stream was contained in the PipelineData, do nothing.
|
||||
PipelineDataWriter::None => Ok(()),
|
||||
// Write a list stream.
|
||||
PipelineDataWriter::ListStream(mut writer, stream) => {
|
||||
writer.write_all(stream)?;
|
||||
Ok(())
|
||||
}
|
||||
// Write all three possible streams of an ExternalStream on separate threads.
|
||||
PipelineDataWriter::ExternalStream {
|
||||
stdout,
|
||||
stderr,
|
||||
exit_code,
|
||||
} => {
|
||||
thread::scope(|scope| {
|
||||
let stderr_thread = stderr.map(|(mut writer, stream)| {
|
||||
thread::Builder::new()
|
||||
.name("plugin stderr writer".into())
|
||||
.spawn_scoped(scope, move || writer.write_all(raw_stream_iter(stream)))
|
||||
.expect("failed to spawn thread")
|
||||
});
|
||||
let exit_code_thread = exit_code.map(|(mut writer, stream)| {
|
||||
thread::Builder::new()
|
||||
.name("plugin exit_code writer".into())
|
||||
.spawn_scoped(scope, move || writer.write_all(stream))
|
||||
.expect("failed to spawn thread")
|
||||
});
|
||||
// Optimize for stdout: if only stdout is present, don't spawn any other
|
||||
// threads.
|
||||
if let Some((mut writer, stream)) = stdout {
|
||||
writer.write_all(raw_stream_iter(stream))?;
|
||||
}
|
||||
let panicked = |thread_name: &str| {
|
||||
Err(ShellError::NushellFailed {
|
||||
msg: format!(
|
||||
"{thread_name} thread panicked in PipelineDataWriter::write"
|
||||
),
|
||||
})
|
||||
};
|
||||
stderr_thread
|
||||
.map(|t| t.join().unwrap_or_else(|_| panicked("stderr")))
|
||||
.transpose()?;
|
||||
exit_code_thread
|
||||
.map(|t| t.join().unwrap_or_else(|_| panicked("exit_code")))
|
||||
.transpose()?;
|
||||
Ok(())
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Write all of the data in each of the streams. This method returns immediately; any necessary
|
||||
/// write will happen in the background. If a thread was spawned, its handle is returned.
|
||||
pub(crate) fn write_background(self) -> Option<thread::JoinHandle<Result<(), ShellError>>> {
|
||||
match self {
|
||||
PipelineDataWriter::None => None,
|
||||
_ => Some(
|
||||
thread::Builder::new()
|
||||
.name("plugin stream background writer".into())
|
||||
.spawn(move || {
|
||||
let result = self.write();
|
||||
if let Err(ref err) = result {
|
||||
// Assume that the background thread error probably won't be handled and log it
|
||||
// here just in case.
|
||||
log::warn!("Error while writing pipeline in background: {err}");
|
||||
}
|
||||
result
|
||||
})
|
||||
.expect("failed to spawn thread"),
|
||||
),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Custom iterator for [`RawStream`] that respects ctrlc, but still has binary chunks
|
||||
fn raw_stream_iter(stream: RawStream) -> impl Iterator<Item = Result<Vec<u8>, ShellError>> {
|
||||
let ctrlc = stream.ctrlc;
|
||||
stream
|
||||
.stream
|
||||
.take_while(move |_| ctrlc.as_ref().map(|b| !b.load(Relaxed)).unwrap_or(true))
|
||||
}
|
375
crates/nu-plugin/src/plugin/interface/engine.rs
Normal file
375
crates/nu-plugin/src/plugin/interface/engine.rs
Normal file
@ -0,0 +1,375 @@
|
||||
//! Interface used by the plugin to communicate with the engine.
|
||||
|
||||
use std::sync::{mpsc, Arc};
|
||||
|
||||
use nu_protocol::{
|
||||
IntoInterruptiblePipelineData, ListStream, PipelineData, PluginSignature, ShellError, Spanned,
|
||||
Value,
|
||||
};
|
||||
|
||||
use crate::{
|
||||
protocol::{
|
||||
CallInfo, CustomValueOp, PluginCall, PluginCallId, PluginCallResponse, PluginCustomValue,
|
||||
PluginInput, ProtocolInfo,
|
||||
},
|
||||
LabeledError, PluginOutput,
|
||||
};
|
||||
|
||||
use super::{
|
||||
stream::{StreamManager, StreamManagerHandle},
|
||||
Interface, InterfaceManager, PipelineDataWriter, PluginRead, PluginWrite,
|
||||
};
|
||||
use crate::sequence::Sequence;
|
||||
|
||||
/// Plugin calls that are received by the [`EngineInterfaceManager`] for handling.
|
||||
///
|
||||
/// With each call, an [`EngineInterface`] is included that can be provided to the plugin code
|
||||
/// and should be used to send the response. The interface sent includes the [`PluginCallId`] for
|
||||
/// sending associated messages with the correct context.
|
||||
#[derive(Debug)]
|
||||
pub(crate) enum ReceivedPluginCall {
|
||||
Signature {
|
||||
engine: EngineInterface,
|
||||
},
|
||||
Run {
|
||||
engine: EngineInterface,
|
||||
call: CallInfo<PipelineData>,
|
||||
},
|
||||
CustomValueOp {
|
||||
engine: EngineInterface,
|
||||
custom_value: Spanned<PluginCustomValue>,
|
||||
op: CustomValueOp,
|
||||
},
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests;
|
||||
|
||||
/// Internal shared state between the manager and each interface.
|
||||
struct EngineInterfaceState {
|
||||
/// Sequence for generating stream ids
|
||||
stream_id_sequence: Sequence,
|
||||
/// The synchronized output writer
|
||||
writer: Box<dyn PluginWrite<PluginOutput>>,
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for EngineInterfaceState {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
f.debug_struct("EngineInterfaceState")
|
||||
.field("stream_id_sequence", &self.stream_id_sequence)
|
||||
.finish_non_exhaustive()
|
||||
}
|
||||
}
|
||||
|
||||
/// Manages reading and dispatching messages for [`EngineInterface`]s.
|
||||
#[derive(Debug)]
|
||||
pub(crate) struct EngineInterfaceManager {
|
||||
/// Shared state
|
||||
state: Arc<EngineInterfaceState>,
|
||||
/// Channel to send received PluginCalls to
|
||||
plugin_call_sender: mpsc::Sender<ReceivedPluginCall>,
|
||||
/// Receiver for PluginCalls. This is usually taken after initialization
|
||||
plugin_call_receiver: Option<mpsc::Receiver<ReceivedPluginCall>>,
|
||||
/// Manages stream messages and state
|
||||
stream_manager: StreamManager,
|
||||
/// Protocol version info, set after `Hello` received
|
||||
protocol_info: Option<ProtocolInfo>,
|
||||
}
|
||||
|
||||
impl EngineInterfaceManager {
|
||||
pub(crate) fn new(writer: impl PluginWrite<PluginOutput> + 'static) -> EngineInterfaceManager {
|
||||
let (plug_tx, plug_rx) = mpsc::channel();
|
||||
|
||||
EngineInterfaceManager {
|
||||
state: Arc::new(EngineInterfaceState {
|
||||
stream_id_sequence: Sequence::default(),
|
||||
writer: Box::new(writer),
|
||||
}),
|
||||
plugin_call_sender: plug_tx,
|
||||
plugin_call_receiver: Some(plug_rx),
|
||||
stream_manager: StreamManager::new(),
|
||||
protocol_info: None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Get the receiving end of the plugin call channel. Plugin calls that need to be handled
|
||||
/// will be sent here.
|
||||
pub(crate) fn take_plugin_call_receiver(
|
||||
&mut self,
|
||||
) -> Option<mpsc::Receiver<ReceivedPluginCall>> {
|
||||
self.plugin_call_receiver.take()
|
||||
}
|
||||
|
||||
/// Create an [`EngineInterface`] associated with the given call id.
|
||||
fn interface_for_context(&self, context: PluginCallId) -> EngineInterface {
|
||||
EngineInterface {
|
||||
state: self.state.clone(),
|
||||
stream_manager_handle: self.stream_manager.get_handle(),
|
||||
context: Some(context),
|
||||
}
|
||||
}
|
||||
|
||||
/// Send a [`ReceivedPluginCall`] to the channel
|
||||
fn send_plugin_call(&self, plugin_call: ReceivedPluginCall) -> Result<(), ShellError> {
|
||||
self.plugin_call_sender
|
||||
.send(plugin_call)
|
||||
.map_err(|_| ShellError::NushellFailed {
|
||||
msg: "Received a plugin call, but there's nowhere to send it".into(),
|
||||
})
|
||||
}
|
||||
|
||||
/// True if there are no other copies of the state (which would mean there are no interfaces
|
||||
/// and no stream readers/writers)
|
||||
pub(crate) fn is_finished(&self) -> bool {
|
||||
Arc::strong_count(&self.state) < 2
|
||||
}
|
||||
|
||||
/// Loop on input from the given reader as long as `is_finished()` is false
|
||||
///
|
||||
/// Any errors will be propagated to all read streams automatically.
|
||||
pub(crate) fn consume_all(
|
||||
&mut self,
|
||||
mut reader: impl PluginRead<PluginInput>,
|
||||
) -> Result<(), ShellError> {
|
||||
while let Some(msg) = reader.read().transpose() {
|
||||
if self.is_finished() {
|
||||
break;
|
||||
}
|
||||
|
||||
if let Err(err) = msg.and_then(|msg| self.consume(msg)) {
|
||||
let _ = self.stream_manager.broadcast_read_error(err.clone());
|
||||
return Err(err);
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl InterfaceManager for EngineInterfaceManager {
|
||||
type Interface = EngineInterface;
|
||||
type Input = PluginInput;
|
||||
|
||||
fn get_interface(&self) -> Self::Interface {
|
||||
EngineInterface {
|
||||
state: self.state.clone(),
|
||||
stream_manager_handle: self.stream_manager.get_handle(),
|
||||
context: None,
|
||||
}
|
||||
}
|
||||
|
||||
fn consume(&mut self, input: Self::Input) -> Result<(), ShellError> {
|
||||
log::trace!("from engine: {:?}", input);
|
||||
|
||||
match input {
|
||||
PluginInput::Hello(info) => {
|
||||
let local_info = ProtocolInfo::default();
|
||||
if local_info.is_compatible_with(&info)? {
|
||||
self.protocol_info = Some(info);
|
||||
Ok(())
|
||||
} else {
|
||||
self.protocol_info = None;
|
||||
Err(ShellError::PluginFailedToLoad {
|
||||
msg: format!(
|
||||
"Plugin is compiled for nushell version {}, \
|
||||
which is not compatible with version {}",
|
||||
local_info.version, info.version
|
||||
),
|
||||
})
|
||||
}
|
||||
}
|
||||
_ if self.protocol_info.is_none() => {
|
||||
// Must send protocol info first
|
||||
Err(ShellError::PluginFailedToLoad {
|
||||
msg: "Failed to receive initial Hello message. This engine might be too old"
|
||||
.into(),
|
||||
})
|
||||
}
|
||||
PluginInput::Stream(message) => self.consume_stream_message(message),
|
||||
PluginInput::Call(id, call) => match call {
|
||||
// We just let the receiver handle it rather than trying to store signature here
|
||||
// or something
|
||||
PluginCall::Signature => self.send_plugin_call(ReceivedPluginCall::Signature {
|
||||
engine: self.interface_for_context(id),
|
||||
}),
|
||||
// Set up the streams from the input and reformat to a ReceivedPluginCall
|
||||
PluginCall::Run(CallInfo {
|
||||
name,
|
||||
mut call,
|
||||
input,
|
||||
config,
|
||||
}) => {
|
||||
let interface = self.interface_for_context(id);
|
||||
// If there's an error with initialization of the input stream, just send
|
||||
// the error response rather than failing here
|
||||
match self.read_pipeline_data(input, None) {
|
||||
Ok(input) => {
|
||||
// Deserialize custom values in the arguments
|
||||
if let Err(err) = deserialize_call_args(&mut call) {
|
||||
return interface.write_response(Err(err))?.write();
|
||||
}
|
||||
// Send the plugin call to the receiver
|
||||
self.send_plugin_call(ReceivedPluginCall::Run {
|
||||
engine: interface,
|
||||
call: CallInfo {
|
||||
name,
|
||||
call,
|
||||
input,
|
||||
config,
|
||||
},
|
||||
})
|
||||
}
|
||||
err @ Err(_) => interface.write_response(err)?.write(),
|
||||
}
|
||||
}
|
||||
// Send request with the custom value
|
||||
PluginCall::CustomValueOp(custom_value, op) => {
|
||||
self.send_plugin_call(ReceivedPluginCall::CustomValueOp {
|
||||
engine: self.interface_for_context(id),
|
||||
custom_value,
|
||||
op,
|
||||
})
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
fn stream_manager(&self) -> &StreamManager {
|
||||
&self.stream_manager
|
||||
}
|
||||
|
||||
fn prepare_pipeline_data(&self, mut data: PipelineData) -> Result<PipelineData, ShellError> {
|
||||
// Deserialize custom values in the pipeline data
|
||||
match data {
|
||||
PipelineData::Value(ref mut value, _) => {
|
||||
PluginCustomValue::deserialize_custom_values_in(value)?;
|
||||
Ok(data)
|
||||
}
|
||||
PipelineData::ListStream(ListStream { stream, ctrlc, .. }, meta) => Ok(stream
|
||||
.map(|mut value| {
|
||||
let span = value.span();
|
||||
PluginCustomValue::deserialize_custom_values_in(&mut value)
|
||||
.map(|()| value)
|
||||
.unwrap_or_else(|err| Value::error(err, span))
|
||||
})
|
||||
.into_pipeline_data_with_metadata(meta, ctrlc)),
|
||||
PipelineData::Empty | PipelineData::ExternalStream { .. } => Ok(data),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Deserialize custom values in call arguments
|
||||
fn deserialize_call_args(call: &mut crate::EvaluatedCall) -> Result<(), ShellError> {
|
||||
call.positional
|
||||
.iter_mut()
|
||||
.try_for_each(PluginCustomValue::deserialize_custom_values_in)?;
|
||||
call.named
|
||||
.iter_mut()
|
||||
.flat_map(|(_, value)| value.as_mut())
|
||||
.try_for_each(PluginCustomValue::deserialize_custom_values_in)
|
||||
}
|
||||
|
||||
/// A reference through which the nushell engine can be interacted with during execution.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct EngineInterface {
|
||||
/// Shared state with the manager
|
||||
state: Arc<EngineInterfaceState>,
|
||||
/// Handle to stream manager
|
||||
stream_manager_handle: StreamManagerHandle,
|
||||
/// The plugin call this interface belongs to.
|
||||
context: Option<PluginCallId>,
|
||||
}
|
||||
|
||||
impl EngineInterface {
|
||||
/// Write the protocol info. This should be done after initialization
|
||||
pub(crate) fn hello(&self) -> Result<(), ShellError> {
|
||||
self.write(PluginOutput::Hello(ProtocolInfo::default()))?;
|
||||
self.flush()
|
||||
}
|
||||
|
||||
fn context(&self) -> Result<PluginCallId, ShellError> {
|
||||
self.context.ok_or_else(|| ShellError::NushellFailed {
|
||||
msg: "Tried to call an EngineInterface method that requires a call context \
|
||||
outside of one"
|
||||
.into(),
|
||||
})
|
||||
}
|
||||
|
||||
/// Write a call response of either [`PipelineData`] or an error. Returns the stream writer
|
||||
/// to finish writing the stream
|
||||
pub(crate) fn write_response(
|
||||
&self,
|
||||
result: Result<PipelineData, impl Into<LabeledError>>,
|
||||
) -> Result<PipelineDataWriter<Self>, ShellError> {
|
||||
match result {
|
||||
Ok(data) => {
|
||||
let (header, writer) = match self.init_write_pipeline_data(data) {
|
||||
Ok(tup) => tup,
|
||||
// If we get an error while trying to construct the pipeline data, send that
|
||||
// instead
|
||||
Err(err) => return self.write_response(Err(err)),
|
||||
};
|
||||
// Write pipeline data header response, and the full stream
|
||||
let response = PluginCallResponse::PipelineData(header);
|
||||
self.write(PluginOutput::CallResponse(self.context()?, response))?;
|
||||
self.flush()?;
|
||||
Ok(writer)
|
||||
}
|
||||
Err(err) => {
|
||||
let response = PluginCallResponse::Error(err.into());
|
||||
self.write(PluginOutput::CallResponse(self.context()?, response))?;
|
||||
self.flush()?;
|
||||
Ok(Default::default())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Write a call response of plugin signatures.
|
||||
pub(crate) fn write_signature(
|
||||
&self,
|
||||
signature: Vec<PluginSignature>,
|
||||
) -> Result<(), ShellError> {
|
||||
let response = PluginCallResponse::Signature(signature);
|
||||
self.write(PluginOutput::CallResponse(self.context()?, response))?;
|
||||
self.flush()
|
||||
}
|
||||
}
|
||||
|
||||
impl Interface for EngineInterface {
|
||||
type Output = PluginOutput;
|
||||
|
||||
fn write(&self, output: PluginOutput) -> Result<(), ShellError> {
|
||||
log::trace!("to engine: {:?}", output);
|
||||
self.state.writer.write(&output)
|
||||
}
|
||||
|
||||
fn flush(&self) -> Result<(), ShellError> {
|
||||
self.state.writer.flush()
|
||||
}
|
||||
|
||||
fn stream_id_sequence(&self) -> &Sequence {
|
||||
&self.state.stream_id_sequence
|
||||
}
|
||||
|
||||
fn stream_manager_handle(&self) -> &StreamManagerHandle {
|
||||
&self.stream_manager_handle
|
||||
}
|
||||
|
||||
fn prepare_pipeline_data(&self, mut data: PipelineData) -> Result<PipelineData, ShellError> {
|
||||
// Serialize custom values in the pipeline data
|
||||
match data {
|
||||
PipelineData::Value(ref mut value, _) => {
|
||||
PluginCustomValue::serialize_custom_values_in(value)?;
|
||||
Ok(data)
|
||||
}
|
||||
PipelineData::ListStream(ListStream { stream, ctrlc, .. }, meta) => Ok(stream
|
||||
.map(|mut value| {
|
||||
let span = value.span();
|
||||
PluginCustomValue::serialize_custom_values_in(&mut value)
|
||||
.map(|_| value)
|
||||
.unwrap_or_else(|err| Value::error(err, span))
|
||||
})
|
||||
.into_pipeline_data_with_metadata(meta, ctrlc)),
|
||||
PipelineData::Empty | PipelineData::ExternalStream { .. } => Ok(data),
|
||||
}
|
||||
}
|
||||
}
|
779
crates/nu-plugin/src/plugin/interface/engine/tests.rs
Normal file
779
crates/nu-plugin/src/plugin/interface/engine/tests.rs
Normal file
@ -0,0 +1,779 @@
|
||||
use nu_protocol::{
|
||||
CustomValue, IntoInterruptiblePipelineData, PipelineData, PluginSignature, ShellError, Span,
|
||||
Spanned, Value,
|
||||
};
|
||||
|
||||
use crate::{
|
||||
plugin::interface::{test_util::TestCase, Interface, InterfaceManager},
|
||||
protocol::{
|
||||
test_util::{expected_test_custom_value, test_plugin_custom_value, TestCustomValue},
|
||||
CallInfo, CustomValueOp, ExternalStreamInfo, ListStreamInfo, PipelineDataHeader,
|
||||
PluginCall, PluginCustomValue, PluginInput, Protocol, ProtocolInfo, RawStreamInfo,
|
||||
StreamData, StreamMessage,
|
||||
},
|
||||
EvaluatedCall, LabeledError, PluginCallResponse, PluginOutput,
|
||||
};
|
||||
|
||||
use super::ReceivedPluginCall;
|
||||
|
||||
#[test]
|
||||
fn manager_consume_all_consumes_messages() -> Result<(), ShellError> {
|
||||
let mut test = TestCase::new();
|
||||
let mut manager = test.engine();
|
||||
|
||||
// This message should be non-problematic
|
||||
test.add(PluginInput::Hello(ProtocolInfo::default()));
|
||||
|
||||
manager.consume_all(&mut test)?;
|
||||
|
||||
assert!(!test.has_unconsumed_read());
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn manager_consume_all_exits_after_streams_and_interfaces_are_dropped() -> Result<(), ShellError> {
|
||||
let mut test = TestCase::new();
|
||||
let mut manager = test.engine();
|
||||
|
||||
// Add messages that won't cause errors
|
||||
for _ in 0..5 {
|
||||
test.add(PluginInput::Hello(ProtocolInfo::default()));
|
||||
}
|
||||
|
||||
// Create a stream...
|
||||
let stream = manager.read_pipeline_data(
|
||||
PipelineDataHeader::ListStream(ListStreamInfo { id: 0 }),
|
||||
None,
|
||||
)?;
|
||||
|
||||
// and an interface...
|
||||
let interface = manager.get_interface();
|
||||
|
||||
// Expect that is_finished is false
|
||||
assert!(
|
||||
!manager.is_finished(),
|
||||
"is_finished is true even though active stream/interface exists"
|
||||
);
|
||||
|
||||
// After dropping, it should be true
|
||||
drop(stream);
|
||||
drop(interface);
|
||||
|
||||
assert!(
|
||||
manager.is_finished(),
|
||||
"is_finished is false even though manager has no stream or interface"
|
||||
);
|
||||
|
||||
// When it's true, consume_all shouldn't consume everything
|
||||
manager.consume_all(&mut test)?;
|
||||
|
||||
assert!(
|
||||
test.has_unconsumed_read(),
|
||||
"consume_all consumed the messages"
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn test_io_error() -> ShellError {
|
||||
ShellError::IOError {
|
||||
msg: "test io error".into(),
|
||||
}
|
||||
}
|
||||
|
||||
fn check_test_io_error(error: &ShellError) {
|
||||
assert!(
|
||||
format!("{error:?}").contains("test io error"),
|
||||
"error: {error}"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn manager_consume_all_propagates_error_to_readers() -> Result<(), ShellError> {
|
||||
let mut test = TestCase::new();
|
||||
let mut manager = test.engine();
|
||||
|
||||
test.set_read_error(test_io_error());
|
||||
|
||||
let stream = manager.read_pipeline_data(
|
||||
PipelineDataHeader::ListStream(ListStreamInfo { id: 0 }),
|
||||
None,
|
||||
)?;
|
||||
|
||||
manager
|
||||
.consume_all(&mut test)
|
||||
.expect_err("consume_all did not error");
|
||||
|
||||
// Ensure end of stream
|
||||
drop(manager);
|
||||
|
||||
let value = stream.into_iter().next().expect("stream is empty");
|
||||
if let Value::Error { error, .. } = value {
|
||||
check_test_io_error(&error);
|
||||
Ok(())
|
||||
} else {
|
||||
panic!("did not get an error");
|
||||
}
|
||||
}
|
||||
|
||||
fn invalid_input() -> PluginInput {
|
||||
// This should definitely cause an error, as 0.0.0 is not compatible with any version other than
|
||||
// itself
|
||||
PluginInput::Hello(ProtocolInfo {
|
||||
protocol: Protocol::NuPlugin,
|
||||
version: "0.0.0".into(),
|
||||
features: vec![],
|
||||
})
|
||||
}
|
||||
|
||||
fn check_invalid_input_error(error: &ShellError) {
|
||||
// the error message should include something about the version...
|
||||
assert!(format!("{error:?}").contains("0.0.0"), "error: {error}");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn manager_consume_all_propagates_message_error_to_readers() -> Result<(), ShellError> {
|
||||
let mut test = TestCase::new();
|
||||
let mut manager = test.engine();
|
||||
|
||||
test.add(invalid_input());
|
||||
|
||||
let stream = manager.read_pipeline_data(
|
||||
PipelineDataHeader::ExternalStream(ExternalStreamInfo {
|
||||
span: Span::test_data(),
|
||||
stdout: Some(RawStreamInfo {
|
||||
id: 0,
|
||||
is_binary: false,
|
||||
known_size: None,
|
||||
}),
|
||||
stderr: None,
|
||||
exit_code: None,
|
||||
trim_end_newline: false,
|
||||
}),
|
||||
None,
|
||||
)?;
|
||||
|
||||
manager
|
||||
.consume_all(&mut test)
|
||||
.expect_err("consume_all did not error");
|
||||
|
||||
// Ensure end of stream
|
||||
drop(manager);
|
||||
|
||||
let value = stream.into_iter().next().expect("stream is empty");
|
||||
if let Value::Error { error, .. } = value {
|
||||
check_invalid_input_error(&error);
|
||||
Ok(())
|
||||
} else {
|
||||
panic!("did not get an error");
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn manager_consume_sets_protocol_info_on_hello() -> Result<(), ShellError> {
|
||||
let mut manager = TestCase::new().engine();
|
||||
|
||||
let info = ProtocolInfo::default();
|
||||
|
||||
manager.consume(PluginInput::Hello(info.clone()))?;
|
||||
|
||||
let set_info = manager
|
||||
.protocol_info
|
||||
.as_ref()
|
||||
.expect("protocol info not set");
|
||||
assert_eq!(info.version, set_info.version);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn manager_consume_errors_on_wrong_nushell_version() -> Result<(), ShellError> {
|
||||
let mut manager = TestCase::new().engine();
|
||||
|
||||
let info = ProtocolInfo {
|
||||
protocol: Protocol::NuPlugin,
|
||||
version: "0.0.0".into(),
|
||||
features: vec![],
|
||||
};
|
||||
|
||||
manager
|
||||
.consume(PluginInput::Hello(info))
|
||||
.expect_err("version 0.0.0 should cause an error");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn manager_consume_errors_on_sending_other_messages_before_hello() -> Result<(), ShellError> {
|
||||
let mut manager = TestCase::new().engine();
|
||||
|
||||
// hello not set
|
||||
assert!(manager.protocol_info.is_none());
|
||||
|
||||
let error = manager
|
||||
.consume(PluginInput::Stream(StreamMessage::Drop(0)))
|
||||
.expect_err("consume before Hello should cause an error");
|
||||
|
||||
assert!(format!("{error:?}").contains("Hello"));
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn manager_consume_call_signature_forwards_to_receiver_with_context() -> Result<(), ShellError> {
|
||||
let mut manager = TestCase::new().engine();
|
||||
manager.protocol_info = Some(ProtocolInfo::default());
|
||||
|
||||
let rx = manager
|
||||
.take_plugin_call_receiver()
|
||||
.expect("couldn't take receiver");
|
||||
|
||||
manager.consume(PluginInput::Call(0, PluginCall::Signature))?;
|
||||
|
||||
match rx.try_recv().expect("call was not forwarded to receiver") {
|
||||
ReceivedPluginCall::Signature { engine } => {
|
||||
assert_eq!(Some(0), engine.context);
|
||||
Ok(())
|
||||
}
|
||||
call => panic!("wrong call type: {call:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn manager_consume_call_run_forwards_to_receiver_with_context() -> Result<(), ShellError> {
|
||||
let mut manager = TestCase::new().engine();
|
||||
manager.protocol_info = Some(ProtocolInfo::default());
|
||||
|
||||
let rx = manager
|
||||
.take_plugin_call_receiver()
|
||||
.expect("couldn't take receiver");
|
||||
|
||||
manager.consume(PluginInput::Call(
|
||||
17,
|
||||
PluginCall::Run(CallInfo {
|
||||
name: "bar".into(),
|
||||
call: EvaluatedCall {
|
||||
head: Span::test_data(),
|
||||
positional: vec![],
|
||||
named: vec![],
|
||||
},
|
||||
input: PipelineDataHeader::Empty,
|
||||
config: None,
|
||||
}),
|
||||
))?;
|
||||
|
||||
// Make sure the streams end and we don't deadlock
|
||||
drop(manager);
|
||||
|
||||
match rx.try_recv().expect("call was not forwarded to receiver") {
|
||||
ReceivedPluginCall::Run { engine, call: _ } => {
|
||||
assert_eq!(Some(17), engine.context, "context");
|
||||
Ok(())
|
||||
}
|
||||
call => panic!("wrong call type: {call:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn manager_consume_call_run_forwards_to_receiver_with_pipeline_data() -> Result<(), ShellError> {
|
||||
let mut manager = TestCase::new().engine();
|
||||
manager.protocol_info = Some(ProtocolInfo::default());
|
||||
|
||||
let rx = manager
|
||||
.take_plugin_call_receiver()
|
||||
.expect("couldn't take receiver");
|
||||
|
||||
manager.consume(PluginInput::Call(
|
||||
0,
|
||||
PluginCall::Run(CallInfo {
|
||||
name: "bar".into(),
|
||||
call: EvaluatedCall {
|
||||
head: Span::test_data(),
|
||||
positional: vec![],
|
||||
named: vec![],
|
||||
},
|
||||
input: PipelineDataHeader::ListStream(ListStreamInfo { id: 6 }),
|
||||
config: None,
|
||||
}),
|
||||
))?;
|
||||
|
||||
for i in 0..10 {
|
||||
manager.consume(PluginInput::Stream(StreamMessage::Data(
|
||||
6,
|
||||
Value::test_int(i).into(),
|
||||
)))?;
|
||||
}
|
||||
|
||||
manager.consume(PluginInput::Stream(StreamMessage::End(6)))?;
|
||||
|
||||
// Make sure the streams end and we don't deadlock
|
||||
drop(manager);
|
||||
|
||||
match rx.try_recv().expect("call was not forwarded to receiver") {
|
||||
ReceivedPluginCall::Run { engine: _, call } => {
|
||||
assert_eq!("bar", call.name);
|
||||
// Ensure we manage to receive the stream messages
|
||||
assert_eq!(10, call.input.into_iter().count());
|
||||
Ok(())
|
||||
}
|
||||
call => panic!("wrong call type: {call:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn manager_consume_call_run_deserializes_custom_values_in_args() -> Result<(), ShellError> {
|
||||
let mut manager = TestCase::new().engine();
|
||||
manager.protocol_info = Some(ProtocolInfo::default());
|
||||
|
||||
let rx = manager
|
||||
.take_plugin_call_receiver()
|
||||
.expect("couldn't take receiver");
|
||||
|
||||
let value = Value::test_custom_value(Box::new(test_plugin_custom_value()));
|
||||
|
||||
manager.consume(PluginInput::Call(
|
||||
0,
|
||||
PluginCall::Run(CallInfo {
|
||||
name: "bar".into(),
|
||||
call: EvaluatedCall {
|
||||
head: Span::test_data(),
|
||||
positional: vec![value.clone()],
|
||||
named: vec![(
|
||||
Spanned {
|
||||
item: "flag".into(),
|
||||
span: Span::test_data(),
|
||||
},
|
||||
Some(value),
|
||||
)],
|
||||
},
|
||||
input: PipelineDataHeader::Empty,
|
||||
config: None,
|
||||
}),
|
||||
))?;
|
||||
|
||||
// Make sure the streams end and we don't deadlock
|
||||
drop(manager);
|
||||
|
||||
match rx.try_recv().expect("call was not forwarded to receiver") {
|
||||
ReceivedPluginCall::Run { engine: _, call } => {
|
||||
assert_eq!(1, call.call.positional.len());
|
||||
assert_eq!(1, call.call.named.len());
|
||||
|
||||
for arg in call.call.positional {
|
||||
let custom_value: &TestCustomValue = arg
|
||||
.as_custom_value()?
|
||||
.as_any()
|
||||
.downcast_ref()
|
||||
.expect("positional arg is not TestCustomValue");
|
||||
assert_eq!(expected_test_custom_value(), *custom_value, "positional");
|
||||
}
|
||||
|
||||
for (key, val) in call.call.named {
|
||||
let key = &key.item;
|
||||
let custom_value: &TestCustomValue = val
|
||||
.as_ref()
|
||||
.unwrap_or_else(|| panic!("found empty named argument: {key}"))
|
||||
.as_custom_value()?
|
||||
.as_any()
|
||||
.downcast_ref()
|
||||
.unwrap_or_else(|| panic!("named arg {key} is not TestCustomValue"));
|
||||
assert_eq!(expected_test_custom_value(), *custom_value, "named: {key}");
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
call => panic!("wrong call type: {call:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn manager_consume_call_custom_value_op_forwards_to_receiver_with_context() -> Result<(), ShellError>
|
||||
{
|
||||
let mut manager = TestCase::new().engine();
|
||||
manager.protocol_info = Some(ProtocolInfo::default());
|
||||
|
||||
let rx = manager
|
||||
.take_plugin_call_receiver()
|
||||
.expect("couldn't take receiver");
|
||||
|
||||
manager.consume(PluginInput::Call(
|
||||
32,
|
||||
PluginCall::CustomValueOp(
|
||||
Spanned {
|
||||
item: test_plugin_custom_value(),
|
||||
span: Span::test_data(),
|
||||
},
|
||||
CustomValueOp::ToBaseValue,
|
||||
),
|
||||
))?;
|
||||
|
||||
match rx.try_recv().expect("call was not forwarded to receiver") {
|
||||
ReceivedPluginCall::CustomValueOp {
|
||||
engine,
|
||||
custom_value,
|
||||
op,
|
||||
} => {
|
||||
assert_eq!(Some(32), engine.context);
|
||||
assert_eq!("TestCustomValue", custom_value.item.name);
|
||||
assert!(
|
||||
matches!(op, CustomValueOp::ToBaseValue),
|
||||
"incorrect op: {op:?}"
|
||||
);
|
||||
}
|
||||
call => panic!("wrong call type: {call:?}"),
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn manager_prepare_pipeline_data_deserializes_custom_values() -> Result<(), ShellError> {
|
||||
let manager = TestCase::new().engine();
|
||||
|
||||
let data = manager.prepare_pipeline_data(PipelineData::Value(
|
||||
Value::test_custom_value(Box::new(test_plugin_custom_value())),
|
||||
None,
|
||||
))?;
|
||||
|
||||
let value = data
|
||||
.into_iter()
|
||||
.next()
|
||||
.expect("prepared pipeline data is empty");
|
||||
let custom_value: &TestCustomValue = value
|
||||
.as_custom_value()?
|
||||
.as_any()
|
||||
.downcast_ref()
|
||||
.expect("custom value is not a TestCustomValue, probably not deserialized");
|
||||
|
||||
assert_eq!(expected_test_custom_value(), *custom_value);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn manager_prepare_pipeline_data_deserializes_custom_values_in_streams() -> Result<(), ShellError> {
|
||||
let manager = TestCase::new().engine();
|
||||
|
||||
let data = manager.prepare_pipeline_data(
|
||||
[Value::test_custom_value(Box::new(
|
||||
test_plugin_custom_value(),
|
||||
))]
|
||||
.into_pipeline_data(None),
|
||||
)?;
|
||||
|
||||
let value = data
|
||||
.into_iter()
|
||||
.next()
|
||||
.expect("prepared pipeline data is empty");
|
||||
let custom_value: &TestCustomValue = value
|
||||
.as_custom_value()?
|
||||
.as_any()
|
||||
.downcast_ref()
|
||||
.expect("custom value is not a TestCustomValue, probably not deserialized");
|
||||
|
||||
assert_eq!(expected_test_custom_value(), *custom_value);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn manager_prepare_pipeline_data_embeds_deserialization_errors_in_streams() -> Result<(), ShellError>
|
||||
{
|
||||
let manager = TestCase::new().engine();
|
||||
|
||||
let invalid_custom_value = PluginCustomValue {
|
||||
name: "Invalid".into(),
|
||||
data: vec![0; 8], // should fail to decode to anything
|
||||
source: None,
|
||||
};
|
||||
|
||||
let span = Span::new(20, 30);
|
||||
let data = manager.prepare_pipeline_data(
|
||||
[Value::custom_value(Box::new(invalid_custom_value), span)].into_pipeline_data(None),
|
||||
)?;
|
||||
|
||||
let value = data
|
||||
.into_iter()
|
||||
.next()
|
||||
.expect("prepared pipeline data is empty");
|
||||
|
||||
match value {
|
||||
Value::Error { error, .. } => match *error {
|
||||
ShellError::CustomValueFailedToDecode {
|
||||
span: error_span, ..
|
||||
} => {
|
||||
assert_eq!(span, error_span, "error span not the same as the value's");
|
||||
}
|
||||
_ => panic!("expected ShellError::CustomValueFailedToDecode, but got {error:?}"),
|
||||
},
|
||||
_ => panic!("unexpected value, not error: {value:?}"),
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn interface_hello_sends_protocol_info() -> Result<(), ShellError> {
|
||||
let test = TestCase::new();
|
||||
let interface = test.engine().get_interface();
|
||||
interface.hello()?;
|
||||
|
||||
let written = test.next_written().expect("nothing written");
|
||||
|
||||
match written {
|
||||
PluginOutput::Hello(info) => {
|
||||
assert_eq!(ProtocolInfo::default().version, info.version);
|
||||
}
|
||||
_ => panic!("unexpected message written: {written:?}"),
|
||||
}
|
||||
|
||||
assert!(!test.has_unconsumed_write());
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn interface_write_response_with_value() -> Result<(), ShellError> {
|
||||
let test = TestCase::new();
|
||||
let interface = test.engine().interface_for_context(33);
|
||||
interface
|
||||
.write_response(Ok::<_, ShellError>(PipelineData::Value(
|
||||
Value::test_int(6),
|
||||
None,
|
||||
)))?
|
||||
.write()?;
|
||||
|
||||
let written = test.next_written().expect("nothing written");
|
||||
|
||||
match written {
|
||||
PluginOutput::CallResponse(id, response) => {
|
||||
assert_eq!(33, id, "id");
|
||||
match response {
|
||||
PluginCallResponse::PipelineData(header) => match header {
|
||||
PipelineDataHeader::Value(value) => assert_eq!(6, value.as_int()?),
|
||||
_ => panic!("unexpected pipeline data header: {header:?}"),
|
||||
},
|
||||
_ => panic!("unexpected response: {response:?}"),
|
||||
}
|
||||
}
|
||||
_ => panic!("unexpected message written: {written:?}"),
|
||||
}
|
||||
|
||||
assert!(!test.has_unconsumed_write());
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn interface_write_response_with_stream() -> Result<(), ShellError> {
|
||||
let test = TestCase::new();
|
||||
let manager = test.engine();
|
||||
let interface = manager.interface_for_context(34);
|
||||
|
||||
interface
|
||||
.write_response(Ok::<_, ShellError>(
|
||||
[Value::test_int(3), Value::test_int(4), Value::test_int(5)].into_pipeline_data(None),
|
||||
))?
|
||||
.write()?;
|
||||
|
||||
let written = test.next_written().expect("nothing written");
|
||||
|
||||
let info = match written {
|
||||
PluginOutput::CallResponse(_, response) => match response {
|
||||
PluginCallResponse::PipelineData(header) => match header {
|
||||
PipelineDataHeader::ListStream(info) => info,
|
||||
_ => panic!("expected ListStream header: {header:?}"),
|
||||
},
|
||||
_ => panic!("wrong response: {response:?}"),
|
||||
},
|
||||
_ => panic!("wrong output written: {written:?}"),
|
||||
};
|
||||
|
||||
for number in [3, 4, 5] {
|
||||
match test.next_written().expect("missing stream Data message") {
|
||||
PluginOutput::Stream(StreamMessage::Data(id, data)) => {
|
||||
assert_eq!(info.id, id, "Data id");
|
||||
match data {
|
||||
StreamData::List(val) => assert_eq!(number, val.as_int()?),
|
||||
_ => panic!("expected List data: {data:?}"),
|
||||
}
|
||||
}
|
||||
message => panic!("expected Stream(Data(..)): {message:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
match test.next_written().expect("missing stream End message") {
|
||||
PluginOutput::Stream(StreamMessage::End(id)) => assert_eq!(info.id, id, "End id"),
|
||||
message => panic!("expected Stream(Data(..)): {message:?}"),
|
||||
}
|
||||
|
||||
assert!(!test.has_unconsumed_write());
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn interface_write_response_with_error() -> Result<(), ShellError> {
|
||||
let test = TestCase::new();
|
||||
let interface = test.engine().interface_for_context(35);
|
||||
let labeled_error = LabeledError {
|
||||
label: "this is an error".into(),
|
||||
msg: "a test error".into(),
|
||||
span: None,
|
||||
};
|
||||
interface
|
||||
.write_response(Err(labeled_error.clone()))?
|
||||
.write()?;
|
||||
|
||||
let written = test.next_written().expect("nothing written");
|
||||
|
||||
match written {
|
||||
PluginOutput::CallResponse(id, response) => {
|
||||
assert_eq!(35, id, "id");
|
||||
match response {
|
||||
PluginCallResponse::Error(err) => assert_eq!(labeled_error, err),
|
||||
_ => panic!("unexpected response: {response:?}"),
|
||||
}
|
||||
}
|
||||
_ => panic!("unexpected message written: {written:?}"),
|
||||
}
|
||||
|
||||
assert!(!test.has_unconsumed_write());
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn interface_write_signature() -> Result<(), ShellError> {
|
||||
let test = TestCase::new();
|
||||
let interface = test.engine().interface_for_context(36);
|
||||
let signatures = vec![PluginSignature::build("test command")];
|
||||
interface.write_signature(signatures.clone())?;
|
||||
|
||||
let written = test.next_written().expect("nothing written");
|
||||
|
||||
match written {
|
||||
PluginOutput::CallResponse(id, response) => {
|
||||
assert_eq!(36, id, "id");
|
||||
match response {
|
||||
PluginCallResponse::Signature(sigs) => assert_eq!(1, sigs.len(), "sigs.len"),
|
||||
_ => panic!("unexpected response: {response:?}"),
|
||||
}
|
||||
}
|
||||
_ => panic!("unexpected message written: {written:?}"),
|
||||
}
|
||||
|
||||
assert!(!test.has_unconsumed_write());
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn interface_prepare_pipeline_data_serializes_custom_values() -> Result<(), ShellError> {
|
||||
let interface = TestCase::new().engine().get_interface();
|
||||
|
||||
let data = interface.prepare_pipeline_data(PipelineData::Value(
|
||||
Value::test_custom_value(Box::new(expected_test_custom_value())),
|
||||
None,
|
||||
))?;
|
||||
|
||||
let value = data
|
||||
.into_iter()
|
||||
.next()
|
||||
.expect("prepared pipeline data is empty");
|
||||
let custom_value: &PluginCustomValue = value
|
||||
.as_custom_value()?
|
||||
.as_any()
|
||||
.downcast_ref()
|
||||
.expect("custom value is not a PluginCustomValue, probably not serialized");
|
||||
|
||||
let expected = test_plugin_custom_value();
|
||||
assert_eq!(expected.name, custom_value.name);
|
||||
assert_eq!(expected.data, custom_value.data);
|
||||
assert!(custom_value.source.is_none());
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn interface_prepare_pipeline_data_serializes_custom_values_in_streams() -> Result<(), ShellError> {
|
||||
let interface = TestCase::new().engine().get_interface();
|
||||
|
||||
let data = interface.prepare_pipeline_data(
|
||||
[Value::test_custom_value(Box::new(
|
||||
expected_test_custom_value(),
|
||||
))]
|
||||
.into_pipeline_data(None),
|
||||
)?;
|
||||
|
||||
let value = data
|
||||
.into_iter()
|
||||
.next()
|
||||
.expect("prepared pipeline data is empty");
|
||||
let custom_value: &PluginCustomValue = value
|
||||
.as_custom_value()?
|
||||
.as_any()
|
||||
.downcast_ref()
|
||||
.expect("custom value is not a PluginCustomValue, probably not serialized");
|
||||
|
||||
let expected = test_plugin_custom_value();
|
||||
assert_eq!(expected.name, custom_value.name);
|
||||
assert_eq!(expected.data, custom_value.data);
|
||||
assert!(custom_value.source.is_none());
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// A non-serializable custom value. Should cause a serialization error
|
||||
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
|
||||
enum CantSerialize {
|
||||
#[serde(skip_serializing)]
|
||||
BadVariant,
|
||||
}
|
||||
|
||||
#[typetag::serde]
|
||||
impl CustomValue for CantSerialize {
|
||||
fn clone_value(&self, span: Span) -> Value {
|
||||
Value::custom_value(Box::new(self.clone()), span)
|
||||
}
|
||||
|
||||
fn value_string(&self) -> String {
|
||||
"CantSerialize".into()
|
||||
}
|
||||
|
||||
fn to_base_value(&self, _span: Span) -> Result<Value, ShellError> {
|
||||
unimplemented!()
|
||||
}
|
||||
|
||||
fn as_any(&self) -> &dyn std::any::Any {
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn interface_prepare_pipeline_data_embeds_serialization_errors_in_streams() -> Result<(), ShellError>
|
||||
{
|
||||
let interface = TestCase::new().engine().get_interface();
|
||||
|
||||
let span = Span::new(40, 60);
|
||||
let data = interface.prepare_pipeline_data(
|
||||
[Value::custom_value(
|
||||
Box::new(CantSerialize::BadVariant),
|
||||
span,
|
||||
)]
|
||||
.into_pipeline_data(None),
|
||||
)?;
|
||||
|
||||
let value = data
|
||||
.into_iter()
|
||||
.next()
|
||||
.expect("prepared pipeline data is empty");
|
||||
|
||||
match value {
|
||||
Value::Error { error, .. } => match *error {
|
||||
ShellError::CustomValueFailedToEncode {
|
||||
span: error_span, ..
|
||||
} => {
|
||||
assert_eq!(span, error_span, "error span not the same as the value's");
|
||||
}
|
||||
_ => panic!("expected ShellError::CustomValueFailedToEncode, but got {error:?}"),
|
||||
},
|
||||
_ => panic!("unexpected value, not error: {value:?}"),
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
504
crates/nu-plugin/src/plugin/interface/plugin.rs
Normal file
504
crates/nu-plugin/src/plugin/interface/plugin.rs
Normal file
@ -0,0 +1,504 @@
|
||||
//! Interface used by the engine to communicate with the plugin.
|
||||
|
||||
use std::{
|
||||
collections::{btree_map, BTreeMap},
|
||||
sync::{mpsc, Arc},
|
||||
};
|
||||
|
||||
use nu_protocol::{
|
||||
IntoInterruptiblePipelineData, ListStream, PipelineData, PluginSignature, ShellError, Spanned,
|
||||
Value,
|
||||
};
|
||||
|
||||
use crate::{
|
||||
plugin::{context::PluginExecutionContext, PluginIdentity},
|
||||
protocol::{
|
||||
CallInfo, CustomValueOp, PluginCall, PluginCallId, PluginCallResponse, PluginCustomValue,
|
||||
PluginInput, PluginOutput, ProtocolInfo,
|
||||
},
|
||||
sequence::Sequence,
|
||||
};
|
||||
|
||||
use super::{
|
||||
stream::{StreamManager, StreamManagerHandle},
|
||||
Interface, InterfaceManager, PipelineDataWriter, PluginRead, PluginWrite,
|
||||
};
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests;
|
||||
|
||||
#[derive(Debug)]
|
||||
enum ReceivedPluginCallMessage {
|
||||
/// The final response to send
|
||||
Response(PluginCallResponse<PipelineData>),
|
||||
|
||||
/// An critical error with the interface
|
||||
Error(ShellError),
|
||||
}
|
||||
|
||||
/// Context for plugin call execution
|
||||
#[derive(Clone)]
|
||||
pub(crate) struct Context(Arc<dyn PluginExecutionContext>);
|
||||
|
||||
impl std::fmt::Debug for Context {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
f.write_str("Context")
|
||||
}
|
||||
}
|
||||
|
||||
impl std::ops::Deref for Context {
|
||||
type Target = dyn PluginExecutionContext;
|
||||
|
||||
fn deref(&self) -> &Self::Target {
|
||||
&*self.0
|
||||
}
|
||||
}
|
||||
|
||||
/// Internal shared state between the manager and each interface.
|
||||
struct PluginInterfaceState {
|
||||
/// The identity of the plugin being interfaced with
|
||||
identity: Arc<PluginIdentity>,
|
||||
/// Sequence for generating plugin call ids
|
||||
plugin_call_id_sequence: Sequence,
|
||||
/// Sequence for generating stream ids
|
||||
stream_id_sequence: Sequence,
|
||||
/// Sender to subscribe to a plugin call response
|
||||
plugin_call_subscription_sender: mpsc::Sender<(PluginCallId, PluginCallSubscription)>,
|
||||
/// The synchronized output writer
|
||||
writer: Box<dyn PluginWrite<PluginInput>>,
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for PluginInterfaceState {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
f.debug_struct("PluginInterfaceState")
|
||||
.field("identity", &self.identity)
|
||||
.field("plugin_call_id_sequence", &self.plugin_call_id_sequence)
|
||||
.field("stream_id_sequence", &self.stream_id_sequence)
|
||||
.field(
|
||||
"plugin_call_subscription_sender",
|
||||
&self.plugin_call_subscription_sender,
|
||||
)
|
||||
.finish_non_exhaustive()
|
||||
}
|
||||
}
|
||||
|
||||
/// Sent to the [`PluginInterfaceManager`] before making a plugin call to indicate interest in its
|
||||
/// response.
|
||||
#[derive(Debug)]
|
||||
struct PluginCallSubscription {
|
||||
/// The sender back to the thread that is waiting for the plugin call response
|
||||
sender: mpsc::Sender<ReceivedPluginCallMessage>,
|
||||
/// Optional context for the environment of a plugin call
|
||||
context: Option<Context>,
|
||||
}
|
||||
|
||||
/// Manages reading and dispatching messages for [`PluginInterface`]s.
|
||||
#[derive(Debug)]
|
||||
pub(crate) struct PluginInterfaceManager {
|
||||
/// Shared state
|
||||
state: Arc<PluginInterfaceState>,
|
||||
/// Manages stream messages and state
|
||||
stream_manager: StreamManager,
|
||||
/// Protocol version info, set after `Hello` received
|
||||
protocol_info: Option<ProtocolInfo>,
|
||||
/// Subscriptions for messages related to plugin calls
|
||||
plugin_call_subscriptions: BTreeMap<PluginCallId, PluginCallSubscription>,
|
||||
/// Receiver for plugin call subscriptions
|
||||
plugin_call_subscription_receiver: mpsc::Receiver<(PluginCallId, PluginCallSubscription)>,
|
||||
}
|
||||
|
||||
impl PluginInterfaceManager {
|
||||
pub(crate) fn new(
|
||||
identity: Arc<PluginIdentity>,
|
||||
writer: impl PluginWrite<PluginInput> + 'static,
|
||||
) -> PluginInterfaceManager {
|
||||
let (subscription_tx, subscription_rx) = mpsc::channel();
|
||||
|
||||
PluginInterfaceManager {
|
||||
state: Arc::new(PluginInterfaceState {
|
||||
identity,
|
||||
plugin_call_id_sequence: Sequence::default(),
|
||||
stream_id_sequence: Sequence::default(),
|
||||
plugin_call_subscription_sender: subscription_tx,
|
||||
writer: Box::new(writer),
|
||||
}),
|
||||
stream_manager: StreamManager::new(),
|
||||
protocol_info: None,
|
||||
plugin_call_subscriptions: BTreeMap::new(),
|
||||
plugin_call_subscription_receiver: subscription_rx,
|
||||
}
|
||||
}
|
||||
|
||||
/// Consume pending messages in the `plugin_call_subscription_receiver`
|
||||
fn receive_plugin_call_subscriptions(&mut self) {
|
||||
while let Ok((id, subscription)) = self.plugin_call_subscription_receiver.try_recv() {
|
||||
if let btree_map::Entry::Vacant(e) = self.plugin_call_subscriptions.entry(id) {
|
||||
e.insert(subscription);
|
||||
} else {
|
||||
log::warn!("Duplicate plugin call ID ignored: {id}");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Find the context corresponding to the given plugin call id
|
||||
fn get_context(&mut self, id: PluginCallId) -> Result<Option<Context>, ShellError> {
|
||||
// Make sure we're up to date
|
||||
self.receive_plugin_call_subscriptions();
|
||||
// Find the subscription and return the context
|
||||
self.plugin_call_subscriptions
|
||||
.get(&id)
|
||||
.map(|sub| sub.context.clone())
|
||||
.ok_or_else(|| ShellError::PluginFailedToDecode {
|
||||
msg: format!("Unknown plugin call ID: {id}"),
|
||||
})
|
||||
}
|
||||
|
||||
/// Send a [`PluginCallResponse`] to the appropriate sender
|
||||
fn send_plugin_call_response(
|
||||
&mut self,
|
||||
id: PluginCallId,
|
||||
response: PluginCallResponse<PipelineData>,
|
||||
) -> Result<(), ShellError> {
|
||||
// Ensure we're caught up on the subscriptions made
|
||||
self.receive_plugin_call_subscriptions();
|
||||
|
||||
// Remove the subscription, since this would be the last message
|
||||
if let Some(subscription) = self.plugin_call_subscriptions.remove(&id) {
|
||||
if subscription
|
||||
.sender
|
||||
.send(ReceivedPluginCallMessage::Response(response))
|
||||
.is_err()
|
||||
{
|
||||
log::warn!("Received a plugin call response for id={id}, but the caller hung up");
|
||||
}
|
||||
Ok(())
|
||||
} else {
|
||||
Err(ShellError::PluginFailedToDecode {
|
||||
msg: format!("Unknown plugin call ID: {id}"),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
/// True if there are no other copies of the state (which would mean there are no interfaces
|
||||
/// and no stream readers/writers)
|
||||
pub(crate) fn is_finished(&self) -> bool {
|
||||
Arc::strong_count(&self.state) < 2
|
||||
}
|
||||
|
||||
/// Loop on input from the given reader as long as `is_finished()` is false
|
||||
///
|
||||
/// Any errors will be propagated to all read streams automatically.
|
||||
pub(crate) fn consume_all(
|
||||
&mut self,
|
||||
mut reader: impl PluginRead<PluginOutput>,
|
||||
) -> Result<(), ShellError> {
|
||||
while let Some(msg) = reader.read().transpose() {
|
||||
if self.is_finished() {
|
||||
break;
|
||||
}
|
||||
|
||||
if let Err(err) = msg.and_then(|msg| self.consume(msg)) {
|
||||
// Error to streams
|
||||
let _ = self.stream_manager.broadcast_read_error(err.clone());
|
||||
// Error to call waiters
|
||||
self.receive_plugin_call_subscriptions();
|
||||
for subscription in
|
||||
std::mem::take(&mut self.plugin_call_subscriptions).into_values()
|
||||
{
|
||||
let _ = subscription
|
||||
.sender
|
||||
.send(ReceivedPluginCallMessage::Error(err.clone()));
|
||||
}
|
||||
return Err(err);
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl InterfaceManager for PluginInterfaceManager {
|
||||
type Interface = PluginInterface;
|
||||
type Input = PluginOutput;
|
||||
|
||||
fn get_interface(&self) -> Self::Interface {
|
||||
PluginInterface {
|
||||
state: self.state.clone(),
|
||||
stream_manager_handle: self.stream_manager.get_handle(),
|
||||
}
|
||||
}
|
||||
|
||||
fn consume(&mut self, input: Self::Input) -> Result<(), ShellError> {
|
||||
log::trace!("from plugin: {:?}", input);
|
||||
|
||||
match input {
|
||||
PluginOutput::Hello(info) => {
|
||||
let local_info = ProtocolInfo::default();
|
||||
if local_info.is_compatible_with(&info)? {
|
||||
self.protocol_info = Some(info);
|
||||
Ok(())
|
||||
} else {
|
||||
self.protocol_info = None;
|
||||
Err(ShellError::PluginFailedToLoad {
|
||||
msg: format!(
|
||||
"Plugin is compiled for nushell version {}, \
|
||||
which is not compatible with version {}",
|
||||
info.version, local_info.version
|
||||
),
|
||||
})
|
||||
}
|
||||
}
|
||||
_ if self.protocol_info.is_none() => {
|
||||
// Must send protocol info first
|
||||
Err(ShellError::PluginFailedToLoad {
|
||||
msg: "Failed to receive initial Hello message. \
|
||||
This plugin might be too old"
|
||||
.into(),
|
||||
})
|
||||
}
|
||||
PluginOutput::Stream(message) => self.consume_stream_message(message),
|
||||
PluginOutput::CallResponse(id, response) => {
|
||||
// Handle reading the pipeline data, if any
|
||||
let response = match response {
|
||||
PluginCallResponse::Error(err) => PluginCallResponse::Error(err),
|
||||
PluginCallResponse::Signature(sigs) => PluginCallResponse::Signature(sigs),
|
||||
PluginCallResponse::PipelineData(data) => {
|
||||
// If there's an error with initializing this stream, change it to a plugin
|
||||
// error response, but send it anyway
|
||||
let exec_context = self.get_context(id)?;
|
||||
let ctrlc = exec_context.as_ref().and_then(|c| c.0.ctrlc());
|
||||
match self.read_pipeline_data(data, ctrlc) {
|
||||
Ok(data) => PluginCallResponse::PipelineData(data),
|
||||
Err(err) => PluginCallResponse::Error(err.into()),
|
||||
}
|
||||
}
|
||||
};
|
||||
self.send_plugin_call_response(id, response)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn stream_manager(&self) -> &StreamManager {
|
||||
&self.stream_manager
|
||||
}
|
||||
|
||||
fn prepare_pipeline_data(&self, mut data: PipelineData) -> Result<PipelineData, ShellError> {
|
||||
// Add source to any values
|
||||
match data {
|
||||
PipelineData::Value(ref mut value, _) => {
|
||||
PluginCustomValue::add_source(value, &self.state.identity);
|
||||
Ok(data)
|
||||
}
|
||||
PipelineData::ListStream(ListStream { stream, ctrlc, .. }, meta) => {
|
||||
let identity = self.state.identity.clone();
|
||||
Ok(stream
|
||||
.map(move |mut value| {
|
||||
PluginCustomValue::add_source(&mut value, &identity);
|
||||
value
|
||||
})
|
||||
.into_pipeline_data_with_metadata(meta, ctrlc))
|
||||
}
|
||||
PipelineData::Empty | PipelineData::ExternalStream { .. } => Ok(data),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// A reference through which a plugin can be interacted with during execution.
|
||||
#[derive(Debug, Clone)]
|
||||
pub(crate) struct PluginInterface {
|
||||
/// Shared state
|
||||
state: Arc<PluginInterfaceState>,
|
||||
/// Handle to stream manager
|
||||
stream_manager_handle: StreamManagerHandle,
|
||||
}
|
||||
|
||||
impl PluginInterface {
|
||||
/// Write the protocol info. This should be done after initialization
|
||||
pub(crate) fn hello(&self) -> Result<(), ShellError> {
|
||||
self.write(PluginInput::Hello(ProtocolInfo::default()))?;
|
||||
self.flush()
|
||||
}
|
||||
|
||||
/// Write a plugin call message. Returns the writer for the stream, and the receiver for
|
||||
/// messages (e.g. response) related to the plugin call
|
||||
fn write_plugin_call(
|
||||
&self,
|
||||
call: PluginCall<PipelineData>,
|
||||
context: Option<Context>,
|
||||
) -> Result<
|
||||
(
|
||||
PipelineDataWriter<Self>,
|
||||
mpsc::Receiver<ReceivedPluginCallMessage>,
|
||||
),
|
||||
ShellError,
|
||||
> {
|
||||
let id = self.state.plugin_call_id_sequence.next()?;
|
||||
let (tx, rx) = mpsc::channel();
|
||||
|
||||
// Convert the call into one with a header and handle the stream, if necessary
|
||||
let (call, writer) = match call {
|
||||
PluginCall::Signature => (PluginCall::Signature, Default::default()),
|
||||
PluginCall::CustomValueOp(value, op) => {
|
||||
(PluginCall::CustomValueOp(value, op), Default::default())
|
||||
}
|
||||
PluginCall::Run(CallInfo {
|
||||
name,
|
||||
call,
|
||||
input,
|
||||
config,
|
||||
}) => {
|
||||
let (header, writer) = self.init_write_pipeline_data(input)?;
|
||||
(
|
||||
PluginCall::Run(CallInfo {
|
||||
name,
|
||||
call,
|
||||
input: header,
|
||||
config,
|
||||
}),
|
||||
writer,
|
||||
)
|
||||
}
|
||||
};
|
||||
|
||||
// Register the subscription to the response, and the context
|
||||
self.state
|
||||
.plugin_call_subscription_sender
|
||||
.send((
|
||||
id,
|
||||
PluginCallSubscription {
|
||||
sender: tx,
|
||||
context,
|
||||
},
|
||||
))
|
||||
.map_err(|_| ShellError::NushellFailed {
|
||||
msg: "PluginInterfaceManager hung up and is no longer accepting plugin calls"
|
||||
.into(),
|
||||
})?;
|
||||
|
||||
// Write request
|
||||
self.write(PluginInput::Call(id, call))?;
|
||||
self.flush()?;
|
||||
|
||||
Ok((writer, rx))
|
||||
}
|
||||
|
||||
/// Read the channel for plugin call messages and handle them until the response is received.
|
||||
fn receive_plugin_call_response(
|
||||
&self,
|
||||
rx: mpsc::Receiver<ReceivedPluginCallMessage>,
|
||||
) -> Result<PluginCallResponse<PipelineData>, ShellError> {
|
||||
if let Ok(msg) = rx.recv() {
|
||||
// Handle message from receiver
|
||||
match msg {
|
||||
ReceivedPluginCallMessage::Response(resp) => Ok(resp),
|
||||
ReceivedPluginCallMessage::Error(err) => Err(err),
|
||||
}
|
||||
} else {
|
||||
// If we fail to get a response
|
||||
Err(ShellError::PluginFailedToDecode {
|
||||
msg: "Failed to receive response to plugin call".into(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
/// Perform a plugin call. Input and output streams are handled automatically.
|
||||
fn plugin_call(
|
||||
&self,
|
||||
call: PluginCall<PipelineData>,
|
||||
context: &Option<Context>,
|
||||
) -> Result<PluginCallResponse<PipelineData>, ShellError> {
|
||||
let (writer, rx) = self.write_plugin_call(call, context.clone())?;
|
||||
|
||||
// Finish writing stream in the background
|
||||
writer.write_background();
|
||||
|
||||
self.receive_plugin_call_response(rx)
|
||||
}
|
||||
|
||||
/// Get the command signatures from the plugin.
|
||||
pub(crate) fn get_signature(&self) -> Result<Vec<PluginSignature>, ShellError> {
|
||||
match self.plugin_call(PluginCall::Signature, &None)? {
|
||||
PluginCallResponse::Signature(sigs) => Ok(sigs),
|
||||
PluginCallResponse::Error(err) => Err(err.into()),
|
||||
_ => Err(ShellError::PluginFailedToDecode {
|
||||
msg: "Received unexpected response to plugin Signature call".into(),
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
/// Run the plugin with the given call and execution context.
|
||||
pub(crate) fn run(
|
||||
&self,
|
||||
call: CallInfo<PipelineData>,
|
||||
context: Arc<impl PluginExecutionContext + 'static>,
|
||||
) -> Result<PipelineData, ShellError> {
|
||||
let context = Some(Context(context));
|
||||
match self.plugin_call(PluginCall::Run(call), &context)? {
|
||||
PluginCallResponse::PipelineData(data) => Ok(data),
|
||||
PluginCallResponse::Error(err) => Err(err.into()),
|
||||
_ => Err(ShellError::PluginFailedToDecode {
|
||||
msg: "Received unexpected response to plugin Run call".into(),
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
/// Collapse a custom value to its base value.
|
||||
pub(crate) fn custom_value_to_base_value(
|
||||
&self,
|
||||
value: Spanned<PluginCustomValue>,
|
||||
) -> Result<Value, ShellError> {
|
||||
let span = value.span;
|
||||
let call = PluginCall::CustomValueOp(value, CustomValueOp::ToBaseValue);
|
||||
match self.plugin_call(call, &None)? {
|
||||
PluginCallResponse::PipelineData(out_data) => Ok(out_data.into_value(span)),
|
||||
PluginCallResponse::Error(err) => Err(err.into()),
|
||||
_ => Err(ShellError::PluginFailedToDecode {
|
||||
msg: "Received unexpected response to plugin CustomValueOp::ToBaseValue call"
|
||||
.into(),
|
||||
}),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Interface for PluginInterface {
|
||||
type Output = PluginInput;
|
||||
|
||||
fn write(&self, input: PluginInput) -> Result<(), ShellError> {
|
||||
log::trace!("to plugin: {:?}", input);
|
||||
self.state.writer.write(&input)
|
||||
}
|
||||
|
||||
fn flush(&self) -> Result<(), ShellError> {
|
||||
self.state.writer.flush()
|
||||
}
|
||||
|
||||
fn stream_id_sequence(&self) -> &Sequence {
|
||||
&self.state.stream_id_sequence
|
||||
}
|
||||
|
||||
fn stream_manager_handle(&self) -> &StreamManagerHandle {
|
||||
&self.stream_manager_handle
|
||||
}
|
||||
|
||||
fn prepare_pipeline_data(&self, data: PipelineData) -> Result<PipelineData, ShellError> {
|
||||
// Validate the destination of values in the pipeline data
|
||||
match data {
|
||||
PipelineData::Value(mut value, meta) => {
|
||||
PluginCustomValue::verify_source(&mut value, &self.state.identity)?;
|
||||
Ok(PipelineData::Value(value, meta))
|
||||
}
|
||||
PipelineData::ListStream(ListStream { stream, ctrlc, .. }, meta) => {
|
||||
let identity = self.state.identity.clone();
|
||||
Ok(stream
|
||||
.map(move |mut value| {
|
||||
match PluginCustomValue::verify_source(&mut value, &identity) {
|
||||
Ok(()) => value,
|
||||
// Put the error in the stream instead
|
||||
Err(err) => Value::error(err, value.span()),
|
||||
}
|
||||
})
|
||||
.into_pipeline_data_with_metadata(meta, ctrlc))
|
||||
}
|
||||
PipelineData::Empty | PipelineData::ExternalStream { .. } => Ok(data),
|
||||
}
|
||||
}
|
||||
}
|
842
crates/nu-plugin/src/plugin/interface/plugin/tests.rs
Normal file
842
crates/nu-plugin/src/plugin/interface/plugin/tests.rs
Normal file
@ -0,0 +1,842 @@
|
||||
use std::sync::mpsc;
|
||||
|
||||
use nu_protocol::{
|
||||
IntoInterruptiblePipelineData, PipelineData, PluginSignature, ShellError, Span, Spanned, Value,
|
||||
};
|
||||
|
||||
use crate::{
|
||||
plugin::{
|
||||
context::PluginExecutionBogusContext,
|
||||
interface::{test_util::TestCase, Interface, InterfaceManager},
|
||||
PluginIdentity,
|
||||
},
|
||||
protocol::{
|
||||
test_util::{expected_test_custom_value, test_plugin_custom_value},
|
||||
CallInfo, CustomValueOp, ExternalStreamInfo, ListStreamInfo, PipelineDataHeader,
|
||||
PluginCall, PluginCallId, PluginCustomValue, PluginInput, Protocol, ProtocolInfo,
|
||||
RawStreamInfo, StreamData, StreamMessage,
|
||||
},
|
||||
EvaluatedCall, PluginCallResponse, PluginOutput,
|
||||
};
|
||||
|
||||
use super::{
|
||||
PluginCallSubscription, PluginInterface, PluginInterfaceManager, ReceivedPluginCallMessage,
|
||||
};
|
||||
|
||||
#[test]
|
||||
fn manager_consume_all_consumes_messages() -> Result<(), ShellError> {
|
||||
let mut test = TestCase::new();
|
||||
let mut manager = test.plugin("test");
|
||||
|
||||
// This message should be non-problematic
|
||||
test.add(PluginOutput::Hello(ProtocolInfo::default()));
|
||||
|
||||
manager.consume_all(&mut test)?;
|
||||
|
||||
assert!(!test.has_unconsumed_read());
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn manager_consume_all_exits_after_streams_and_interfaces_are_dropped() -> Result<(), ShellError> {
|
||||
let mut test = TestCase::new();
|
||||
let mut manager = test.plugin("test");
|
||||
|
||||
// Add messages that won't cause errors
|
||||
for _ in 0..5 {
|
||||
test.add(PluginOutput::Hello(ProtocolInfo::default()));
|
||||
}
|
||||
|
||||
// Create a stream...
|
||||
let stream = manager.read_pipeline_data(
|
||||
PipelineDataHeader::ListStream(ListStreamInfo { id: 0 }),
|
||||
None,
|
||||
)?;
|
||||
|
||||
// and an interface...
|
||||
let interface = manager.get_interface();
|
||||
|
||||
// Expect that is_finished is false
|
||||
assert!(
|
||||
!manager.is_finished(),
|
||||
"is_finished is true even though active stream/interface exists"
|
||||
);
|
||||
|
||||
// After dropping, it should be true
|
||||
drop(stream);
|
||||
drop(interface);
|
||||
|
||||
assert!(
|
||||
manager.is_finished(),
|
||||
"is_finished is false even though manager has no stream or interface"
|
||||
);
|
||||
|
||||
// When it's true, consume_all shouldn't consume everything
|
||||
manager.consume_all(&mut test)?;
|
||||
|
||||
assert!(
|
||||
test.has_unconsumed_read(),
|
||||
"consume_all consumed the messages"
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn test_io_error() -> ShellError {
|
||||
ShellError::IOError {
|
||||
msg: "test io error".into(),
|
||||
}
|
||||
}
|
||||
|
||||
fn check_test_io_error(error: &ShellError) {
|
||||
assert!(
|
||||
format!("{error:?}").contains("test io error"),
|
||||
"error: {error}"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn manager_consume_all_propagates_io_error_to_readers() -> Result<(), ShellError> {
|
||||
let mut test = TestCase::new();
|
||||
let mut manager = test.plugin("test");
|
||||
|
||||
test.set_read_error(test_io_error());
|
||||
|
||||
let stream = manager.read_pipeline_data(
|
||||
PipelineDataHeader::ListStream(ListStreamInfo { id: 0 }),
|
||||
None,
|
||||
)?;
|
||||
|
||||
manager
|
||||
.consume_all(&mut test)
|
||||
.expect_err("consume_all did not error");
|
||||
|
||||
// Ensure end of stream
|
||||
drop(manager);
|
||||
|
||||
let value = stream.into_iter().next().expect("stream is empty");
|
||||
if let Value::Error { error, .. } = value {
|
||||
check_test_io_error(&error);
|
||||
Ok(())
|
||||
} else {
|
||||
panic!("did not get an error");
|
||||
}
|
||||
}
|
||||
|
||||
fn invalid_output() -> PluginOutput {
|
||||
// This should definitely cause an error, as 0.0.0 is not compatible with any version other than
|
||||
// itself
|
||||
PluginOutput::Hello(ProtocolInfo {
|
||||
protocol: Protocol::NuPlugin,
|
||||
version: "0.0.0".into(),
|
||||
features: vec![],
|
||||
})
|
||||
}
|
||||
|
||||
fn check_invalid_output_error(error: &ShellError) {
|
||||
// the error message should include something about the version...
|
||||
assert!(format!("{error:?}").contains("0.0.0"), "error: {error}");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn manager_consume_all_propagates_message_error_to_readers() -> Result<(), ShellError> {
|
||||
let mut test = TestCase::new();
|
||||
let mut manager = test.plugin("test");
|
||||
|
||||
test.add(invalid_output());
|
||||
|
||||
let stream = manager.read_pipeline_data(
|
||||
PipelineDataHeader::ExternalStream(ExternalStreamInfo {
|
||||
span: Span::test_data(),
|
||||
stdout: Some(RawStreamInfo {
|
||||
id: 0,
|
||||
is_binary: false,
|
||||
known_size: None,
|
||||
}),
|
||||
stderr: None,
|
||||
exit_code: None,
|
||||
trim_end_newline: false,
|
||||
}),
|
||||
None,
|
||||
)?;
|
||||
|
||||
manager
|
||||
.consume_all(&mut test)
|
||||
.expect_err("consume_all did not error");
|
||||
|
||||
// Ensure end of stream
|
||||
drop(manager);
|
||||
|
||||
let value = stream.into_iter().next().expect("stream is empty");
|
||||
if let Value::Error { error, .. } = value {
|
||||
check_invalid_output_error(&error);
|
||||
Ok(())
|
||||
} else {
|
||||
panic!("did not get an error");
|
||||
}
|
||||
}
|
||||
|
||||
fn fake_plugin_call(
|
||||
manager: &mut PluginInterfaceManager,
|
||||
id: PluginCallId,
|
||||
) -> mpsc::Receiver<ReceivedPluginCallMessage> {
|
||||
// Set up a fake plugin call subscription
|
||||
let (tx, rx) = mpsc::channel();
|
||||
|
||||
manager.plugin_call_subscriptions.insert(
|
||||
id,
|
||||
PluginCallSubscription {
|
||||
sender: tx,
|
||||
context: None,
|
||||
},
|
||||
);
|
||||
|
||||
rx
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn manager_consume_all_propagates_io_error_to_plugin_calls() -> Result<(), ShellError> {
|
||||
let mut test = TestCase::new();
|
||||
let mut manager = test.plugin("test");
|
||||
let interface = manager.get_interface();
|
||||
|
||||
test.set_read_error(test_io_error());
|
||||
|
||||
// Set up a fake plugin call subscription
|
||||
let rx = fake_plugin_call(&mut manager, 0);
|
||||
|
||||
manager
|
||||
.consume_all(&mut test)
|
||||
.expect_err("consume_all did not error");
|
||||
|
||||
// We have to hold interface until now otherwise consume_all won't try to process the message
|
||||
drop(interface);
|
||||
|
||||
let message = rx.try_recv().expect("failed to get plugin call message");
|
||||
match message {
|
||||
ReceivedPluginCallMessage::Error(error) => {
|
||||
check_test_io_error(&error);
|
||||
Ok(())
|
||||
}
|
||||
_ => panic!("received something other than an error: {message:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn manager_consume_all_propagates_message_error_to_plugin_calls() -> Result<(), ShellError> {
|
||||
let mut test = TestCase::new();
|
||||
let mut manager = test.plugin("test");
|
||||
let interface = manager.get_interface();
|
||||
|
||||
test.add(invalid_output());
|
||||
|
||||
// Set up a fake plugin call subscription
|
||||
let rx = fake_plugin_call(&mut manager, 0);
|
||||
|
||||
manager
|
||||
.consume_all(&mut test)
|
||||
.expect_err("consume_all did not error");
|
||||
|
||||
// We have to hold interface until now otherwise consume_all won't try to process the message
|
||||
drop(interface);
|
||||
|
||||
let message = rx.try_recv().expect("failed to get plugin call message");
|
||||
match message {
|
||||
ReceivedPluginCallMessage::Error(error) => {
|
||||
check_invalid_output_error(&error);
|
||||
Ok(())
|
||||
}
|
||||
_ => panic!("received something other than an error: {message:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn manager_consume_sets_protocol_info_on_hello() -> Result<(), ShellError> {
|
||||
let mut manager = TestCase::new().plugin("test");
|
||||
|
||||
let info = ProtocolInfo::default();
|
||||
|
||||
manager.consume(PluginOutput::Hello(info.clone()))?;
|
||||
|
||||
let set_info = manager
|
||||
.protocol_info
|
||||
.as_ref()
|
||||
.expect("protocol info not set");
|
||||
assert_eq!(info.version, set_info.version);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn manager_consume_errors_on_wrong_nushell_version() -> Result<(), ShellError> {
|
||||
let mut manager = TestCase::new().plugin("test");
|
||||
|
||||
let info = ProtocolInfo {
|
||||
protocol: Protocol::NuPlugin,
|
||||
version: "0.0.0".into(),
|
||||
features: vec![],
|
||||
};
|
||||
|
||||
manager
|
||||
.consume(PluginOutput::Hello(info))
|
||||
.expect_err("version 0.0.0 should cause an error");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn manager_consume_errors_on_sending_other_messages_before_hello() -> Result<(), ShellError> {
|
||||
let mut manager = TestCase::new().plugin("test");
|
||||
|
||||
// hello not set
|
||||
assert!(manager.protocol_info.is_none());
|
||||
|
||||
let error = manager
|
||||
.consume(PluginOutput::Stream(StreamMessage::Drop(0)))
|
||||
.expect_err("consume before Hello should cause an error");
|
||||
|
||||
assert!(format!("{error:?}").contains("Hello"));
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn manager_consume_call_response_forwards_to_subscriber_with_pipeline_data(
|
||||
) -> Result<(), ShellError> {
|
||||
let mut manager = TestCase::new().plugin("test");
|
||||
manager.protocol_info = Some(ProtocolInfo::default());
|
||||
|
||||
let rx = fake_plugin_call(&mut manager, 0);
|
||||
|
||||
manager.consume(PluginOutput::CallResponse(
|
||||
0,
|
||||
PluginCallResponse::PipelineData(PipelineDataHeader::ListStream(ListStreamInfo { id: 0 })),
|
||||
))?;
|
||||
|
||||
for i in 0..2 {
|
||||
manager.consume(PluginOutput::Stream(StreamMessage::Data(
|
||||
0,
|
||||
Value::test_int(i).into(),
|
||||
)))?;
|
||||
}
|
||||
|
||||
manager.consume(PluginOutput::Stream(StreamMessage::End(0)))?;
|
||||
|
||||
// Make sure the streams end and we don't deadlock
|
||||
drop(manager);
|
||||
|
||||
let message = rx
|
||||
.try_recv()
|
||||
.expect("failed to get plugin call response message");
|
||||
|
||||
match message {
|
||||
ReceivedPluginCallMessage::Response(response) => match response {
|
||||
PluginCallResponse::PipelineData(data) => {
|
||||
// Ensure we manage to receive the stream messages
|
||||
assert_eq!(2, data.into_iter().count());
|
||||
Ok(())
|
||||
}
|
||||
_ => panic!("unexpected response: {response:?}"),
|
||||
},
|
||||
_ => panic!("unexpected response message: {message:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn manager_prepare_pipeline_data_adds_source_to_values() -> Result<(), ShellError> {
|
||||
let manager = TestCase::new().plugin("test");
|
||||
|
||||
let data = manager.prepare_pipeline_data(PipelineData::Value(
|
||||
Value::test_custom_value(Box::new(test_plugin_custom_value())),
|
||||
None,
|
||||
))?;
|
||||
|
||||
let value = data
|
||||
.into_iter()
|
||||
.next()
|
||||
.expect("prepared pipeline data is empty");
|
||||
let custom_value: &PluginCustomValue = value
|
||||
.as_custom_value()?
|
||||
.as_any()
|
||||
.downcast_ref()
|
||||
.expect("custom value is not a PluginCustomValue");
|
||||
|
||||
if let Some(source) = &custom_value.source {
|
||||
assert_eq!("test", source.plugin_name);
|
||||
} else {
|
||||
panic!("source was not set");
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn manager_prepare_pipeline_data_adds_source_to_list_streams() -> Result<(), ShellError> {
|
||||
let manager = TestCase::new().plugin("test");
|
||||
|
||||
let data = manager.prepare_pipeline_data(
|
||||
[Value::test_custom_value(Box::new(
|
||||
test_plugin_custom_value(),
|
||||
))]
|
||||
.into_pipeline_data(None),
|
||||
)?;
|
||||
|
||||
let value = data
|
||||
.into_iter()
|
||||
.next()
|
||||
.expect("prepared pipeline data is empty");
|
||||
let custom_value: &PluginCustomValue = value
|
||||
.as_custom_value()?
|
||||
.as_any()
|
||||
.downcast_ref()
|
||||
.expect("custom value is not a PluginCustomValue");
|
||||
|
||||
if let Some(source) = &custom_value.source {
|
||||
assert_eq!("test", source.plugin_name);
|
||||
} else {
|
||||
panic!("source was not set");
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn interface_hello_sends_protocol_info() -> Result<(), ShellError> {
|
||||
let test = TestCase::new();
|
||||
let interface = test.plugin("test").get_interface();
|
||||
interface.hello()?;
|
||||
|
||||
let written = test.next_written().expect("nothing written");
|
||||
|
||||
match written {
|
||||
PluginInput::Hello(info) => {
|
||||
assert_eq!(ProtocolInfo::default().version, info.version);
|
||||
}
|
||||
_ => panic!("unexpected message written: {written:?}"),
|
||||
}
|
||||
|
||||
assert!(!test.has_unconsumed_write());
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn interface_write_plugin_call_registers_subscription() -> Result<(), ShellError> {
|
||||
let mut manager = TestCase::new().plugin("test");
|
||||
assert!(
|
||||
manager.plugin_call_subscriptions.is_empty(),
|
||||
"plugin call subscriptions not empty before start of test"
|
||||
);
|
||||
|
||||
let interface = manager.get_interface();
|
||||
let _ = interface.write_plugin_call(PluginCall::Signature, None)?;
|
||||
|
||||
manager.receive_plugin_call_subscriptions();
|
||||
assert!(
|
||||
!manager.plugin_call_subscriptions.is_empty(),
|
||||
"not registered"
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn interface_write_plugin_call_writes_signature() -> Result<(), ShellError> {
|
||||
let test = TestCase::new();
|
||||
let manager = test.plugin("test");
|
||||
let interface = manager.get_interface();
|
||||
|
||||
let (writer, _) = interface.write_plugin_call(PluginCall::Signature, None)?;
|
||||
writer.write()?;
|
||||
|
||||
let written = test.next_written().expect("nothing written");
|
||||
match written {
|
||||
PluginInput::Call(_, call) => assert!(
|
||||
matches!(call, PluginCall::Signature),
|
||||
"not Signature: {call:?}"
|
||||
),
|
||||
_ => panic!("unexpected message written: {written:?}"),
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn interface_write_plugin_call_writes_custom_value_op() -> Result<(), ShellError> {
|
||||
let test = TestCase::new();
|
||||
let manager = test.plugin("test");
|
||||
let interface = manager.get_interface();
|
||||
|
||||
let (writer, _) = interface.write_plugin_call(
|
||||
PluginCall::CustomValueOp(
|
||||
Spanned {
|
||||
item: test_plugin_custom_value(),
|
||||
span: Span::test_data(),
|
||||
},
|
||||
CustomValueOp::ToBaseValue,
|
||||
),
|
||||
None,
|
||||
)?;
|
||||
writer.write()?;
|
||||
|
||||
let written = test.next_written().expect("nothing written");
|
||||
match written {
|
||||
PluginInput::Call(_, call) => assert!(
|
||||
matches!(
|
||||
call,
|
||||
PluginCall::CustomValueOp(_, CustomValueOp::ToBaseValue)
|
||||
),
|
||||
"expected CustomValueOp(_, ToBaseValue), got {call:?}"
|
||||
),
|
||||
_ => panic!("unexpected message written: {written:?}"),
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn interface_write_plugin_call_writes_run_with_value_input() -> Result<(), ShellError> {
|
||||
let test = TestCase::new();
|
||||
let manager = test.plugin("test");
|
||||
let interface = manager.get_interface();
|
||||
|
||||
let (writer, _) = interface.write_plugin_call(
|
||||
PluginCall::Run(CallInfo {
|
||||
name: "foo".into(),
|
||||
call: EvaluatedCall {
|
||||
head: Span::test_data(),
|
||||
positional: vec![],
|
||||
named: vec![],
|
||||
},
|
||||
input: PipelineData::Value(Value::test_int(-1), None),
|
||||
config: None,
|
||||
}),
|
||||
None,
|
||||
)?;
|
||||
writer.write()?;
|
||||
|
||||
let written = test.next_written().expect("nothing written");
|
||||
match written {
|
||||
PluginInput::Call(_, call) => match call {
|
||||
PluginCall::Run(CallInfo { name, input, .. }) => {
|
||||
assert_eq!("foo", name);
|
||||
match input {
|
||||
PipelineDataHeader::Value(value) => assert_eq!(-1, value.as_int()?),
|
||||
_ => panic!("unexpected input header: {input:?}"),
|
||||
}
|
||||
}
|
||||
_ => panic!("unexpected Call: {call:?}"),
|
||||
},
|
||||
_ => panic!("unexpected message written: {written:?}"),
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn interface_write_plugin_call_writes_run_with_stream_input() -> Result<(), ShellError> {
|
||||
let test = TestCase::new();
|
||||
let manager = test.plugin("test");
|
||||
let interface = manager.get_interface();
|
||||
|
||||
let values = vec![Value::test_int(1), Value::test_int(2)];
|
||||
let (writer, _) = interface.write_plugin_call(
|
||||
PluginCall::Run(CallInfo {
|
||||
name: "foo".into(),
|
||||
call: EvaluatedCall {
|
||||
head: Span::test_data(),
|
||||
positional: vec![],
|
||||
named: vec![],
|
||||
},
|
||||
input: values.clone().into_pipeline_data(None),
|
||||
config: None,
|
||||
}),
|
||||
None,
|
||||
)?;
|
||||
writer.write()?;
|
||||
|
||||
let written = test.next_written().expect("nothing written");
|
||||
let info = match written {
|
||||
PluginInput::Call(_, call) => match call {
|
||||
PluginCall::Run(CallInfo { name, input, .. }) => {
|
||||
assert_eq!("foo", name);
|
||||
match input {
|
||||
PipelineDataHeader::ListStream(info) => info,
|
||||
_ => panic!("unexpected input header: {input:?}"),
|
||||
}
|
||||
}
|
||||
_ => panic!("unexpected Call: {call:?}"),
|
||||
},
|
||||
_ => panic!("unexpected message written: {written:?}"),
|
||||
};
|
||||
|
||||
// Expect stream messages
|
||||
for value in values {
|
||||
match test
|
||||
.next_written()
|
||||
.expect("failed to get Data stream message")
|
||||
{
|
||||
PluginInput::Stream(StreamMessage::Data(id, data)) => {
|
||||
assert_eq!(info.id, id, "id");
|
||||
match data {
|
||||
StreamData::List(data_value) => {
|
||||
assert_eq!(value, data_value, "wrong value in Data message");
|
||||
}
|
||||
_ => panic!("not List stream data: {data:?}"),
|
||||
}
|
||||
}
|
||||
message => panic!("expected Stream(Data(..)) message: {message:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
match test
|
||||
.next_written()
|
||||
.expect("failed to get End stream message")
|
||||
{
|
||||
PluginInput::Stream(StreamMessage::End(id)) => {
|
||||
assert_eq!(info.id, id, "id");
|
||||
}
|
||||
message => panic!("expected Stream(End(_)) message: {message:?}"),
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn interface_receive_plugin_call_receives_response() -> Result<(), ShellError> {
|
||||
let interface = TestCase::new().plugin("test").get_interface();
|
||||
|
||||
// Set up a fake channel that has the response in it
|
||||
let (tx, rx) = mpsc::channel();
|
||||
tx.send(ReceivedPluginCallMessage::Response(
|
||||
PluginCallResponse::Signature(vec![]),
|
||||
))
|
||||
.expect("failed to send on new channel");
|
||||
drop(tx); // so we don't deadlock on recv()
|
||||
|
||||
let response = interface.receive_plugin_call_response(rx)?;
|
||||
assert!(
|
||||
matches!(response, PluginCallResponse::Signature(_)),
|
||||
"wrong response: {response:?}"
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn interface_receive_plugin_call_receives_error() -> Result<(), ShellError> {
|
||||
let interface = TestCase::new().plugin("test").get_interface();
|
||||
|
||||
// Set up a fake channel that has the error in it
|
||||
let (tx, rx) = mpsc::channel();
|
||||
tx.send(ReceivedPluginCallMessage::Error(
|
||||
ShellError::ExternalNotSupported {
|
||||
span: Span::test_data(),
|
||||
},
|
||||
))
|
||||
.expect("failed to send on new channel");
|
||||
drop(tx); // so we don't deadlock on recv()
|
||||
|
||||
let error = interface
|
||||
.receive_plugin_call_response(rx)
|
||||
.expect_err("did not receive error");
|
||||
assert!(
|
||||
matches!(error, ShellError::ExternalNotSupported { .. }),
|
||||
"wrong error: {error:?}"
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Fake responses to requests for plugin call messages
|
||||
fn start_fake_plugin_call_responder(
|
||||
manager: PluginInterfaceManager,
|
||||
take: usize,
|
||||
mut f: impl FnMut(PluginCallId) -> Vec<ReceivedPluginCallMessage> + Send + 'static,
|
||||
) {
|
||||
std::thread::Builder::new()
|
||||
.name("fake plugin call responder".into())
|
||||
.spawn(move || {
|
||||
for (id, sub) in manager
|
||||
.plugin_call_subscription_receiver
|
||||
.into_iter()
|
||||
.take(take)
|
||||
{
|
||||
for message in f(id) {
|
||||
sub.sender.send(message).expect("failed to send");
|
||||
}
|
||||
}
|
||||
})
|
||||
.expect("failed to spawn thread");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn interface_get_signature() -> Result<(), ShellError> {
|
||||
let test = TestCase::new();
|
||||
let manager = test.plugin("test");
|
||||
let interface = manager.get_interface();
|
||||
|
||||
start_fake_plugin_call_responder(manager, 1, |_| {
|
||||
vec![ReceivedPluginCallMessage::Response(
|
||||
PluginCallResponse::Signature(vec![PluginSignature::build("test")]),
|
||||
)]
|
||||
});
|
||||
|
||||
let signatures = interface.get_signature()?;
|
||||
|
||||
assert_eq!(1, signatures.len());
|
||||
assert!(test.has_unconsumed_write());
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn interface_run() -> Result<(), ShellError> {
|
||||
let test = TestCase::new();
|
||||
let manager = test.plugin("test");
|
||||
let interface = manager.get_interface();
|
||||
let number = 64;
|
||||
|
||||
start_fake_plugin_call_responder(manager, 1, move |_| {
|
||||
vec![ReceivedPluginCallMessage::Response(
|
||||
PluginCallResponse::PipelineData(PipelineData::Value(Value::test_int(number), None)),
|
||||
)]
|
||||
});
|
||||
|
||||
let result = interface.run(
|
||||
CallInfo {
|
||||
name: "bogus".into(),
|
||||
call: EvaluatedCall {
|
||||
head: Span::test_data(),
|
||||
positional: vec![],
|
||||
named: vec![],
|
||||
},
|
||||
input: PipelineData::Empty,
|
||||
config: None,
|
||||
},
|
||||
PluginExecutionBogusContext.into(),
|
||||
)?;
|
||||
|
||||
assert_eq!(
|
||||
Value::test_int(number),
|
||||
result.into_value(Span::test_data())
|
||||
);
|
||||
assert!(test.has_unconsumed_write());
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn interface_custom_value_to_base_value() -> Result<(), ShellError> {
|
||||
let test = TestCase::new();
|
||||
let manager = test.plugin("test");
|
||||
let interface = manager.get_interface();
|
||||
let string = "this is a test";
|
||||
|
||||
start_fake_plugin_call_responder(manager, 1, move |_| {
|
||||
vec![ReceivedPluginCallMessage::Response(
|
||||
PluginCallResponse::PipelineData(PipelineData::Value(Value::test_string(string), None)),
|
||||
)]
|
||||
});
|
||||
|
||||
let result = interface.custom_value_to_base_value(Spanned {
|
||||
item: test_plugin_custom_value(),
|
||||
span: Span::test_data(),
|
||||
})?;
|
||||
|
||||
assert_eq!(Value::test_string(string), result);
|
||||
assert!(test.has_unconsumed_write());
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn normal_values(interface: &PluginInterface) -> Vec<Value> {
|
||||
vec![
|
||||
Value::test_int(5),
|
||||
Value::test_custom_value(Box::new(PluginCustomValue {
|
||||
name: "SomeTest".into(),
|
||||
data: vec![1, 2, 3],
|
||||
// Has the same source, so it should be accepted
|
||||
source: Some(interface.state.identity.clone()),
|
||||
})),
|
||||
]
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn interface_prepare_pipeline_data_accepts_normal_values() -> Result<(), ShellError> {
|
||||
let interface = TestCase::new().plugin("test").get_interface();
|
||||
for value in normal_values(&interface) {
|
||||
match interface.prepare_pipeline_data(PipelineData::Value(value.clone(), None)) {
|
||||
Ok(data) => assert_eq!(
|
||||
value.get_type(),
|
||||
data.into_value(Span::test_data()).get_type()
|
||||
),
|
||||
Err(err) => panic!("failed to accept {value:?}: {err}"),
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn interface_prepare_pipeline_data_accepts_normal_streams() -> Result<(), ShellError> {
|
||||
let interface = TestCase::new().plugin("test").get_interface();
|
||||
let values = normal_values(&interface);
|
||||
let data = interface.prepare_pipeline_data(values.clone().into_pipeline_data(None))?;
|
||||
|
||||
let mut count = 0;
|
||||
for (expected_value, actual_value) in values.iter().zip(data) {
|
||||
assert!(
|
||||
!actual_value.is_error(),
|
||||
"error value instead of {expected_value:?} in stream: {actual_value:?}"
|
||||
);
|
||||
assert_eq!(expected_value.get_type(), actual_value.get_type());
|
||||
count += 1;
|
||||
}
|
||||
assert_eq!(
|
||||
values.len(),
|
||||
count,
|
||||
"didn't receive as many values as expected"
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn bad_custom_values() -> Vec<Value> {
|
||||
// These shouldn't be accepted
|
||||
vec![
|
||||
// Native custom value (not PluginCustomValue) should be rejected
|
||||
Value::test_custom_value(Box::new(expected_test_custom_value())),
|
||||
// Has no source, so it should be rejected
|
||||
Value::test_custom_value(Box::new(PluginCustomValue {
|
||||
name: "SomeTest".into(),
|
||||
data: vec![1, 2, 3],
|
||||
source: None,
|
||||
})),
|
||||
// Has a different source, so it should be rejected
|
||||
Value::test_custom_value(Box::new(PluginCustomValue {
|
||||
name: "SomeTest".into(),
|
||||
data: vec![1, 2, 3],
|
||||
source: Some(PluginIdentity::new_fake("pluto")),
|
||||
})),
|
||||
]
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn interface_prepare_pipeline_data_rejects_bad_custom_value() -> Result<(), ShellError> {
|
||||
let interface = TestCase::new().plugin("test").get_interface();
|
||||
for value in bad_custom_values() {
|
||||
match interface.prepare_pipeline_data(PipelineData::Value(value.clone(), None)) {
|
||||
Err(err) => match err {
|
||||
ShellError::CustomValueIncorrectForPlugin { .. } => (),
|
||||
_ => panic!("expected error type CustomValueIncorrectForPlugin, but got {err:?}"),
|
||||
},
|
||||
Ok(_) => panic!("mistakenly accepted {value:?}"),
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn interface_prepare_pipeline_data_rejects_bad_custom_value_in_a_stream() -> Result<(), ShellError>
|
||||
{
|
||||
let interface = TestCase::new().plugin("test").get_interface();
|
||||
let values = bad_custom_values();
|
||||
let data = interface.prepare_pipeline_data(values.clone().into_pipeline_data(None))?;
|
||||
|
||||
let mut count = 0;
|
||||
for value in data {
|
||||
assert!(value.is_error(), "expected error value for {value:?}");
|
||||
count += 1;
|
||||
}
|
||||
assert_eq!(
|
||||
values.len(),
|
||||
count,
|
||||
"didn't receive as many values as expected"
|
||||
);
|
||||
Ok(())
|
||||
}
|
621
crates/nu-plugin/src/plugin/interface/stream.rs
Normal file
621
crates/nu-plugin/src/plugin/interface/stream.rs
Normal file
@ -0,0 +1,621 @@
|
||||
use std::{
|
||||
collections::{btree_map, BTreeMap},
|
||||
iter::FusedIterator,
|
||||
marker::PhantomData,
|
||||
sync::{mpsc, Arc, Condvar, Mutex, MutexGuard, Weak},
|
||||
};
|
||||
|
||||
use nu_protocol::{ShellError, Span, Value};
|
||||
|
||||
use crate::protocol::{StreamData, StreamId, StreamMessage};
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests;
|
||||
|
||||
/// Receives messages from a stream read from input by a [`StreamManager`].
|
||||
///
|
||||
/// The receiver reads for messages of type `Result<Option<StreamData>, ShellError>` from the
|
||||
/// channel, which is managed by a [`StreamManager`]. Signalling for end-of-stream is explicit
|
||||
/// through `Ok(Some)`.
|
||||
///
|
||||
/// Failing to receive is an error. When end-of-stream is received, the `receiver` is set to `None`
|
||||
/// and all further calls to `next()` return `None`.
|
||||
///
|
||||
/// The type `T` must implement [`FromShellError`], so that errors in the stream can be represented,
|
||||
/// and `TryFrom<StreamData>` to convert it to the correct type.
|
||||
///
|
||||
/// For each message read, it sends [`StreamMessage::Ack`] to the writer. When dropped,
|
||||
/// it sends [`StreamMessage::Drop`].
|
||||
#[derive(Debug)]
|
||||
pub(crate) struct StreamReader<T, W>
|
||||
where
|
||||
W: WriteStreamMessage,
|
||||
{
|
||||
id: StreamId,
|
||||
receiver: Option<mpsc::Receiver<Result<Option<StreamData>, ShellError>>>,
|
||||
writer: W,
|
||||
/// Iterator requires the item type to be fixed, so we have to keep it as part of the type,
|
||||
/// even though we're actually receiving dynamic data.
|
||||
marker: PhantomData<fn() -> T>,
|
||||
}
|
||||
|
||||
impl<T, W> StreamReader<T, W>
|
||||
where
|
||||
T: TryFrom<StreamData, Error = ShellError>,
|
||||
W: WriteStreamMessage,
|
||||
{
|
||||
/// Create a new StreamReader from parts
|
||||
pub(crate) fn new(
|
||||
id: StreamId,
|
||||
receiver: mpsc::Receiver<Result<Option<StreamData>, ShellError>>,
|
||||
writer: W,
|
||||
) -> StreamReader<T, W> {
|
||||
StreamReader {
|
||||
id,
|
||||
receiver: Some(receiver),
|
||||
writer,
|
||||
marker: PhantomData,
|
||||
}
|
||||
}
|
||||
|
||||
/// Receive a message from the channel, or return an error if:
|
||||
///
|
||||
/// * the channel couldn't be received from
|
||||
/// * an error was sent on the channel
|
||||
/// * the message received couldn't be converted to `T`
|
||||
pub(crate) fn recv(&mut self) -> Result<Option<T>, ShellError> {
|
||||
let connection_lost = || ShellError::GenericError {
|
||||
error: "Stream ended unexpectedly".into(),
|
||||
msg: "connection lost before explicit end of stream".into(),
|
||||
span: None,
|
||||
help: None,
|
||||
inner: vec![],
|
||||
};
|
||||
|
||||
if let Some(ref rx) = self.receiver {
|
||||
// Try to receive a message first
|
||||
let msg = match rx.try_recv() {
|
||||
Ok(msg) => msg?,
|
||||
Err(mpsc::TryRecvError::Empty) => {
|
||||
// The receiver doesn't have any messages waiting for us. It's possible that the
|
||||
// other side hasn't seen our acknowledgements. Let's flush the writer and then
|
||||
// wait
|
||||
self.writer.flush()?;
|
||||
rx.recv().map_err(|_| connection_lost())??
|
||||
}
|
||||
Err(mpsc::TryRecvError::Disconnected) => return Err(connection_lost()),
|
||||
};
|
||||
|
||||
if let Some(data) = msg {
|
||||
// Acknowledge the message
|
||||
self.writer
|
||||
.write_stream_message(StreamMessage::Ack(self.id))?;
|
||||
// Try to convert it into the correct type
|
||||
Ok(Some(data.try_into()?))
|
||||
} else {
|
||||
// Remove the receiver, so that future recv() calls always return Ok(None)
|
||||
self.receiver = None;
|
||||
Ok(None)
|
||||
}
|
||||
} else {
|
||||
// Closed already
|
||||
Ok(None)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T, W> Iterator for StreamReader<T, W>
|
||||
where
|
||||
T: FromShellError + TryFrom<StreamData, Error = ShellError>,
|
||||
W: WriteStreamMessage,
|
||||
{
|
||||
type Item = T;
|
||||
|
||||
fn next(&mut self) -> Option<T> {
|
||||
// Converting the error to the value here makes the implementation a lot easier
|
||||
self.recv()
|
||||
.unwrap_or_else(|err| Some(T::from_shell_error(err)))
|
||||
}
|
||||
}
|
||||
|
||||
// Guaranteed not to return anything after the end
|
||||
impl<T, W> FusedIterator for StreamReader<T, W>
|
||||
where
|
||||
T: FromShellError + TryFrom<StreamData, Error = ShellError>,
|
||||
W: WriteStreamMessage,
|
||||
{
|
||||
}
|
||||
|
||||
impl<T, W> Drop for StreamReader<T, W>
|
||||
where
|
||||
W: WriteStreamMessage,
|
||||
{
|
||||
fn drop(&mut self) {
|
||||
if let Err(err) = self
|
||||
.writer
|
||||
.write_stream_message(StreamMessage::Drop(self.id))
|
||||
.and_then(|_| self.writer.flush())
|
||||
{
|
||||
log::warn!("Failed to send message to drop stream: {err}");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Values that can contain a `ShellError` to signal an error has occurred.
|
||||
pub(crate) trait FromShellError {
|
||||
fn from_shell_error(err: ShellError) -> Self;
|
||||
}
|
||||
|
||||
// For List streams.
|
||||
impl FromShellError for Value {
|
||||
fn from_shell_error(err: ShellError) -> Self {
|
||||
Value::error(err, Span::unknown())
|
||||
}
|
||||
}
|
||||
|
||||
// For Raw streams, mostly.
|
||||
impl<T> FromShellError for Result<T, ShellError> {
|
||||
fn from_shell_error(err: ShellError) -> Self {
|
||||
Err(err)
|
||||
}
|
||||
}
|
||||
|
||||
/// Writes messages to a stream, with flow control.
|
||||
///
|
||||
/// The `signal` contained
|
||||
#[derive(Debug)]
|
||||
pub(crate) struct StreamWriter<W: WriteStreamMessage> {
|
||||
id: StreamId,
|
||||
signal: Arc<StreamWriterSignal>,
|
||||
writer: W,
|
||||
ended: bool,
|
||||
}
|
||||
|
||||
impl<W> StreamWriter<W>
|
||||
where
|
||||
W: WriteStreamMessage,
|
||||
{
|
||||
pub(crate) fn new(id: StreamId, signal: Arc<StreamWriterSignal>, writer: W) -> StreamWriter<W> {
|
||||
StreamWriter {
|
||||
id,
|
||||
signal,
|
||||
writer,
|
||||
ended: false,
|
||||
}
|
||||
}
|
||||
|
||||
/// Check if the stream was dropped from the other end. Recommended to do this before calling
|
||||
/// [`.write()`], especially in a loop.
|
||||
pub(crate) fn is_dropped(&self) -> Result<bool, ShellError> {
|
||||
self.signal.is_dropped()
|
||||
}
|
||||
|
||||
/// Write a single piece of data to the stream.
|
||||
///
|
||||
/// Error if something failed with the write, or if [`.end()`] was already called
|
||||
/// previously.
|
||||
pub(crate) fn write(&mut self, data: impl Into<StreamData>) -> Result<(), ShellError> {
|
||||
if !self.ended {
|
||||
self.writer
|
||||
.write_stream_message(StreamMessage::Data(self.id, data.into()))?;
|
||||
// This implements flow control, so we don't write too many messages:
|
||||
if !self.signal.notify_sent()? {
|
||||
// Flush the output, and then wait for acknowledgements
|
||||
self.writer.flush()?;
|
||||
self.signal.wait_for_drain()
|
||||
} else {
|
||||
Ok(())
|
||||
}
|
||||
} else {
|
||||
Err(ShellError::GenericError {
|
||||
error: "Wrote to a stream after it ended".into(),
|
||||
msg: format!(
|
||||
"tried to write to stream {} after it was already ended",
|
||||
self.id
|
||||
),
|
||||
span: None,
|
||||
help: Some("this may be a bug in the nu-plugin crate".into()),
|
||||
inner: vec![],
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
/// Write a full iterator to the stream. Note that this doesn't end the stream, so you should
|
||||
/// still call [`.end()`].
|
||||
///
|
||||
/// If the stream is dropped from the other end, the iterator will not be fully consumed, and
|
||||
/// writing will terminate.
|
||||
///
|
||||
/// Returns `Ok(true)` if the iterator was fully consumed, or `Ok(false)` if a drop interrupted
|
||||
/// the stream from the other side.
|
||||
pub(crate) fn write_all<T>(
|
||||
&mut self,
|
||||
data: impl IntoIterator<Item = T>,
|
||||
) -> Result<bool, ShellError>
|
||||
where
|
||||
T: Into<StreamData>,
|
||||
{
|
||||
// Check before starting
|
||||
if self.is_dropped()? {
|
||||
return Ok(false);
|
||||
}
|
||||
|
||||
for item in data {
|
||||
// Check again after each item is consumed from the iterator, just in case the iterator
|
||||
// takes a while to produce a value
|
||||
if self.is_dropped()? {
|
||||
return Ok(false);
|
||||
}
|
||||
self.write(item)?;
|
||||
}
|
||||
Ok(true)
|
||||
}
|
||||
|
||||
/// End the stream. Recommend doing this instead of relying on `Drop` so that you can catch the
|
||||
/// error.
|
||||
pub(crate) fn end(&mut self) -> Result<(), ShellError> {
|
||||
if !self.ended {
|
||||
// Set the flag first so we don't double-report in the Drop
|
||||
self.ended = true;
|
||||
self.writer
|
||||
.write_stream_message(StreamMessage::End(self.id))?;
|
||||
self.writer.flush()
|
||||
} else {
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<W> Drop for StreamWriter<W>
|
||||
where
|
||||
W: WriteStreamMessage,
|
||||
{
|
||||
fn drop(&mut self) {
|
||||
// Make sure we ended the stream
|
||||
if let Err(err) = self.end() {
|
||||
log::warn!("Error while ending stream in Drop for StreamWriter: {err}");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Stores stream state for a writer, and can be blocked on to wait for messages to be acknowledged.
|
||||
/// A key part of managing stream lifecycle and flow control.
|
||||
#[derive(Debug)]
|
||||
pub(crate) struct StreamWriterSignal {
|
||||
mutex: Mutex<StreamWriterSignalState>,
|
||||
change_cond: Condvar,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub(crate) struct StreamWriterSignalState {
|
||||
/// Stream has been dropped and consumer is no longer interested in any messages.
|
||||
dropped: bool,
|
||||
/// Number of messages that have been sent without acknowledgement.
|
||||
unacknowledged: i32,
|
||||
/// Max number of messages to send before waiting for acknowledgement.
|
||||
high_pressure_mark: i32,
|
||||
}
|
||||
|
||||
impl StreamWriterSignal {
|
||||
/// Create a new signal.
|
||||
///
|
||||
/// If `notify_sent()` is called more than `high_pressure_mark` times, it will wait until
|
||||
/// `notify_acknowledge()` is called by another thread enough times to bring the number of
|
||||
/// unacknowledged sent messages below that threshold.
|
||||
pub fn new(high_pressure_mark: i32) -> StreamWriterSignal {
|
||||
assert!(high_pressure_mark > 0);
|
||||
|
||||
StreamWriterSignal {
|
||||
mutex: Mutex::new(StreamWriterSignalState {
|
||||
dropped: false,
|
||||
unacknowledged: 0,
|
||||
high_pressure_mark,
|
||||
}),
|
||||
change_cond: Condvar::new(),
|
||||
}
|
||||
}
|
||||
|
||||
fn lock(&self) -> Result<MutexGuard<StreamWriterSignalState>, ShellError> {
|
||||
self.mutex.lock().map_err(|_| ShellError::NushellFailed {
|
||||
msg: "StreamWriterSignal mutex poisoned due to panic".into(),
|
||||
})
|
||||
}
|
||||
|
||||
/// True if the stream was dropped and the consumer is no longer interested in it. Indicates
|
||||
/// that no more messages should be sent, other than `End`.
|
||||
pub fn is_dropped(&self) -> Result<bool, ShellError> {
|
||||
Ok(self.lock()?.dropped)
|
||||
}
|
||||
|
||||
/// Notify the writers that the stream has been dropped, so they can stop writing.
|
||||
pub fn set_dropped(&self) -> Result<(), ShellError> {
|
||||
let mut state = self.lock()?;
|
||||
state.dropped = true;
|
||||
// Unblock the writers so they can terminate
|
||||
self.change_cond.notify_all();
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Track that a message has been sent. Returns `Ok(true)` if more messages can be sent,
|
||||
/// or `Ok(false)` if the high pressure mark has been reached and [`.wait_for_drain()`] should
|
||||
/// be called to block.
|
||||
pub fn notify_sent(&self) -> Result<bool, ShellError> {
|
||||
let mut state = self.lock()?;
|
||||
state.unacknowledged =
|
||||
state
|
||||
.unacknowledged
|
||||
.checked_add(1)
|
||||
.ok_or_else(|| ShellError::NushellFailed {
|
||||
msg: "Overflow in counter: too many unacknowledged messages".into(),
|
||||
})?;
|
||||
|
||||
Ok(state.unacknowledged < state.high_pressure_mark)
|
||||
}
|
||||
|
||||
/// Wait for acknowledgements before sending more data. Also returns if the stream is dropped.
|
||||
pub fn wait_for_drain(&self) -> Result<(), ShellError> {
|
||||
let mut state = self.lock()?;
|
||||
while !state.dropped && state.unacknowledged >= state.high_pressure_mark {
|
||||
state = self
|
||||
.change_cond
|
||||
.wait(state)
|
||||
.map_err(|_| ShellError::NushellFailed {
|
||||
msg: "StreamWriterSignal mutex poisoned due to panic".into(),
|
||||
})?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Notify the writers that a message has been acknowledged, so they can continue to write
|
||||
/// if they were waiting.
|
||||
pub fn notify_acknowledged(&self) -> Result<(), ShellError> {
|
||||
let mut state = self.lock()?;
|
||||
state.unacknowledged =
|
||||
state
|
||||
.unacknowledged
|
||||
.checked_sub(1)
|
||||
.ok_or_else(|| ShellError::NushellFailed {
|
||||
msg: "Underflow in counter: too many message acknowledgements".into(),
|
||||
})?;
|
||||
// Unblock the writer
|
||||
self.change_cond.notify_one();
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
/// A sink for a [`StreamMessage`]
|
||||
pub(crate) trait WriteStreamMessage {
|
||||
fn write_stream_message(&mut self, msg: StreamMessage) -> Result<(), ShellError>;
|
||||
fn flush(&mut self) -> Result<(), ShellError>;
|
||||
}
|
||||
|
||||
#[derive(Debug, Default)]
|
||||
struct StreamManagerState {
|
||||
reading_streams: BTreeMap<StreamId, mpsc::Sender<Result<Option<StreamData>, ShellError>>>,
|
||||
writing_streams: BTreeMap<StreamId, Weak<StreamWriterSignal>>,
|
||||
}
|
||||
|
||||
impl StreamManagerState {
|
||||
/// Lock the state, or return a [`ShellError`] if the mutex is poisoned.
|
||||
fn lock(
|
||||
state: &Mutex<StreamManagerState>,
|
||||
) -> Result<MutexGuard<StreamManagerState>, ShellError> {
|
||||
state.lock().map_err(|_| ShellError::NushellFailed {
|
||||
msg: "StreamManagerState mutex poisoned due to a panic".into(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub(crate) struct StreamManager {
|
||||
state: Arc<Mutex<StreamManagerState>>,
|
||||
}
|
||||
|
||||
impl StreamManager {
|
||||
/// Create a new StreamManager.
|
||||
pub(crate) fn new() -> StreamManager {
|
||||
StreamManager {
|
||||
state: Default::default(),
|
||||
}
|
||||
}
|
||||
|
||||
fn lock(&self) -> Result<MutexGuard<StreamManagerState>, ShellError> {
|
||||
StreamManagerState::lock(&self.state)
|
||||
}
|
||||
|
||||
/// Create a new handle to the StreamManager for registering streams.
|
||||
pub(crate) fn get_handle(&self) -> StreamManagerHandle {
|
||||
StreamManagerHandle {
|
||||
state: Arc::downgrade(&self.state),
|
||||
}
|
||||
}
|
||||
|
||||
/// Process a stream message, and update internal state accordingly.
|
||||
pub(crate) fn handle_message(&self, message: StreamMessage) -> Result<(), ShellError> {
|
||||
let mut state = self.lock()?;
|
||||
match message {
|
||||
StreamMessage::Data(id, data) => {
|
||||
if let Some(sender) = state.reading_streams.get(&id) {
|
||||
// We should ignore the error on send. This just means the reader has dropped,
|
||||
// but it will have sent a Drop message to the other side, and we will receive
|
||||
// an End message at which point we can remove the channel.
|
||||
let _ = sender.send(Ok(Some(data)));
|
||||
Ok(())
|
||||
} else {
|
||||
Err(ShellError::PluginFailedToDecode {
|
||||
msg: format!("received Data for unknown stream {id}"),
|
||||
})
|
||||
}
|
||||
}
|
||||
StreamMessage::End(id) => {
|
||||
if let Some(sender) = state.reading_streams.remove(&id) {
|
||||
// We should ignore the error on the send, because the reader might have dropped
|
||||
// already
|
||||
let _ = sender.send(Ok(None));
|
||||
Ok(())
|
||||
} else {
|
||||
Err(ShellError::PluginFailedToDecode {
|
||||
msg: format!("received End for unknown stream {id}"),
|
||||
})
|
||||
}
|
||||
}
|
||||
StreamMessage::Drop(id) => {
|
||||
if let Some(signal) = state.writing_streams.remove(&id) {
|
||||
if let Some(signal) = signal.upgrade() {
|
||||
// This will wake blocked writers so they can stop writing, so it's ok
|
||||
signal.set_dropped()?;
|
||||
}
|
||||
}
|
||||
// It's possible that the stream has already finished writing and we don't have it
|
||||
// anymore, so we fall through to Ok
|
||||
Ok(())
|
||||
}
|
||||
StreamMessage::Ack(id) => {
|
||||
if let Some(signal) = state.writing_streams.get(&id) {
|
||||
if let Some(signal) = signal.upgrade() {
|
||||
// This will wake up a blocked writer
|
||||
signal.notify_acknowledged()?;
|
||||
} else {
|
||||
// We know it doesn't exist, so might as well remove it
|
||||
state.writing_streams.remove(&id);
|
||||
}
|
||||
}
|
||||
// It's possible that the stream has already finished writing and we don't have it
|
||||
// anymore, so we fall through to Ok
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Broadcast an error to all stream readers. This is useful for error propagation.
|
||||
pub(crate) fn broadcast_read_error(&self, error: ShellError) -> Result<(), ShellError> {
|
||||
let state = self.lock()?;
|
||||
for channel in state.reading_streams.values() {
|
||||
// Ignore send errors.
|
||||
let _ = channel.send(Err(error.clone()));
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// If the `StreamManager` is dropped, we should let all of the stream writers know that they
|
||||
// won't be able to write anymore. We don't need to do anything about the readers though
|
||||
// because they'll know when the `Sender` is dropped automatically
|
||||
fn drop_all_writers(&self) -> Result<(), ShellError> {
|
||||
let mut state = self.lock()?;
|
||||
let writers = std::mem::take(&mut state.writing_streams);
|
||||
for (_, signal) in writers {
|
||||
if let Some(signal) = signal.upgrade() {
|
||||
// more important that we send to all than handling an error
|
||||
let _ = signal.set_dropped();
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for StreamManager {
|
||||
fn drop(&mut self) {
|
||||
if let Err(err) = self.drop_all_writers() {
|
||||
log::warn!("error during Drop for StreamManager: {}", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// A [`StreamManagerHandle`] supports operations for interacting with the [`StreamManager`].
|
||||
///
|
||||
/// Streams can be registered for reading, returning a [`StreamReader`], or for writing, returning
|
||||
/// a [`StreamWriter`].
|
||||
#[derive(Debug, Clone)]
|
||||
pub(crate) struct StreamManagerHandle {
|
||||
state: Weak<Mutex<StreamManagerState>>,
|
||||
}
|
||||
|
||||
impl StreamManagerHandle {
|
||||
/// Because the handle only has a weak reference to the [`StreamManager`] state, we have to
|
||||
/// first try to upgrade to a strong reference and then lock. This function wraps those two
|
||||
/// operations together, handling errors appropriately.
|
||||
fn with_lock<T, F>(&self, f: F) -> Result<T, ShellError>
|
||||
where
|
||||
F: FnOnce(MutexGuard<StreamManagerState>) -> Result<T, ShellError>,
|
||||
{
|
||||
let upgraded = self
|
||||
.state
|
||||
.upgrade()
|
||||
.ok_or_else(|| ShellError::NushellFailed {
|
||||
msg: "StreamManager is no longer alive".into(),
|
||||
})?;
|
||||
let guard = upgraded.lock().map_err(|_| ShellError::NushellFailed {
|
||||
msg: "StreamManagerState mutex poisoned due to a panic".into(),
|
||||
})?;
|
||||
f(guard)
|
||||
}
|
||||
|
||||
/// Register a new stream for reading, and return a [`StreamReader`] that can be used to iterate
|
||||
/// on the values received. A [`StreamMessage`] writer is required for writing control messages
|
||||
/// back to the producer.
|
||||
pub(crate) fn read_stream<T, W>(
|
||||
&self,
|
||||
id: StreamId,
|
||||
writer: W,
|
||||
) -> Result<StreamReader<T, W>, ShellError>
|
||||
where
|
||||
T: TryFrom<StreamData, Error = ShellError>,
|
||||
W: WriteStreamMessage,
|
||||
{
|
||||
let (tx, rx) = mpsc::channel();
|
||||
self.with_lock(|mut state| {
|
||||
// Must be exclusive
|
||||
if let btree_map::Entry::Vacant(e) = state.reading_streams.entry(id) {
|
||||
e.insert(tx);
|
||||
Ok(())
|
||||
} else {
|
||||
Err(ShellError::GenericError {
|
||||
error: format!("Failed to acquire reader for stream {id}"),
|
||||
msg: "tried to get a reader for a stream that's already being read".into(),
|
||||
span: None,
|
||||
help: Some("this may be a bug in the nu-plugin crate".into()),
|
||||
inner: vec![],
|
||||
})
|
||||
}
|
||||
})?;
|
||||
Ok(StreamReader::new(id, rx, writer))
|
||||
}
|
||||
|
||||
/// Register a new stream for writing, and return a [`StreamWriter`] that can be used to send
|
||||
/// data to the stream.
|
||||
///
|
||||
/// The `high_pressure_mark` value controls how many messages can be written without receiving
|
||||
/// an acknowledgement before any further attempts to write will wait for the consumer to
|
||||
/// acknowledge them. This prevents overwhelming the reader.
|
||||
pub(crate) fn write_stream<W>(
|
||||
&self,
|
||||
id: StreamId,
|
||||
writer: W,
|
||||
high_pressure_mark: i32,
|
||||
) -> Result<StreamWriter<W>, ShellError>
|
||||
where
|
||||
W: WriteStreamMessage,
|
||||
{
|
||||
let signal = Arc::new(StreamWriterSignal::new(high_pressure_mark));
|
||||
self.with_lock(|mut state| {
|
||||
// Remove dead writing streams
|
||||
state
|
||||
.writing_streams
|
||||
.retain(|_, signal| signal.strong_count() > 0);
|
||||
// Must be exclusive
|
||||
if let btree_map::Entry::Vacant(e) = state.writing_streams.entry(id) {
|
||||
e.insert(Arc::downgrade(&signal));
|
||||
Ok(())
|
||||
} else {
|
||||
Err(ShellError::GenericError {
|
||||
error: format!("Failed to acquire writer for stream {id}"),
|
||||
msg: "tried to get a writer for a stream that's already being written".into(),
|
||||
span: None,
|
||||
help: Some("this may be a bug in the nu-plugin crate".into()),
|
||||
inner: vec![],
|
||||
})
|
||||
}
|
||||
})?;
|
||||
Ok(StreamWriter::new(id, signal, writer))
|
||||
}
|
||||
}
|
508
crates/nu-plugin/src/plugin/interface/stream/tests.rs
Normal file
508
crates/nu-plugin/src/plugin/interface/stream/tests.rs
Normal file
@ -0,0 +1,508 @@
|
||||
use std::{
|
||||
sync::{
|
||||
atomic::{AtomicBool, Ordering::Relaxed},
|
||||
mpsc, Arc,
|
||||
},
|
||||
time::Duration,
|
||||
};
|
||||
|
||||
use nu_protocol::{ShellError, Value};
|
||||
|
||||
use crate::protocol::{StreamData, StreamMessage};
|
||||
|
||||
use super::{StreamManager, StreamReader, StreamWriter, StreamWriterSignal, WriteStreamMessage};
|
||||
|
||||
// Should be long enough to definitely complete any quick operation, but not so long that tests are
|
||||
// slow to complete. 10 ms is a pretty long time
|
||||
const WAIT_DURATION: Duration = Duration::from_millis(10);
|
||||
|
||||
#[derive(Debug, Clone, Default)]
|
||||
struct TestSink(Vec<StreamMessage>);
|
||||
|
||||
impl WriteStreamMessage for TestSink {
|
||||
fn write_stream_message(&mut self, msg: StreamMessage) -> Result<(), ShellError> {
|
||||
self.0.push(msg);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn flush(&mut self) -> Result<(), ShellError> {
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl WriteStreamMessage for mpsc::Sender<StreamMessage> {
|
||||
fn write_stream_message(&mut self, msg: StreamMessage) -> Result<(), ShellError> {
|
||||
self.send(msg).map_err(|err| ShellError::NushellFailed {
|
||||
msg: err.to_string(),
|
||||
})
|
||||
}
|
||||
|
||||
fn flush(&mut self) -> Result<(), ShellError> {
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn reader_recv_list_messages() -> Result<(), ShellError> {
|
||||
let (tx, rx) = mpsc::channel();
|
||||
let mut reader = StreamReader::new(0, rx, TestSink::default());
|
||||
|
||||
tx.send(Ok(Some(StreamData::List(Value::test_int(5)))))
|
||||
.unwrap();
|
||||
drop(tx);
|
||||
|
||||
assert_eq!(Some(Value::test_int(5)), reader.recv()?);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn list_reader_recv_wrong_type() -> Result<(), ShellError> {
|
||||
let (tx, rx) = mpsc::channel();
|
||||
let mut reader = StreamReader::<Value, _>::new(0, rx, TestSink::default());
|
||||
|
||||
tx.send(Ok(Some(StreamData::Raw(Ok(vec![10, 20])))))
|
||||
.unwrap();
|
||||
tx.send(Ok(Some(StreamData::List(Value::test_nothing()))))
|
||||
.unwrap();
|
||||
drop(tx);
|
||||
|
||||
reader.recv().expect_err("should be an error");
|
||||
reader.recv().expect("should be able to recover");
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn reader_recv_raw_messages() -> Result<(), ShellError> {
|
||||
let (tx, rx) = mpsc::channel();
|
||||
let mut reader = StreamReader::new(0, rx, TestSink::default());
|
||||
|
||||
tx.send(Ok(Some(StreamData::Raw(Ok(vec![10, 20])))))
|
||||
.unwrap();
|
||||
drop(tx);
|
||||
|
||||
assert_eq!(Some(vec![10, 20]), reader.recv()?.transpose()?);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn raw_reader_recv_wrong_type() -> Result<(), ShellError> {
|
||||
let (tx, rx) = mpsc::channel();
|
||||
let mut reader =
|
||||
StreamReader::<Result<Vec<u8>, ShellError>, _>::new(0, rx, TestSink::default());
|
||||
|
||||
tx.send(Ok(Some(StreamData::List(Value::test_nothing()))))
|
||||
.unwrap();
|
||||
tx.send(Ok(Some(StreamData::Raw(Ok(vec![10, 20])))))
|
||||
.unwrap();
|
||||
drop(tx);
|
||||
|
||||
reader.recv().expect_err("should be an error");
|
||||
reader.recv().expect("should be able to recover");
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn reader_recv_acknowledge() -> Result<(), ShellError> {
|
||||
let (tx, rx) = mpsc::channel();
|
||||
let mut reader = StreamReader::<Value, _>::new(0, rx, TestSink::default());
|
||||
|
||||
tx.send(Ok(Some(StreamData::List(Value::test_int(5)))))
|
||||
.unwrap();
|
||||
tx.send(Ok(Some(StreamData::List(Value::test_int(6)))))
|
||||
.unwrap();
|
||||
drop(tx);
|
||||
|
||||
reader.recv()?;
|
||||
reader.recv()?;
|
||||
let wrote = &reader.writer.0;
|
||||
assert!(wrote.len() >= 2);
|
||||
assert!(
|
||||
matches!(wrote[0], StreamMessage::Ack(0)),
|
||||
"0 = {:?}",
|
||||
wrote[0]
|
||||
);
|
||||
assert!(
|
||||
matches!(wrote[1], StreamMessage::Ack(0)),
|
||||
"1 = {:?}",
|
||||
wrote[1]
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn reader_recv_end_of_stream() -> Result<(), ShellError> {
|
||||
let (tx, rx) = mpsc::channel();
|
||||
let mut reader = StreamReader::<Value, _>::new(0, rx, TestSink::default());
|
||||
|
||||
tx.send(Ok(Some(StreamData::List(Value::test_int(5)))))
|
||||
.unwrap();
|
||||
tx.send(Ok(None)).unwrap();
|
||||
drop(tx);
|
||||
|
||||
assert!(reader.recv()?.is_some(), "actual message");
|
||||
assert!(reader.recv()?.is_none(), "on close");
|
||||
assert!(reader.recv()?.is_none(), "after close");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn reader_drop() {
|
||||
let (_tx, rx) = mpsc::channel();
|
||||
|
||||
// Flag set if drop message is received.
|
||||
struct Check(Arc<AtomicBool>);
|
||||
|
||||
impl WriteStreamMessage for Check {
|
||||
fn write_stream_message(&mut self, msg: StreamMessage) -> Result<(), ShellError> {
|
||||
assert!(matches!(msg, StreamMessage::Drop(1)), "got {:?}", msg);
|
||||
self.0.store(true, Relaxed);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn flush(&mut self) -> Result<(), ShellError> {
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
let flag = Arc::new(AtomicBool::new(false));
|
||||
|
||||
let reader = StreamReader::<Value, _>::new(1, rx, Check(flag.clone()));
|
||||
drop(reader);
|
||||
|
||||
assert!(flag.load(Relaxed));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn writer_write_all_stops_if_dropped() -> Result<(), ShellError> {
|
||||
let signal = Arc::new(StreamWriterSignal::new(20));
|
||||
let id = 1337;
|
||||
let mut writer = StreamWriter::new(id, signal.clone(), TestSink::default());
|
||||
|
||||
// Simulate this by having it consume a stream that will actually do the drop halfway through
|
||||
let iter = (0..5).map(Value::test_int).chain({
|
||||
let mut n = 5;
|
||||
std::iter::from_fn(move || {
|
||||
// produces numbers 5..10, but drops for the first one
|
||||
if n == 5 {
|
||||
signal.set_dropped().unwrap();
|
||||
}
|
||||
if n < 10 {
|
||||
let value = Value::test_int(n);
|
||||
n += 1;
|
||||
Some(value)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
});
|
||||
|
||||
writer.write_all(iter)?;
|
||||
|
||||
assert!(writer.is_dropped()?);
|
||||
|
||||
let wrote = &writer.writer.0;
|
||||
assert_eq!(5, wrote.len(), "length wrong: {wrote:?}");
|
||||
|
||||
for (n, message) in (0..5).zip(wrote) {
|
||||
match message {
|
||||
StreamMessage::Data(msg_id, StreamData::List(value)) => {
|
||||
assert_eq!(id, *msg_id, "id");
|
||||
assert_eq!(Value::test_int(n), *value, "value");
|
||||
}
|
||||
other => panic!("unexpected message: {other:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn writer_end() -> Result<(), ShellError> {
|
||||
let signal = Arc::new(StreamWriterSignal::new(20));
|
||||
let mut writer = StreamWriter::new(9001, signal.clone(), TestSink::default());
|
||||
|
||||
writer.end()?;
|
||||
writer
|
||||
.write(Value::test_int(2))
|
||||
.expect_err("shouldn't be able to write after end");
|
||||
writer.end().expect("end twice should be ok");
|
||||
|
||||
let wrote = &writer.writer.0;
|
||||
assert!(
|
||||
matches!(wrote.last(), Some(StreamMessage::End(9001))),
|
||||
"didn't write end message: {wrote:?}"
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn signal_set_dropped() -> Result<(), ShellError> {
|
||||
let signal = StreamWriterSignal::new(4);
|
||||
assert!(!signal.is_dropped()?);
|
||||
signal.set_dropped()?;
|
||||
assert!(signal.is_dropped()?);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn signal_notify_sent_false_if_unacknowledged() -> Result<(), ShellError> {
|
||||
let signal = StreamWriterSignal::new(2);
|
||||
assert!(signal.notify_sent()?);
|
||||
for _ in 0..100 {
|
||||
assert!(!signal.notify_sent()?);
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn signal_notify_sent_never_false_if_flowing() -> Result<(), ShellError> {
|
||||
let signal = StreamWriterSignal::new(1);
|
||||
for _ in 0..100 {
|
||||
signal.notify_acknowledged()?;
|
||||
}
|
||||
for _ in 0..100 {
|
||||
assert!(signal.notify_sent()?);
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn signal_wait_for_drain_blocks_on_unacknowledged() -> Result<(), ShellError> {
|
||||
let signal = StreamWriterSignal::new(50);
|
||||
std::thread::scope(|scope| {
|
||||
let spawned = scope.spawn(|| {
|
||||
for _ in 0..100 {
|
||||
if !signal.notify_sent()? {
|
||||
signal.wait_for_drain()?;
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
});
|
||||
std::thread::sleep(WAIT_DURATION);
|
||||
assert!(!spawned.is_finished(), "didn't block");
|
||||
for _ in 0..100 {
|
||||
signal.notify_acknowledged()?;
|
||||
}
|
||||
std::thread::sleep(WAIT_DURATION);
|
||||
assert!(spawned.is_finished(), "blocked at end");
|
||||
spawned.join().unwrap()
|
||||
})
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn signal_wait_for_drain_unblocks_on_dropped() -> Result<(), ShellError> {
|
||||
let signal = StreamWriterSignal::new(1);
|
||||
std::thread::scope(|scope| {
|
||||
let spawned = scope.spawn(|| {
|
||||
while !signal.is_dropped()? {
|
||||
if !signal.notify_sent()? {
|
||||
signal.wait_for_drain()?;
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
});
|
||||
std::thread::sleep(WAIT_DURATION);
|
||||
assert!(!spawned.is_finished(), "didn't block");
|
||||
signal.set_dropped()?;
|
||||
std::thread::sleep(WAIT_DURATION);
|
||||
assert!(spawned.is_finished(), "still blocked at end");
|
||||
spawned.join().unwrap()
|
||||
})
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn stream_manager_single_stream_read_scenario() -> Result<(), ShellError> {
|
||||
let manager = StreamManager::new();
|
||||
let handle = manager.get_handle();
|
||||
let (tx, rx) = mpsc::channel();
|
||||
let readable = handle.read_stream::<Value, _>(2, tx)?;
|
||||
|
||||
let expected_values = vec![Value::test_int(40), Value::test_string("hello")];
|
||||
|
||||
for value in &expected_values {
|
||||
manager.handle_message(StreamMessage::Data(2, value.clone().into()))?;
|
||||
}
|
||||
manager.handle_message(StreamMessage::End(2))?;
|
||||
|
||||
let values = readable.collect::<Vec<Value>>();
|
||||
|
||||
assert_eq!(expected_values, values);
|
||||
|
||||
// Now check the sent messages on consumption
|
||||
// Should be Ack for each message, then Drop
|
||||
for _ in &expected_values {
|
||||
match rx.try_recv().expect("failed to receive Ack") {
|
||||
StreamMessage::Ack(2) => (),
|
||||
other => panic!("should have been an Ack: {other:?}"),
|
||||
}
|
||||
}
|
||||
match rx.try_recv().expect("failed to receive Drop") {
|
||||
StreamMessage::Drop(2) => (),
|
||||
other => panic!("should have been a Drop: {other:?}"),
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn stream_manager_multi_stream_read_scenario() -> Result<(), ShellError> {
|
||||
let manager = StreamManager::new();
|
||||
let handle = manager.get_handle();
|
||||
let (tx, rx) = mpsc::channel();
|
||||
let readable_list = handle.read_stream::<Value, _>(2, tx.clone())?;
|
||||
let readable_raw = handle.read_stream::<Result<Vec<u8>, _>, _>(3, tx)?;
|
||||
|
||||
let expected_values = (1..100).map(Value::test_int).collect::<Vec<_>>();
|
||||
let expected_raw_buffers = (1..100).map(|n| vec![n]).collect::<Vec<Vec<u8>>>();
|
||||
|
||||
for (value, buf) in expected_values.iter().zip(&expected_raw_buffers) {
|
||||
manager.handle_message(StreamMessage::Data(2, value.clone().into()))?;
|
||||
manager.handle_message(StreamMessage::Data(3, StreamData::Raw(Ok(buf.clone()))))?;
|
||||
}
|
||||
manager.handle_message(StreamMessage::End(2))?;
|
||||
manager.handle_message(StreamMessage::End(3))?;
|
||||
|
||||
let values = readable_list.collect::<Vec<Value>>();
|
||||
let bufs = readable_raw.collect::<Result<Vec<Vec<u8>>, _>>()?;
|
||||
|
||||
for (expected_value, value) in expected_values.iter().zip(&values) {
|
||||
assert_eq!(expected_value, value, "in List stream");
|
||||
}
|
||||
for (expected_buf, buf) in expected_raw_buffers.iter().zip(&bufs) {
|
||||
assert_eq!(expected_buf, buf, "in Raw stream");
|
||||
}
|
||||
|
||||
// Now check the sent messages on consumption
|
||||
// Should be Ack for each message, then Drop
|
||||
for _ in &expected_values {
|
||||
match rx.try_recv().expect("failed to receive Ack") {
|
||||
StreamMessage::Ack(2) => (),
|
||||
other => panic!("should have been an Ack(2): {other:?}"),
|
||||
}
|
||||
}
|
||||
match rx.try_recv().expect("failed to receive Drop") {
|
||||
StreamMessage::Drop(2) => (),
|
||||
other => panic!("should have been a Drop(2): {other:?}"),
|
||||
}
|
||||
for _ in &expected_values {
|
||||
match rx.try_recv().expect("failed to receive Ack") {
|
||||
StreamMessage::Ack(3) => (),
|
||||
other => panic!("should have been an Ack(3): {other:?}"),
|
||||
}
|
||||
}
|
||||
match rx.try_recv().expect("failed to receive Drop") {
|
||||
StreamMessage::Drop(3) => (),
|
||||
other => panic!("should have been a Drop(3): {other:?}"),
|
||||
}
|
||||
|
||||
// Should be end of stream
|
||||
assert!(
|
||||
rx.try_recv().is_err(),
|
||||
"more messages written to stream than expected"
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn stream_manager_write_scenario() -> Result<(), ShellError> {
|
||||
let manager = StreamManager::new();
|
||||
let handle = manager.get_handle();
|
||||
let (tx, rx) = mpsc::channel();
|
||||
let mut writable = handle.write_stream(4, tx, 100)?;
|
||||
|
||||
let expected_values = vec![b"hello".to_vec(), b"world".to_vec(), b"test".to_vec()];
|
||||
|
||||
for value in &expected_values {
|
||||
writable.write(Ok(value.clone()))?;
|
||||
}
|
||||
|
||||
// Now try signalling ack
|
||||
assert_eq!(
|
||||
expected_values.len() as i32,
|
||||
writable.signal.lock()?.unacknowledged,
|
||||
"unacknowledged initial count",
|
||||
);
|
||||
manager.handle_message(StreamMessage::Ack(4))?;
|
||||
assert_eq!(
|
||||
expected_values.len() as i32 - 1,
|
||||
writable.signal.lock()?.unacknowledged,
|
||||
"unacknowledged post-Ack count",
|
||||
);
|
||||
|
||||
// ...and Drop
|
||||
manager.handle_message(StreamMessage::Drop(4))?;
|
||||
assert!(writable.is_dropped()?);
|
||||
|
||||
// Drop the StreamWriter...
|
||||
drop(writable);
|
||||
|
||||
// now check what was actually written
|
||||
for value in &expected_values {
|
||||
match rx.try_recv().expect("failed to receive Data") {
|
||||
StreamMessage::Data(4, StreamData::Raw(Ok(received))) => {
|
||||
assert_eq!(*value, received);
|
||||
}
|
||||
other @ StreamMessage::Data(..) => panic!("wrong Data for {value:?}: {other:?}"),
|
||||
other => panic!("should have been Data: {other:?}"),
|
||||
}
|
||||
}
|
||||
match rx.try_recv().expect("failed to receive End") {
|
||||
StreamMessage::End(4) => (),
|
||||
other => panic!("should have been End: {other:?}"),
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn stream_manager_broadcast_read_error() -> Result<(), ShellError> {
|
||||
let manager = StreamManager::new();
|
||||
let handle = manager.get_handle();
|
||||
let mut readable0 = handle.read_stream::<Value, _>(0, TestSink::default())?;
|
||||
let mut readable1 = handle.read_stream::<Result<Vec<u8>, _>, _>(1, TestSink::default())?;
|
||||
|
||||
let error = ShellError::PluginFailedToDecode {
|
||||
msg: "test decode error".into(),
|
||||
};
|
||||
|
||||
manager.broadcast_read_error(error.clone())?;
|
||||
drop(manager);
|
||||
|
||||
assert_eq!(
|
||||
error.to_string(),
|
||||
readable0
|
||||
.recv()
|
||||
.transpose()
|
||||
.expect("nothing received from readable0")
|
||||
.expect_err("not an error received from readable0")
|
||||
.to_string()
|
||||
);
|
||||
assert_eq!(
|
||||
error.to_string(),
|
||||
readable1
|
||||
.next()
|
||||
.expect("nothing received from readable1")
|
||||
.expect_err("not an error received from readable1")
|
||||
.to_string()
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn stream_manager_drop_writers_on_drop() -> Result<(), ShellError> {
|
||||
let manager = StreamManager::new();
|
||||
let handle = manager.get_handle();
|
||||
let writable = handle.write_stream(4, TestSink::default(), 100)?;
|
||||
|
||||
assert!(!writable.is_dropped()?);
|
||||
|
||||
drop(manager);
|
||||
|
||||
assert!(writable.is_dropped()?);
|
||||
|
||||
Ok(())
|
||||
}
|
143
crates/nu-plugin/src/plugin/interface/test_util.rs
Normal file
143
crates/nu-plugin/src/plugin/interface/test_util.rs
Normal file
@ -0,0 +1,143 @@
|
||||
use std::{
|
||||
collections::VecDeque,
|
||||
sync::{Arc, Mutex},
|
||||
};
|
||||
|
||||
use nu_protocol::ShellError;
|
||||
|
||||
use crate::{plugin::PluginIdentity, protocol::PluginInput, PluginOutput};
|
||||
|
||||
use super::{EngineInterfaceManager, PluginInterfaceManager, PluginRead, PluginWrite};
|
||||
|
||||
/// Mock read/write helper for the engine and plugin interfaces.
|
||||
#[derive(Debug, Clone)]
|
||||
pub(crate) struct TestCase<I, O> {
|
||||
r#in: Arc<Mutex<TestData<I>>>,
|
||||
out: Arc<Mutex<TestData<O>>>,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub(crate) struct TestData<T> {
|
||||
data: VecDeque<T>,
|
||||
error: Option<ShellError>,
|
||||
flushed: bool,
|
||||
}
|
||||
|
||||
impl<T> Default for TestData<T> {
|
||||
fn default() -> Self {
|
||||
TestData {
|
||||
data: VecDeque::new(),
|
||||
error: None,
|
||||
flushed: false,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<I, O> PluginRead<I> for TestCase<I, O> {
|
||||
fn read(&mut self) -> Result<Option<I>, ShellError> {
|
||||
let mut lock = self.r#in.lock().unwrap();
|
||||
if let Some(err) = lock.error.take() {
|
||||
Err(err)
|
||||
} else {
|
||||
Ok(lock.data.pop_front())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<I, O> PluginWrite<O> for TestCase<I, O>
|
||||
where
|
||||
I: Send + Clone,
|
||||
O: Send + Clone,
|
||||
{
|
||||
fn write(&self, data: &O) -> Result<(), ShellError> {
|
||||
let mut lock = self.out.lock().unwrap();
|
||||
lock.flushed = false;
|
||||
|
||||
if let Some(err) = lock.error.take() {
|
||||
Err(err)
|
||||
} else {
|
||||
lock.data.push_back(data.clone());
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
fn flush(&self) -> Result<(), ShellError> {
|
||||
let mut lock = self.out.lock().unwrap();
|
||||
lock.flushed = true;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
impl<I, O> TestCase<I, O> {
|
||||
pub(crate) fn new() -> TestCase<I, O> {
|
||||
TestCase {
|
||||
r#in: Default::default(),
|
||||
out: Default::default(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Clear the read buffer.
|
||||
pub(crate) fn clear(&self) {
|
||||
self.r#in.lock().unwrap().data.truncate(0);
|
||||
}
|
||||
|
||||
/// Add input that will be read by the interface.
|
||||
pub(crate) fn add(&self, input: impl Into<I>) {
|
||||
self.r#in.lock().unwrap().data.push_back(input.into());
|
||||
}
|
||||
|
||||
/// Add multiple inputs that will be read by the interface.
|
||||
pub(crate) fn extend(&self, inputs: impl IntoIterator<Item = I>) {
|
||||
self.r#in.lock().unwrap().data.extend(inputs);
|
||||
}
|
||||
|
||||
/// Return an error from the next read operation.
|
||||
pub(crate) fn set_read_error(&self, err: ShellError) {
|
||||
self.r#in.lock().unwrap().error = Some(err);
|
||||
}
|
||||
|
||||
/// Return an error from the next write operation.
|
||||
pub(crate) fn set_write_error(&self, err: ShellError) {
|
||||
self.out.lock().unwrap().error = Some(err);
|
||||
}
|
||||
|
||||
/// Get the next output that was written.
|
||||
pub(crate) fn next_written(&self) -> Option<O> {
|
||||
self.out.lock().unwrap().data.pop_front()
|
||||
}
|
||||
|
||||
/// Iterator over written data.
|
||||
pub(crate) fn written(&self) -> impl Iterator<Item = O> + '_ {
|
||||
std::iter::from_fn(|| self.next_written())
|
||||
}
|
||||
|
||||
/// Returns true if the writer was flushed after the last write operation.
|
||||
pub(crate) fn was_flushed(&self) -> bool {
|
||||
self.out.lock().unwrap().flushed
|
||||
}
|
||||
|
||||
/// Returns true if the reader has unconsumed reads.
|
||||
pub(crate) fn has_unconsumed_read(&self) -> bool {
|
||||
!self.r#in.lock().unwrap().data.is_empty()
|
||||
}
|
||||
|
||||
/// Returns true if the writer has unconsumed writes.
|
||||
pub(crate) fn has_unconsumed_write(&self) -> bool {
|
||||
!self.out.lock().unwrap().data.is_empty()
|
||||
}
|
||||
}
|
||||
|
||||
impl TestCase<PluginOutput, PluginInput> {
|
||||
/// Create a new [`PluginInterfaceManager`] that writes to this test case.
|
||||
pub(crate) fn plugin(&self, name: &str) -> PluginInterfaceManager {
|
||||
PluginInterfaceManager::new(PluginIdentity::new_fake(name), self.clone())
|
||||
}
|
||||
}
|
||||
|
||||
impl TestCase<PluginInput, PluginOutput> {
|
||||
/// Create a new [`EngineInterfaceManager`] that writes to this test case.
|
||||
pub(crate) fn engine(&self) -> EngineInterfaceManager {
|
||||
EngineInterfaceManager::new(self.clone())
|
||||
}
|
||||
}
|
559
crates/nu-plugin/src/plugin/interface/tests.rs
Normal file
559
crates/nu-plugin/src/plugin/interface/tests.rs
Normal file
@ -0,0 +1,559 @@
|
||||
use std::{path::Path, sync::Arc};
|
||||
|
||||
use nu_protocol::{
|
||||
DataSource, ListStream, PipelineData, PipelineMetadata, RawStream, ShellError, Span, Value,
|
||||
};
|
||||
|
||||
use crate::{
|
||||
protocol::{
|
||||
ExternalStreamInfo, ListStreamInfo, PipelineDataHeader, PluginInput, PluginOutput,
|
||||
RawStreamInfo, StreamData, StreamMessage,
|
||||
},
|
||||
sequence::Sequence,
|
||||
};
|
||||
|
||||
use super::{
|
||||
stream::{StreamManager, StreamManagerHandle},
|
||||
test_util::TestCase,
|
||||
Interface, InterfaceManager, PluginRead, PluginWrite,
|
||||
};
|
||||
|
||||
fn test_metadata() -> PipelineMetadata {
|
||||
PipelineMetadata {
|
||||
data_source: DataSource::FilePath("/test/path".into()),
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct TestInterfaceManager {
|
||||
stream_manager: StreamManager,
|
||||
test: TestCase<PluginInput, PluginOutput>,
|
||||
seq: Arc<Sequence>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct TestInterface {
|
||||
stream_manager_handle: StreamManagerHandle,
|
||||
test: TestCase<PluginInput, PluginOutput>,
|
||||
seq: Arc<Sequence>,
|
||||
}
|
||||
|
||||
impl TestInterfaceManager {
|
||||
fn new(test: &TestCase<PluginInput, PluginOutput>) -> TestInterfaceManager {
|
||||
TestInterfaceManager {
|
||||
stream_manager: StreamManager::new(),
|
||||
test: test.clone(),
|
||||
seq: Arc::new(Sequence::default()),
|
||||
}
|
||||
}
|
||||
|
||||
fn consume_all(&mut self) -> Result<(), ShellError> {
|
||||
while let Some(msg) = self.test.read()? {
|
||||
self.consume(msg)?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl InterfaceManager for TestInterfaceManager {
|
||||
type Interface = TestInterface;
|
||||
type Input = PluginInput;
|
||||
|
||||
fn get_interface(&self) -> Self::Interface {
|
||||
TestInterface {
|
||||
stream_manager_handle: self.stream_manager.get_handle(),
|
||||
test: self.test.clone(),
|
||||
seq: self.seq.clone(),
|
||||
}
|
||||
}
|
||||
|
||||
fn consume(&mut self, input: Self::Input) -> Result<(), ShellError> {
|
||||
match input {
|
||||
PluginInput::Stream(msg) => self.consume_stream_message(msg),
|
||||
_ => unimplemented!(),
|
||||
}
|
||||
}
|
||||
|
||||
fn stream_manager(&self) -> &StreamManager {
|
||||
&self.stream_manager
|
||||
}
|
||||
|
||||
fn prepare_pipeline_data(&self, data: PipelineData) -> Result<PipelineData, ShellError> {
|
||||
Ok(data.set_metadata(Some(test_metadata())))
|
||||
}
|
||||
}
|
||||
|
||||
impl Interface for TestInterface {
|
||||
type Output = PluginOutput;
|
||||
|
||||
fn write(&self, output: Self::Output) -> Result<(), ShellError> {
|
||||
self.test.write(&output)
|
||||
}
|
||||
|
||||
fn flush(&self) -> Result<(), ShellError> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn stream_id_sequence(&self) -> &Sequence {
|
||||
&self.seq
|
||||
}
|
||||
|
||||
fn stream_manager_handle(&self) -> &StreamManagerHandle {
|
||||
&self.stream_manager_handle
|
||||
}
|
||||
|
||||
fn prepare_pipeline_data(&self, data: PipelineData) -> Result<PipelineData, ShellError> {
|
||||
// Add an arbitrary check to the data to verify this is being called
|
||||
match data {
|
||||
PipelineData::Value(Value::Binary { .. }, None) => Err(ShellError::NushellFailed {
|
||||
msg: "TEST can't send binary".into(),
|
||||
}),
|
||||
_ => Ok(data),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn read_pipeline_data_empty() -> Result<(), ShellError> {
|
||||
let manager = TestInterfaceManager::new(&TestCase::new());
|
||||
let header = PipelineDataHeader::Empty;
|
||||
|
||||
assert!(matches!(
|
||||
manager.read_pipeline_data(header, None)?,
|
||||
PipelineData::Empty
|
||||
));
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn read_pipeline_data_value() -> Result<(), ShellError> {
|
||||
let manager = TestInterfaceManager::new(&TestCase::new());
|
||||
let value = Value::test_int(4);
|
||||
let header = PipelineDataHeader::Value(value.clone());
|
||||
|
||||
match manager.read_pipeline_data(header, None)? {
|
||||
PipelineData::Value(read_value, _) => assert_eq!(value, read_value),
|
||||
PipelineData::ListStream(_, _) => panic!("unexpected ListStream"),
|
||||
PipelineData::ExternalStream { .. } => panic!("unexpected ExternalStream"),
|
||||
PipelineData::Empty => panic!("unexpected Empty"),
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn read_pipeline_data_list_stream() -> Result<(), ShellError> {
|
||||
let test = TestCase::new();
|
||||
let mut manager = TestInterfaceManager::new(&test);
|
||||
|
||||
let data = (0..100).map(Value::test_int).collect::<Vec<_>>();
|
||||
|
||||
for value in &data {
|
||||
test.add(StreamMessage::Data(7, value.clone().into()));
|
||||
}
|
||||
test.add(StreamMessage::End(7));
|
||||
|
||||
let header = PipelineDataHeader::ListStream(ListStreamInfo { id: 7 });
|
||||
|
||||
let pipe = manager.read_pipeline_data(header, None)?;
|
||||
assert!(
|
||||
matches!(pipe, PipelineData::ListStream(..)),
|
||||
"unexpected PipelineData: {pipe:?}"
|
||||
);
|
||||
|
||||
// need to consume input
|
||||
manager.consume_all()?;
|
||||
|
||||
let mut count = 0;
|
||||
for (expected, read) in data.into_iter().zip(pipe) {
|
||||
assert_eq!(expected, read);
|
||||
count += 1;
|
||||
}
|
||||
assert_eq!(100, count);
|
||||
|
||||
assert!(test.has_unconsumed_write());
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn read_pipeline_data_external_stream() -> Result<(), ShellError> {
|
||||
let test = TestCase::new();
|
||||
let mut manager = TestInterfaceManager::new(&test);
|
||||
|
||||
let iterations = 100;
|
||||
let out_pattern = b"hello".to_vec();
|
||||
let err_pattern = vec![5, 4, 3, 2];
|
||||
|
||||
test.add(StreamMessage::Data(14, Value::test_int(1).into()));
|
||||
for _ in 0..iterations {
|
||||
test.add(StreamMessage::Data(12, Ok(out_pattern.clone()).into()));
|
||||
test.add(StreamMessage::Data(13, Ok(err_pattern.clone()).into()));
|
||||
}
|
||||
test.add(StreamMessage::End(12));
|
||||
test.add(StreamMessage::End(13));
|
||||
test.add(StreamMessage::End(14));
|
||||
|
||||
let test_span = Span::new(10, 13);
|
||||
let header = PipelineDataHeader::ExternalStream(ExternalStreamInfo {
|
||||
span: test_span,
|
||||
stdout: Some(RawStreamInfo {
|
||||
id: 12,
|
||||
is_binary: false,
|
||||
known_size: Some((out_pattern.len() * iterations) as u64),
|
||||
}),
|
||||
stderr: Some(RawStreamInfo {
|
||||
id: 13,
|
||||
is_binary: true,
|
||||
known_size: None,
|
||||
}),
|
||||
exit_code: Some(ListStreamInfo { id: 14 }),
|
||||
trim_end_newline: true,
|
||||
});
|
||||
|
||||
let pipe = manager.read_pipeline_data(header, None)?;
|
||||
|
||||
// need to consume input
|
||||
manager.consume_all()?;
|
||||
|
||||
match pipe {
|
||||
PipelineData::ExternalStream {
|
||||
stdout,
|
||||
stderr,
|
||||
exit_code,
|
||||
span,
|
||||
metadata,
|
||||
trim_end_newline,
|
||||
} => {
|
||||
let stdout = stdout.expect("stdout is None");
|
||||
let stderr = stderr.expect("stderr is None");
|
||||
let exit_code = exit_code.expect("exit_code is None");
|
||||
assert_eq!(test_span, span);
|
||||
assert!(
|
||||
metadata.is_some(),
|
||||
"expected metadata to be Some due to prepare_pipeline_data()"
|
||||
);
|
||||
assert!(trim_end_newline);
|
||||
|
||||
assert!(!stdout.is_binary);
|
||||
assert!(stderr.is_binary);
|
||||
|
||||
assert_eq!(
|
||||
Some((out_pattern.len() * iterations) as u64),
|
||||
stdout.known_size
|
||||
);
|
||||
assert_eq!(None, stderr.known_size);
|
||||
|
||||
// check the streams
|
||||
let mut count = 0;
|
||||
for chunk in stdout.stream {
|
||||
assert_eq!(out_pattern, chunk?);
|
||||
count += 1;
|
||||
}
|
||||
assert_eq!(iterations, count, "stdout length");
|
||||
let mut count = 0;
|
||||
|
||||
for chunk in stderr.stream {
|
||||
assert_eq!(err_pattern, chunk?);
|
||||
count += 1;
|
||||
}
|
||||
assert_eq!(iterations, count, "stderr length");
|
||||
|
||||
assert_eq!(vec![Value::test_int(1)], exit_code.collect::<Vec<_>>());
|
||||
}
|
||||
_ => panic!("unexpected PipelineData: {pipe:?}"),
|
||||
}
|
||||
|
||||
// Don't need to check exactly what was written, just be sure that there is some output
|
||||
assert!(test.has_unconsumed_write());
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn read_pipeline_data_ctrlc() -> Result<(), ShellError> {
|
||||
let manager = TestInterfaceManager::new(&TestCase::new());
|
||||
let header = PipelineDataHeader::ListStream(ListStreamInfo { id: 0 });
|
||||
let ctrlc = Default::default();
|
||||
match manager.read_pipeline_data(header, Some(&ctrlc))? {
|
||||
PipelineData::ListStream(
|
||||
ListStream {
|
||||
ctrlc: stream_ctrlc,
|
||||
..
|
||||
},
|
||||
_,
|
||||
) => {
|
||||
assert!(Arc::ptr_eq(&ctrlc, &stream_ctrlc.expect("ctrlc not set")));
|
||||
Ok(())
|
||||
}
|
||||
_ => panic!("Unexpected PipelineData, should have been ListStream"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn read_pipeline_data_prepared_properly() -> Result<(), ShellError> {
|
||||
let manager = TestInterfaceManager::new(&TestCase::new());
|
||||
let header = PipelineDataHeader::ListStream(ListStreamInfo { id: 0 });
|
||||
match manager.read_pipeline_data(header, None)? {
|
||||
PipelineData::ListStream(_, meta) => match meta {
|
||||
Some(PipelineMetadata { data_source }) => match data_source {
|
||||
DataSource::FilePath(path) => {
|
||||
assert_eq!(Path::new("/test/path"), path);
|
||||
Ok(())
|
||||
}
|
||||
_ => panic!("wrong metadata: {data_source:?}"),
|
||||
},
|
||||
None => panic!("metadata not set"),
|
||||
},
|
||||
_ => panic!("Unexpected PipelineData, should have been ListStream"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn write_pipeline_data_empty() -> Result<(), ShellError> {
|
||||
let test = TestCase::new();
|
||||
let manager = TestInterfaceManager::new(&test);
|
||||
let interface = manager.get_interface();
|
||||
|
||||
let (header, writer) = interface.init_write_pipeline_data(PipelineData::Empty)?;
|
||||
|
||||
assert!(matches!(header, PipelineDataHeader::Empty));
|
||||
|
||||
writer.write()?;
|
||||
|
||||
assert!(
|
||||
!test.has_unconsumed_write(),
|
||||
"Empty shouldn't write any stream messages, test: {test:#?}"
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn write_pipeline_data_value() -> Result<(), ShellError> {
|
||||
let test = TestCase::new();
|
||||
let manager = TestInterfaceManager::new(&test);
|
||||
let interface = manager.get_interface();
|
||||
let value = Value::test_int(7);
|
||||
|
||||
let (header, writer) =
|
||||
interface.init_write_pipeline_data(PipelineData::Value(value.clone(), None))?;
|
||||
|
||||
match header {
|
||||
PipelineDataHeader::Value(read_value) => assert_eq!(value, read_value),
|
||||
_ => panic!("unexpected header: {header:?}"),
|
||||
}
|
||||
|
||||
writer.write()?;
|
||||
|
||||
assert!(
|
||||
!test.has_unconsumed_write(),
|
||||
"Value shouldn't write any stream messages, test: {test:#?}"
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn write_pipeline_data_prepared_properly() {
|
||||
let manager = TestInterfaceManager::new(&TestCase::new());
|
||||
let interface = manager.get_interface();
|
||||
|
||||
// Sending a binary should be an error in our test scenario
|
||||
let value = Value::test_binary(vec![7, 8]);
|
||||
|
||||
match interface.init_write_pipeline_data(PipelineData::Value(value, None)) {
|
||||
Ok(_) => panic!("prepare_pipeline_data was not called"),
|
||||
Err(err) => {
|
||||
assert_eq!(
|
||||
ShellError::NushellFailed {
|
||||
msg: "TEST can't send binary".into()
|
||||
}
|
||||
.to_string(),
|
||||
err.to_string()
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn write_pipeline_data_list_stream() -> Result<(), ShellError> {
|
||||
let test = TestCase::new();
|
||||
let manager = TestInterfaceManager::new(&test);
|
||||
let interface = manager.get_interface();
|
||||
|
||||
let values = vec![
|
||||
Value::test_int(40),
|
||||
Value::test_bool(false),
|
||||
Value::test_string("this is a test"),
|
||||
];
|
||||
|
||||
// Set up pipeline data for a list stream
|
||||
let pipe = PipelineData::ListStream(
|
||||
ListStream::from_stream(values.clone().into_iter(), None),
|
||||
None,
|
||||
);
|
||||
|
||||
let (header, writer) = interface.init_write_pipeline_data(pipe)?;
|
||||
|
||||
let info = match header {
|
||||
PipelineDataHeader::ListStream(info) => info,
|
||||
_ => panic!("unexpected header: {header:?}"),
|
||||
};
|
||||
|
||||
writer.write()?;
|
||||
|
||||
// Now make sure the stream messages have been written
|
||||
for value in values {
|
||||
match test.next_written().expect("unexpected end of stream") {
|
||||
PluginOutput::Stream(StreamMessage::Data(id, data)) => {
|
||||
assert_eq!(info.id, id, "Data id");
|
||||
match data {
|
||||
StreamData::List(read_value) => assert_eq!(value, read_value, "Data value"),
|
||||
_ => panic!("unexpected Data: {data:?}"),
|
||||
}
|
||||
}
|
||||
other => panic!("unexpected output: {other:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
match test.next_written().expect("unexpected end of stream") {
|
||||
PluginOutput::Stream(StreamMessage::End(id)) => {
|
||||
assert_eq!(info.id, id, "End id");
|
||||
}
|
||||
other => panic!("unexpected output: {other:?}"),
|
||||
}
|
||||
|
||||
assert!(!test.has_unconsumed_write());
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn write_pipeline_data_external_stream() -> Result<(), ShellError> {
|
||||
let test = TestCase::new();
|
||||
let manager = TestInterfaceManager::new(&test);
|
||||
let interface = manager.get_interface();
|
||||
|
||||
let stdout_bufs = vec![
|
||||
b"hello".to_vec(),
|
||||
b"world".to_vec(),
|
||||
b"these are tests".to_vec(),
|
||||
];
|
||||
let stdout_len = stdout_bufs.iter().map(|b| b.len() as u64).sum::<u64>();
|
||||
let stderr_bufs = vec![b"error messages".to_vec(), b"go here".to_vec()];
|
||||
let exit_code = Value::test_int(7);
|
||||
|
||||
let span = Span::new(400, 500);
|
||||
|
||||
// Set up pipeline data for an external stream
|
||||
let pipe = PipelineData::ExternalStream {
|
||||
stdout: Some(RawStream::new(
|
||||
Box::new(stdout_bufs.clone().into_iter().map(Ok)),
|
||||
None,
|
||||
span,
|
||||
Some(stdout_len),
|
||||
)),
|
||||
stderr: Some(RawStream::new(
|
||||
Box::new(stderr_bufs.clone().into_iter().map(Ok)),
|
||||
None,
|
||||
span,
|
||||
None,
|
||||
)),
|
||||
exit_code: Some(ListStream::from_stream(
|
||||
std::iter::once(exit_code.clone()),
|
||||
None,
|
||||
)),
|
||||
span,
|
||||
metadata: None,
|
||||
trim_end_newline: true,
|
||||
};
|
||||
|
||||
let (header, writer) = interface.init_write_pipeline_data(pipe)?;
|
||||
|
||||
let info = match header {
|
||||
PipelineDataHeader::ExternalStream(info) => info,
|
||||
_ => panic!("unexpected header: {header:?}"),
|
||||
};
|
||||
|
||||
writer.write()?;
|
||||
|
||||
let stdout_info = info.stdout.as_ref().expect("stdout info is None");
|
||||
let stderr_info = info.stderr.as_ref().expect("stderr info is None");
|
||||
let exit_code_info = info.exit_code.as_ref().expect("exit code info is None");
|
||||
|
||||
assert_eq!(span, info.span);
|
||||
assert!(info.trim_end_newline);
|
||||
|
||||
assert_eq!(Some(stdout_len), stdout_info.known_size);
|
||||
assert_eq!(None, stderr_info.known_size);
|
||||
|
||||
// Now make sure the stream messages have been written
|
||||
let mut stdout_iter = stdout_bufs.into_iter();
|
||||
let mut stderr_iter = stderr_bufs.into_iter();
|
||||
let mut exit_code_iter = std::iter::once(exit_code);
|
||||
|
||||
let mut stdout_ended = false;
|
||||
let mut stderr_ended = false;
|
||||
let mut exit_code_ended = false;
|
||||
|
||||
// There's no specific order these messages must come in with respect to how the streams are
|
||||
// interleaved, but all of the data for each stream must be in its original order, and the
|
||||
// End must come after all Data
|
||||
for msg in test.written() {
|
||||
match msg {
|
||||
PluginOutput::Stream(StreamMessage::Data(id, data)) => {
|
||||
if id == stdout_info.id {
|
||||
let result: Result<Vec<u8>, ShellError> =
|
||||
data.try_into().expect("wrong data in stdout stream");
|
||||
assert_eq!(
|
||||
stdout_iter.next().expect("too much data in stdout"),
|
||||
result.expect("unexpected error in stdout stream")
|
||||
);
|
||||
} else if id == stderr_info.id {
|
||||
let result: Result<Vec<u8>, ShellError> =
|
||||
data.try_into().expect("wrong data in stderr stream");
|
||||
assert_eq!(
|
||||
stderr_iter.next().expect("too much data in stderr"),
|
||||
result.expect("unexpected error in stderr stream")
|
||||
);
|
||||
} else if id == exit_code_info.id {
|
||||
let code: Value = data.try_into().expect("wrong data in stderr stream");
|
||||
assert_eq!(
|
||||
exit_code_iter.next().expect("too much data in stderr"),
|
||||
code
|
||||
);
|
||||
} else {
|
||||
panic!("unrecognized stream id: {id}");
|
||||
}
|
||||
}
|
||||
PluginOutput::Stream(StreamMessage::End(id)) => {
|
||||
if id == stdout_info.id {
|
||||
assert!(!stdout_ended, "double End of stdout");
|
||||
assert!(stdout_iter.next().is_none(), "unexpected end of stdout");
|
||||
stdout_ended = true;
|
||||
} else if id == stderr_info.id {
|
||||
assert!(!stderr_ended, "double End of stderr");
|
||||
assert!(stderr_iter.next().is_none(), "unexpected end of stderr");
|
||||
stderr_ended = true;
|
||||
} else if id == exit_code_info.id {
|
||||
assert!(!exit_code_ended, "double End of exit_code");
|
||||
assert!(
|
||||
exit_code_iter.next().is_none(),
|
||||
"unexpected end of exit_code"
|
||||
);
|
||||
exit_code_ended = true;
|
||||
} else {
|
||||
panic!("unrecognized stream id: {id}");
|
||||
}
|
||||
}
|
||||
other => panic!("unexpected output: {other:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
assert!(stdout_ended, "stdout did not End");
|
||||
assert!(stderr_ended, "stderr did not End");
|
||||
assert!(exit_code_ended, "exit_code did not End");
|
||||
|
||||
Ok(())
|
||||
}
|
@ -2,53 +2,69 @@ mod declaration;
|
||||
pub use declaration::PluginDeclaration;
|
||||
use nu_engine::documentation::get_flags_section;
|
||||
use std::collections::HashMap;
|
||||
use std::ffi::OsStr;
|
||||
use std::sync::{Arc, Mutex};
|
||||
|
||||
use crate::protocol::{CallInput, LabeledError, PluginCall, PluginData, PluginResponse};
|
||||
use crate::plugin::interface::{EngineInterfaceManager, ReceivedPluginCall};
|
||||
use crate::protocol::{CallInfo, CustomValueOp, LabeledError, PluginInput, PluginOutput};
|
||||
use crate::EncodingType;
|
||||
use std::env;
|
||||
use std::fmt::Write;
|
||||
use std::io::{BufReader, ErrorKind, Read, Write as WriteTrait};
|
||||
use std::io::{BufReader, Read, Write as WriteTrait};
|
||||
use std::path::Path;
|
||||
use std::process::{Child, ChildStdout, Command as CommandSys, Stdio};
|
||||
|
||||
use nu_protocol::{CustomValue, PluginSignature, ShellError, Span, Value};
|
||||
use nu_protocol::{PipelineData, PluginSignature, ShellError, Value};
|
||||
|
||||
mod interface;
|
||||
pub(crate) use interface::PluginInterface;
|
||||
|
||||
mod context;
|
||||
pub(crate) use context::PluginExecutionCommandContext;
|
||||
|
||||
mod identity;
|
||||
pub(crate) use identity::PluginIdentity;
|
||||
|
||||
use self::interface::{InterfaceManager, PluginInterfaceManager};
|
||||
|
||||
use super::EvaluatedCall;
|
||||
|
||||
pub(crate) const OUTPUT_BUFFER_SIZE: usize = 8192;
|
||||
|
||||
/// Encoding scheme that defines a plugin's communication protocol with Nu
|
||||
pub trait PluginEncoder: Clone {
|
||||
/// The name of the encoder (e.g., `json`)
|
||||
fn name(&self) -> &str;
|
||||
/// Encoder for a specific message type. Usually implemented on [`PluginInput`]
|
||||
/// and [`PluginOutput`].
|
||||
#[doc(hidden)]
|
||||
pub trait Encoder<T>: Clone + Send + Sync {
|
||||
/// Serialize a value in the [`PluginEncoder`]s format
|
||||
///
|
||||
/// Returns [ShellError::IOError] if there was a problem writing, or
|
||||
/// [ShellError::PluginFailedToEncode] for a serialization error.
|
||||
#[doc(hidden)]
|
||||
fn encode(&self, data: &T, writer: &mut impl std::io::Write) -> Result<(), ShellError>;
|
||||
|
||||
/// Serialize a `PluginCall` in the `PluginEncoder`s format
|
||||
fn encode_call(
|
||||
&self,
|
||||
plugin_call: &PluginCall,
|
||||
writer: &mut impl std::io::Write,
|
||||
) -> Result<(), ShellError>;
|
||||
|
||||
/// Deserialize a `PluginCall` from the `PluginEncoder`s format
|
||||
fn decode_call(&self, reader: &mut impl std::io::BufRead) -> Result<PluginCall, ShellError>;
|
||||
|
||||
/// Serialize a `PluginResponse` from the plugin in this `PluginEncoder`'s preferred
|
||||
/// format
|
||||
fn encode_response(
|
||||
&self,
|
||||
plugin_response: &PluginResponse,
|
||||
writer: &mut impl std::io::Write,
|
||||
) -> Result<(), ShellError>;
|
||||
|
||||
/// Deserialize a `PluginResponse` from the plugin from this `PluginEncoder`'s
|
||||
/// preferred format
|
||||
fn decode_response(
|
||||
&self,
|
||||
reader: &mut impl std::io::BufRead,
|
||||
) -> Result<PluginResponse, ShellError>;
|
||||
/// Deserialize a value from the [`PluginEncoder`]'s format
|
||||
///
|
||||
/// Returns `None` if there is no more output to receive.
|
||||
///
|
||||
/// Returns [ShellError::IOError] if there was a problem reading, or
|
||||
/// [ShellError::PluginFailedToDecode] for a deserialization error.
|
||||
#[doc(hidden)]
|
||||
fn decode(&self, reader: &mut impl std::io::BufRead) -> Result<Option<T>, ShellError>;
|
||||
}
|
||||
|
||||
pub(crate) fn create_command(path: &Path, shell: Option<&Path>) -> CommandSys {
|
||||
/// Encoding scheme that defines a plugin's communication protocol with Nu
|
||||
pub trait PluginEncoder: Encoder<PluginInput> + Encoder<PluginOutput> {
|
||||
/// The name of the encoder (e.g., `json`)
|
||||
fn name(&self) -> &str;
|
||||
}
|
||||
|
||||
fn create_command(path: &Path, shell: Option<&Path>) -> CommandSys {
|
||||
log::trace!("Starting plugin: {path:?}, shell = {shell:?}");
|
||||
|
||||
// There is only one mode supported at the moment, but the idea is that future
|
||||
// communication methods could be supported if desirable
|
||||
let mut input_arg = Some("--stdio");
|
||||
|
||||
let mut process = match (path.extension(), shell) {
|
||||
(_, Some(shell)) => {
|
||||
let mut process = std::process::Command::new(shell);
|
||||
@ -57,18 +73,25 @@ pub(crate) fn create_command(path: &Path, shell: Option<&Path>) -> CommandSys {
|
||||
process
|
||||
}
|
||||
(Some(extension), None) => {
|
||||
let (shell, separator) = match extension.to_str() {
|
||||
let (shell, command_switch) = match extension.to_str() {
|
||||
Some("cmd") | Some("bat") => (Some("cmd"), Some("/c")),
|
||||
Some("sh") => (Some("sh"), Some("-c")),
|
||||
Some("py") => (Some("python"), None),
|
||||
_ => (None, None),
|
||||
};
|
||||
|
||||
match (shell, separator) {
|
||||
(Some(shell), Some(separator)) => {
|
||||
match (shell, command_switch) {
|
||||
(Some(shell), Some(command_switch)) => {
|
||||
let mut process = std::process::Command::new(shell);
|
||||
process.arg(separator);
|
||||
process.arg(path);
|
||||
process.arg(command_switch);
|
||||
// If `command_switch` is set, we need to pass the path + arg as one argument
|
||||
// e.g. sh -c "nu_plugin_inc --stdio"
|
||||
let mut combined = path.as_os_str().to_owned();
|
||||
if let Some(arg) = input_arg.take() {
|
||||
combined.push(OsStr::new(" "));
|
||||
combined.push(OsStr::new(arg));
|
||||
}
|
||||
process.arg(combined);
|
||||
|
||||
process
|
||||
}
|
||||
@ -84,41 +107,60 @@ pub(crate) fn create_command(path: &Path, shell: Option<&Path>) -> CommandSys {
|
||||
(None, None) => std::process::Command::new(path),
|
||||
};
|
||||
|
||||
// Pass input_arg, unless we consumed it already
|
||||
if let Some(input_arg) = input_arg {
|
||||
process.arg(input_arg);
|
||||
}
|
||||
|
||||
// Both stdout and stdin are piped so we can receive information from the plugin
|
||||
process.stdout(Stdio::piped()).stdin(Stdio::piped());
|
||||
|
||||
process
|
||||
}
|
||||
|
||||
pub(crate) fn call_plugin(
|
||||
child: &mut Child,
|
||||
plugin_call: PluginCall,
|
||||
encoding: &EncodingType,
|
||||
span: Span,
|
||||
) -> Result<PluginResponse, ShellError> {
|
||||
if let Some(mut stdin_writer) = child.stdin.take() {
|
||||
let encoding_clone = encoding.clone();
|
||||
// If the child process fills its stdout buffer, it may end up waiting until the parent
|
||||
// reads the stdout, and not be able to read stdin in the meantime, causing a deadlock.
|
||||
// Writing from another thread ensures that stdout is being read at the same time, avoiding the problem.
|
||||
std::thread::spawn(move || encoding_clone.encode_call(&plugin_call, &mut stdin_writer));
|
||||
}
|
||||
fn make_plugin_interface(
|
||||
mut child: Child,
|
||||
identity: Arc<PluginIdentity>,
|
||||
) -> Result<PluginInterface, ShellError> {
|
||||
let stdin = child
|
||||
.stdin
|
||||
.take()
|
||||
.ok_or_else(|| ShellError::PluginFailedToLoad {
|
||||
msg: "plugin missing stdin writer".into(),
|
||||
})?;
|
||||
|
||||
// Deserialize response from plugin to extract the resulting value
|
||||
if let Some(stdout_reader) = &mut child.stdout {
|
||||
let reader = stdout_reader;
|
||||
let mut buf_read = BufReader::with_capacity(OUTPUT_BUFFER_SIZE, reader);
|
||||
let mut stdout = child
|
||||
.stdout
|
||||
.take()
|
||||
.ok_or_else(|| ShellError::PluginFailedToLoad {
|
||||
msg: "Plugin missing stdout writer".into(),
|
||||
})?;
|
||||
|
||||
encoding.decode_response(&mut buf_read)
|
||||
} else {
|
||||
Err(ShellError::GenericError {
|
||||
error: "Error with stdout reader".into(),
|
||||
msg: "no stdout reader".into(),
|
||||
span: Some(span),
|
||||
help: None,
|
||||
inner: vec![],
|
||||
let encoder = get_plugin_encoding(&mut stdout)?;
|
||||
|
||||
let reader = BufReader::with_capacity(OUTPUT_BUFFER_SIZE, stdout);
|
||||
|
||||
let mut manager = PluginInterfaceManager::new(identity, (Mutex::new(stdin), encoder));
|
||||
let interface = manager.get_interface();
|
||||
interface.hello()?;
|
||||
|
||||
// Spawn the reader on a new thread. We need to be able to read messages at the same time that
|
||||
// we write, because we are expected to be able to handle multiple messages coming in from the
|
||||
// plugin at any time, including stream messages like `Drop`.
|
||||
std::thread::Builder::new()
|
||||
.name("plugin interface reader".into())
|
||||
.spawn(move || {
|
||||
if let Err(err) = manager.consume_all((reader, encoder)) {
|
||||
log::warn!("Error in PluginInterfaceManager: {err}");
|
||||
}
|
||||
// If the loop has ended, drop the manager so everyone disconnects and then wait for the
|
||||
// child to exit
|
||||
drop(manager);
|
||||
let _ = child.wait();
|
||||
})
|
||||
}
|
||||
.expect("failed to spawn thread");
|
||||
|
||||
Ok(interface)
|
||||
}
|
||||
|
||||
#[doc(hidden)] // Note: not for plugin authors / only used in nu-parser
|
||||
@ -127,71 +169,9 @@ pub fn get_signature(
|
||||
shell: Option<&Path>,
|
||||
current_envs: &HashMap<String, String>,
|
||||
) -> Result<Vec<PluginSignature>, ShellError> {
|
||||
let mut plugin_cmd = create_command(path, shell);
|
||||
let program_name = plugin_cmd.get_program().to_os_string().into_string();
|
||||
|
||||
plugin_cmd.envs(current_envs);
|
||||
let mut child = plugin_cmd.spawn().map_err(|err| {
|
||||
let error_msg = match err.kind() {
|
||||
ErrorKind::NotFound => match program_name {
|
||||
Ok(prog_name) => {
|
||||
format!("Can't find {prog_name}, please make sure that {prog_name} is in PATH.")
|
||||
}
|
||||
_ => {
|
||||
format!("Error spawning child process: {err}")
|
||||
}
|
||||
},
|
||||
_ => {
|
||||
format!("Error spawning child process: {err}")
|
||||
}
|
||||
};
|
||||
|
||||
ShellError::PluginFailedToLoad { msg: error_msg }
|
||||
})?;
|
||||
|
||||
let mut stdin_writer = child
|
||||
.stdin
|
||||
.take()
|
||||
.ok_or_else(|| ShellError::PluginFailedToLoad {
|
||||
msg: "plugin missing stdin writer".into(),
|
||||
})?;
|
||||
let mut stdout_reader = child
|
||||
.stdout
|
||||
.take()
|
||||
.ok_or_else(|| ShellError::PluginFailedToLoad {
|
||||
msg: "Plugin missing stdout reader".into(),
|
||||
})?;
|
||||
let encoding = get_plugin_encoding(&mut stdout_reader)?;
|
||||
|
||||
// Create message to plugin to indicate that signature is required and
|
||||
// send call to plugin asking for signature
|
||||
let encoding_clone = encoding.clone();
|
||||
// If the child process fills its stdout buffer, it may end up waiting until the parent
|
||||
// reads the stdout, and not be able to read stdin in the meantime, causing a deadlock.
|
||||
// Writing from another thread ensures that stdout is being read at the same time, avoiding the problem.
|
||||
std::thread::spawn(move || {
|
||||
encoding_clone.encode_call(&PluginCall::Signature, &mut stdin_writer)
|
||||
});
|
||||
|
||||
// deserialize response from plugin to extract the signature
|
||||
let reader = stdout_reader;
|
||||
let mut buf_read = BufReader::with_capacity(OUTPUT_BUFFER_SIZE, reader);
|
||||
let response = encoding.decode_response(&mut buf_read)?;
|
||||
|
||||
let signatures = match response {
|
||||
PluginResponse::Signature(sign) => Ok(sign),
|
||||
PluginResponse::Error(err) => Err(err.into()),
|
||||
_ => Err(ShellError::PluginFailedToLoad {
|
||||
msg: "Plugin missing signature".into(),
|
||||
}),
|
||||
}?;
|
||||
|
||||
match child.wait() {
|
||||
Ok(_) => Ok(signatures),
|
||||
Err(err) => Err(ShellError::PluginFailedToLoad {
|
||||
msg: format!("{err}"),
|
||||
}),
|
||||
}
|
||||
Arc::new(PluginIdentity::new(path, shell.map(|s| s.to_owned())))
|
||||
.spawn(current_envs)?
|
||||
.get_signature()
|
||||
}
|
||||
|
||||
/// The basic API for a Nushell plugin
|
||||
@ -199,6 +179,9 @@ pub fn get_signature(
|
||||
/// This is the trait that Nushell plugins must implement. The methods defined on
|
||||
/// `Plugin` are invoked by [serve_plugin] during plugin registration and execution.
|
||||
///
|
||||
/// If large amounts of data are expected to need to be received or produced, it may be more
|
||||
/// appropriate to implement [StreamingPlugin] instead.
|
||||
///
|
||||
/// # Examples
|
||||
/// Basic usage:
|
||||
/// ```
|
||||
@ -224,6 +207,10 @@ pub fn get_signature(
|
||||
/// Ok(Value::string("Hello, World!".to_owned(), call.head))
|
||||
/// }
|
||||
/// }
|
||||
///
|
||||
/// # fn main() {
|
||||
/// # serve_plugin(&mut HelloPlugin{}, MsgPackSerializer)
|
||||
/// # }
|
||||
/// ```
|
||||
pub trait Plugin {
|
||||
/// The signature of the plugin
|
||||
@ -244,6 +231,9 @@ pub trait Plugin {
|
||||
/// invoked command will be passed in via this argument. The `call` contains
|
||||
/// metadata describing how the plugin was invoked and `input` contains the structured
|
||||
/// data passed to the command implemented by this [Plugin].
|
||||
///
|
||||
/// This variant does not support streaming. Consider implementing [StreamingPlugin] instead
|
||||
/// if streaming is desired.
|
||||
fn run(
|
||||
&mut self,
|
||||
name: &str,
|
||||
@ -253,13 +243,115 @@ pub trait Plugin {
|
||||
) -> Result<Value, LabeledError>;
|
||||
}
|
||||
|
||||
/// The streaming API for a Nushell plugin
|
||||
///
|
||||
/// This is a more low-level version of the [Plugin] trait that supports operating on streams of
|
||||
/// data. If you don't need to operate on streams, consider using that trait instead.
|
||||
///
|
||||
/// The methods defined on `StreamingPlugin` are invoked by [serve_plugin] during plugin
|
||||
/// registration and execution.
|
||||
///
|
||||
/// # Examples
|
||||
/// Basic usage:
|
||||
/// ```
|
||||
/// # use nu_plugin::*;
|
||||
/// # use nu_protocol::{PluginSignature, PipelineData, Type, Value};
|
||||
/// struct LowercasePlugin;
|
||||
///
|
||||
/// impl StreamingPlugin for LowercasePlugin {
|
||||
/// fn signature(&self) -> Vec<PluginSignature> {
|
||||
/// let sig = PluginSignature::build("lowercase")
|
||||
/// .usage("Convert each string in a stream to lowercase")
|
||||
/// .input_output_type(Type::List(Type::String.into()), Type::List(Type::String.into()));
|
||||
///
|
||||
/// vec![sig]
|
||||
/// }
|
||||
///
|
||||
/// fn run(
|
||||
/// &mut self,
|
||||
/// name: &str,
|
||||
/// config: &Option<Value>,
|
||||
/// call: &EvaluatedCall,
|
||||
/// input: PipelineData,
|
||||
/// ) -> Result<PipelineData, LabeledError> {
|
||||
/// let span = call.head;
|
||||
/// Ok(input.map(move |value| {
|
||||
/// value.as_str()
|
||||
/// .map(|string| Value::string(string.to_lowercase(), span))
|
||||
/// // Errors in a stream should be returned as values.
|
||||
/// .unwrap_or_else(|err| Value::error(err, span))
|
||||
/// }, None)?)
|
||||
/// }
|
||||
/// }
|
||||
///
|
||||
/// # fn main() {
|
||||
/// # serve_plugin(&mut LowercasePlugin{}, MsgPackSerializer)
|
||||
/// # }
|
||||
/// ```
|
||||
pub trait StreamingPlugin {
|
||||
/// The signature of the plugin
|
||||
///
|
||||
/// This method returns the [PluginSignature]s that describe the capabilities
|
||||
/// of this plugin. Since a single plugin executable can support multiple invocation
|
||||
/// patterns we return a `Vec` of signatures.
|
||||
fn signature(&self) -> Vec<PluginSignature>;
|
||||
|
||||
/// Perform the actual behavior of the plugin
|
||||
///
|
||||
/// The behavior of the plugin is defined by the implementation of this method.
|
||||
/// When Nushell invoked the plugin [serve_plugin] will call this method and
|
||||
/// print the serialized returned value or error to stdout, which Nushell will
|
||||
/// interpret.
|
||||
///
|
||||
/// The `name` is only relevant for plugins that implement multiple commands as the
|
||||
/// invoked command will be passed in via this argument. The `call` contains
|
||||
/// metadata describing how the plugin was invoked and `input` contains the structured
|
||||
/// data passed to the command implemented by this [Plugin].
|
||||
///
|
||||
/// This variant expects to receive and produce [PipelineData], which allows for stream-based
|
||||
/// handling of I/O. This is recommended if the plugin is expected to transform large lists or
|
||||
/// potentially large quantities of bytes. The API is more complex however, and [Plugin] is
|
||||
/// recommended instead if this is not a concern.
|
||||
fn run(
|
||||
&mut self,
|
||||
name: &str,
|
||||
config: &Option<Value>,
|
||||
call: &EvaluatedCall,
|
||||
input: PipelineData,
|
||||
) -> Result<PipelineData, LabeledError>;
|
||||
}
|
||||
|
||||
/// All [Plugin]s can be used as [StreamingPlugin]s, but input streams will be fully consumed
|
||||
/// before the plugin runs.
|
||||
impl<T: Plugin> StreamingPlugin for T {
|
||||
fn signature(&self) -> Vec<PluginSignature> {
|
||||
<Self as Plugin>::signature(self)
|
||||
}
|
||||
|
||||
fn run(
|
||||
&mut self,
|
||||
name: &str,
|
||||
config: &Option<Value>,
|
||||
call: &EvaluatedCall,
|
||||
input: PipelineData,
|
||||
) -> Result<PipelineData, LabeledError> {
|
||||
// Unwrap the PipelineData from input, consuming the potential stream, and pass it to the
|
||||
// simpler signature in Plugin
|
||||
let span = input.span().unwrap_or(call.head);
|
||||
let input_value = input.into_value(span);
|
||||
// Wrap the output in PipelineData::Value
|
||||
<Self as Plugin>::run(self, name, config, call, &input_value)
|
||||
.map(|value| PipelineData::Value(value, None))
|
||||
}
|
||||
}
|
||||
|
||||
/// Function used to implement the communication protocol between
|
||||
/// nushell and an external plugin.
|
||||
/// nushell and an external plugin. Both [Plugin] and [StreamingPlugin] are supported.
|
||||
///
|
||||
/// When creating a new plugin this function is typically used as the main entry
|
||||
/// point for the plugin, e.g.
|
||||
///
|
||||
/// ```
|
||||
/// ```rust,no_run
|
||||
/// # use nu_plugin::*;
|
||||
/// # use nu_protocol::{PluginSignature, Value};
|
||||
/// # struct MyPlugin;
|
||||
@ -273,22 +365,42 @@ pub trait Plugin {
|
||||
/// serve_plugin(&mut MyPlugin::new(), MsgPackSerializer)
|
||||
/// }
|
||||
/// ```
|
||||
///
|
||||
/// The object that is expected to be received by nushell is the `PluginResponse` struct.
|
||||
/// The `serve_plugin` function should ensure that it is encoded correctly and sent
|
||||
/// to StdOut for nushell to decode and and present its result.
|
||||
pub fn serve_plugin(plugin: &mut impl Plugin, encoder: impl PluginEncoder) {
|
||||
if env::args().any(|arg| (arg == "-h") || (arg == "--help")) {
|
||||
pub fn serve_plugin(plugin: &mut impl StreamingPlugin, encoder: impl PluginEncoder + 'static) {
|
||||
let mut args = env::args().skip(1);
|
||||
let number_of_args = args.len();
|
||||
let first_arg = args.next();
|
||||
|
||||
if number_of_args == 0
|
||||
|| first_arg
|
||||
.as_ref()
|
||||
.is_some_and(|arg| arg == "-h" || arg == "--help")
|
||||
{
|
||||
print_help(plugin, encoder);
|
||||
std::process::exit(0)
|
||||
}
|
||||
|
||||
// Must pass --stdio for plugin execution. Any other arg is an error to give us options in the
|
||||
// future.
|
||||
if number_of_args > 1 || !first_arg.is_some_and(|arg| arg == "--stdio") {
|
||||
eprintln!(
|
||||
"{}: This plugin must be run from within Nushell.",
|
||||
env::current_exe()
|
||||
.map(|path| path.display().to_string())
|
||||
.unwrap_or_else(|_| "plugin".into())
|
||||
);
|
||||
eprintln!(
|
||||
"If you are running from Nushell, this plugin may be incompatible with the \
|
||||
version of nushell you are using."
|
||||
);
|
||||
std::process::exit(1)
|
||||
}
|
||||
|
||||
// tell nushell encoding.
|
||||
//
|
||||
// 1 byte
|
||||
// encoding format: | content-length | content |
|
||||
let mut stdout = std::io::stdout();
|
||||
{
|
||||
let mut stdout = std::io::stdout();
|
||||
let encoding = encoder.name();
|
||||
let length = encoding.len() as u8;
|
||||
let mut encoding_content: Vec<u8> = encoding.as_bytes().to_vec();
|
||||
@ -301,91 +413,120 @@ pub fn serve_plugin(plugin: &mut impl Plugin, encoder: impl PluginEncoder) {
|
||||
.expect("Failed to tell nushell my encoding when flushing stdout");
|
||||
}
|
||||
|
||||
let mut stdin_buf = BufReader::with_capacity(OUTPUT_BUFFER_SIZE, std::io::stdin());
|
||||
let plugin_call = encoder.decode_call(&mut stdin_buf);
|
||||
let mut manager = EngineInterfaceManager::new((stdout, encoder.clone()));
|
||||
let call_receiver = manager
|
||||
.take_plugin_call_receiver()
|
||||
// This expect should be totally safe, as we just created the manager
|
||||
.expect("take_plugin_call_receiver returned None");
|
||||
|
||||
match plugin_call {
|
||||
Err(err) => {
|
||||
let response = PluginResponse::Error(err.into());
|
||||
encoder
|
||||
.encode_response(&response, &mut std::io::stdout())
|
||||
.expect("Error encoding response");
|
||||
}
|
||||
Ok(plugin_call) => {
|
||||
match plugin_call {
|
||||
// Sending the signature back to nushell to create the declaration definition
|
||||
PluginCall::Signature => {
|
||||
let response = PluginResponse::Signature(plugin.signature());
|
||||
encoder
|
||||
.encode_response(&response, &mut std::io::stdout())
|
||||
.expect("Error encoding response");
|
||||
}
|
||||
PluginCall::CallInfo(call_info) => {
|
||||
let input = match call_info.input {
|
||||
CallInput::Value(value) => Ok(value),
|
||||
CallInput::Data(plugin_data) => {
|
||||
bincode::deserialize::<Box<dyn CustomValue>>(&plugin_data.data)
|
||||
.map(|custom_value| {
|
||||
Value::custom_value(custom_value, plugin_data.span)
|
||||
})
|
||||
.map_err(|err| ShellError::PluginFailedToDecode {
|
||||
msg: err.to_string(),
|
||||
})
|
||||
}
|
||||
};
|
||||
// We need to hold on to the interface to keep the manager alive. We can drop it at the end
|
||||
let interface = manager.get_interface();
|
||||
|
||||
let value = match input {
|
||||
Ok(input) => {
|
||||
plugin.run(&call_info.name, &call_info.config, &call_info.call, &input)
|
||||
}
|
||||
Err(err) => Err(err.into()),
|
||||
};
|
||||
// Try an operation that could result in ShellError. Exit if an I/O error is encountered.
|
||||
// Try to report the error to nushell otherwise, and failing that, panic.
|
||||
macro_rules! try_or_report {
|
||||
($interface:expr, $expr:expr) => (match $expr {
|
||||
Ok(val) => val,
|
||||
// Just exit if there is an I/O error. Most likely this just means that nushell
|
||||
// interrupted us. If not, the error probably happened on the other side too, so we
|
||||
// don't need to also report it.
|
||||
Err(ShellError::IOError { .. }) => std::process::exit(1),
|
||||
// If there is another error, try to send it to nushell and then exit.
|
||||
Err(err) => {
|
||||
let _ = $interface.write_response(Err(err.clone())).unwrap_or_else(|_| {
|
||||
// If we can't send it to nushell, panic with it so at least we get the output
|
||||
panic!("{}", err)
|
||||
});
|
||||
std::process::exit(1)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
let response = match value {
|
||||
Ok(value) => {
|
||||
let span = value.span();
|
||||
match value {
|
||||
Value::CustomValue { val, .. } => match bincode::serialize(&val) {
|
||||
Ok(data) => {
|
||||
let name = val.value_string();
|
||||
PluginResponse::PluginData(name, PluginData { data, span })
|
||||
}
|
||||
Err(err) => PluginResponse::Error(
|
||||
ShellError::PluginFailedToEncode {
|
||||
msg: err.to_string(),
|
||||
}
|
||||
.into(),
|
||||
),
|
||||
},
|
||||
value => PluginResponse::Value(Box::new(value)),
|
||||
}
|
||||
}
|
||||
Err(err) => PluginResponse::Error(err),
|
||||
};
|
||||
encoder
|
||||
.encode_response(&response, &mut std::io::stdout())
|
||||
.expect("Error encoding response");
|
||||
}
|
||||
PluginCall::CollapseCustomValue(plugin_data) => {
|
||||
let response = bincode::deserialize::<Box<dyn CustomValue>>(&plugin_data.data)
|
||||
.map_err(|err| ShellError::PluginFailedToDecode {
|
||||
msg: err.to_string(),
|
||||
})
|
||||
.and_then(|val| val.to_base_value(plugin_data.span))
|
||||
.map(Box::new)
|
||||
.map_err(LabeledError::from)
|
||||
.map_or_else(PluginResponse::Error, PluginResponse::Value);
|
||||
// Send Hello message
|
||||
try_or_report!(interface, interface.hello());
|
||||
|
||||
encoder
|
||||
.encode_response(&response, &mut std::io::stdout())
|
||||
.expect("Error encoding response");
|
||||
// Spawn the reader thread
|
||||
std::thread::Builder::new()
|
||||
.name("engine interface reader".into())
|
||||
.spawn(move || {
|
||||
if let Err(err) = manager.consume_all((std::io::stdin().lock(), encoder)) {
|
||||
// Do our best to report the read error. Most likely there is some kind of
|
||||
// incompatibility between the plugin and nushell, so it makes more sense to try to
|
||||
// report it on stderr than to send something.
|
||||
let exe = std::env::current_exe().ok();
|
||||
|
||||
let plugin_name: String = exe
|
||||
.as_ref()
|
||||
.and_then(|path| path.file_stem())
|
||||
.map(|stem| stem.to_string_lossy().into_owned())
|
||||
.map(|stem| {
|
||||
stem.strip_prefix("nu_plugin_")
|
||||
.map(|s| s.to_owned())
|
||||
.unwrap_or(stem)
|
||||
})
|
||||
.unwrap_or_else(|| "(unknown)".into());
|
||||
|
||||
eprintln!("Plugin `{plugin_name}` read error: {err}");
|
||||
std::process::exit(1);
|
||||
}
|
||||
})
|
||||
.expect("failed to spawn thread");
|
||||
|
||||
for plugin_call in call_receiver {
|
||||
match plugin_call {
|
||||
// Sending the signature back to nushell to create the declaration definition
|
||||
ReceivedPluginCall::Signature { engine } => {
|
||||
try_or_report!(engine, engine.write_signature(plugin.signature()));
|
||||
}
|
||||
// Run the plugin, handling any input or output streams
|
||||
ReceivedPluginCall::Run {
|
||||
engine,
|
||||
call:
|
||||
CallInfo {
|
||||
name,
|
||||
config,
|
||||
call,
|
||||
input,
|
||||
},
|
||||
} => {
|
||||
let result = plugin.run(&name, &config, &call, input);
|
||||
let write_result = engine
|
||||
.write_response(result)
|
||||
.map(|writer| writer.write_background());
|
||||
try_or_report!(engine, write_result);
|
||||
}
|
||||
// Do an operation on a custom value
|
||||
ReceivedPluginCall::CustomValueOp {
|
||||
engine,
|
||||
custom_value,
|
||||
op,
|
||||
} => {
|
||||
let local_value = try_or_report!(
|
||||
engine,
|
||||
custom_value
|
||||
.item
|
||||
.deserialize_to_custom_value(custom_value.span)
|
||||
);
|
||||
match op {
|
||||
CustomValueOp::ToBaseValue => {
|
||||
let result = local_value
|
||||
.to_base_value(custom_value.span)
|
||||
.map(|value| PipelineData::Value(value, None));
|
||||
let write_result = engine
|
||||
.write_response(result)
|
||||
.map(|writer| writer.write_background());
|
||||
try_or_report!(engine, write_result);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// This will stop the manager
|
||||
drop(interface);
|
||||
}
|
||||
|
||||
fn print_help(plugin: &mut impl Plugin, encoder: impl PluginEncoder) {
|
||||
fn print_help(plugin: &mut impl StreamingPlugin, encoder: impl PluginEncoder) {
|
||||
println!("Nushell Plugin");
|
||||
println!("Encoder: {}", encoder.name());
|
||||
|
||||
|
@ -1,33 +1,201 @@
|
||||
mod evaluated_call;
|
||||
mod plugin_custom_value;
|
||||
mod plugin_data;
|
||||
mod protocol_info;
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests;
|
||||
|
||||
#[cfg(test)]
|
||||
pub(crate) mod test_util;
|
||||
|
||||
pub use evaluated_call::EvaluatedCall;
|
||||
use nu_protocol::{PluginSignature, ShellError, Span, Value};
|
||||
use nu_protocol::{PluginSignature, RawStream, ShellError, Span, Spanned, Value};
|
||||
pub use plugin_custom_value::PluginCustomValue;
|
||||
pub use plugin_data::PluginData;
|
||||
pub(crate) use protocol_info::ProtocolInfo;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
pub struct CallInfo {
|
||||
#[cfg(test)]
|
||||
pub(crate) use protocol_info::Protocol;
|
||||
|
||||
/// A sequential identifier for a stream
|
||||
pub type StreamId = usize;
|
||||
|
||||
/// A sequential identifier for a [`PluginCall`]
|
||||
pub type PluginCallId = usize;
|
||||
|
||||
/// Information about a plugin command invocation. This includes an [`EvaluatedCall`] as a
|
||||
/// serializable representation of [`nu_protocol::ast::Call`]. The type parameter determines
|
||||
/// the input type.
|
||||
#[derive(Serialize, Deserialize, Debug, Clone)]
|
||||
pub struct CallInfo<D> {
|
||||
/// The name of the command to be run
|
||||
pub name: String,
|
||||
/// Information about the invocation, including arguments
|
||||
pub call: EvaluatedCall,
|
||||
pub input: CallInput,
|
||||
/// Pipeline input. This is usually [`nu_protocol::PipelineData`] or [`PipelineDataHeader`]
|
||||
pub input: D,
|
||||
/// Plugin configuration, if available
|
||||
pub config: Option<Value>,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, PartialEq)]
|
||||
pub enum CallInput {
|
||||
/// The initial (and perhaps only) part of any [`nu_protocol::PipelineData`] sent over the wire.
|
||||
///
|
||||
/// This may contain a single value, or may initiate a stream with a [`StreamId`].
|
||||
#[derive(Serialize, Deserialize, Debug, PartialEq, Clone)]
|
||||
pub enum PipelineDataHeader {
|
||||
/// No input
|
||||
Empty,
|
||||
/// A single value
|
||||
Value(Value),
|
||||
Data(PluginData),
|
||||
/// Initiate [`nu_protocol::PipelineData::ListStream`].
|
||||
///
|
||||
/// Items are sent via [`StreamData`]
|
||||
ListStream(ListStreamInfo),
|
||||
/// Initiate [`nu_protocol::PipelineData::ExternalStream`].
|
||||
///
|
||||
/// Items are sent via [`StreamData`]
|
||||
ExternalStream(ExternalStreamInfo),
|
||||
}
|
||||
|
||||
// Information sent to the plugin
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
pub enum PluginCall {
|
||||
/// Additional information about list (value) streams
|
||||
#[derive(Serialize, Deserialize, Debug, PartialEq, Eq, Clone)]
|
||||
pub struct ListStreamInfo {
|
||||
pub id: StreamId,
|
||||
}
|
||||
|
||||
/// Additional information about external streams
|
||||
#[derive(Serialize, Deserialize, Debug, PartialEq, Eq, Clone)]
|
||||
pub struct ExternalStreamInfo {
|
||||
pub span: Span,
|
||||
pub stdout: Option<RawStreamInfo>,
|
||||
pub stderr: Option<RawStreamInfo>,
|
||||
pub exit_code: Option<ListStreamInfo>,
|
||||
pub trim_end_newline: bool,
|
||||
}
|
||||
|
||||
/// Additional information about raw (byte) streams
|
||||
#[derive(Serialize, Deserialize, Debug, PartialEq, Eq, Clone)]
|
||||
pub struct RawStreamInfo {
|
||||
pub id: StreamId,
|
||||
pub is_binary: bool,
|
||||
pub known_size: Option<u64>,
|
||||
}
|
||||
|
||||
impl RawStreamInfo {
|
||||
pub(crate) fn new(id: StreamId, stream: &RawStream) -> Self {
|
||||
RawStreamInfo {
|
||||
id,
|
||||
is_binary: stream.is_binary,
|
||||
known_size: stream.known_size,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Calls that a plugin can execute. The type parameter determines the input type.
|
||||
#[derive(Serialize, Deserialize, Debug, Clone)]
|
||||
pub enum PluginCall<D> {
|
||||
Signature,
|
||||
CallInfo(CallInfo),
|
||||
CollapseCustomValue(PluginData),
|
||||
Run(CallInfo<D>),
|
||||
CustomValueOp(Spanned<PluginCustomValue>, CustomValueOp),
|
||||
}
|
||||
|
||||
/// Operations supported for custom values.
|
||||
#[derive(Serialize, Deserialize, Debug, Clone)]
|
||||
pub enum CustomValueOp {
|
||||
/// [`to_base_value()`](nu_protocol::CustomValue::to_base_value)
|
||||
ToBaseValue,
|
||||
}
|
||||
|
||||
/// Any data sent to the plugin
|
||||
#[derive(Serialize, Deserialize, Debug, Clone)]
|
||||
pub enum PluginInput {
|
||||
/// This must be the first message. Indicates supported protocol
|
||||
Hello(ProtocolInfo),
|
||||
/// Execute a [`PluginCall`], such as `Run` or `Signature`. The ID should not have been used
|
||||
/// before.
|
||||
Call(PluginCallId, PluginCall<PipelineDataHeader>),
|
||||
/// Stream control or data message. Untagged to keep them as small as possible.
|
||||
///
|
||||
/// For example, `Stream(Ack(0))` is encoded as `{"Ack": 0}`
|
||||
#[serde(untagged)]
|
||||
Stream(StreamMessage),
|
||||
}
|
||||
|
||||
impl TryFrom<PluginInput> for StreamMessage {
|
||||
type Error = PluginInput;
|
||||
|
||||
fn try_from(msg: PluginInput) -> Result<StreamMessage, PluginInput> {
|
||||
match msg {
|
||||
PluginInput::Stream(stream_msg) => Ok(stream_msg),
|
||||
_ => Err(msg),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<StreamMessage> for PluginInput {
|
||||
fn from(stream_msg: StreamMessage) -> PluginInput {
|
||||
PluginInput::Stream(stream_msg)
|
||||
}
|
||||
}
|
||||
|
||||
/// A single item of stream data for a stream.
|
||||
#[derive(Serialize, Deserialize, Debug, Clone)]
|
||||
pub enum StreamData {
|
||||
List(Value),
|
||||
Raw(Result<Vec<u8>, ShellError>),
|
||||
}
|
||||
|
||||
impl From<Value> for StreamData {
|
||||
fn from(value: Value) -> Self {
|
||||
StreamData::List(value)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<Result<Vec<u8>, ShellError>> for StreamData {
|
||||
fn from(value: Result<Vec<u8>, ShellError>) -> Self {
|
||||
StreamData::Raw(value)
|
||||
}
|
||||
}
|
||||
|
||||
impl TryFrom<StreamData> for Value {
|
||||
type Error = ShellError;
|
||||
|
||||
fn try_from(data: StreamData) -> Result<Value, ShellError> {
|
||||
match data {
|
||||
StreamData::List(value) => Ok(value),
|
||||
StreamData::Raw(_) => Err(ShellError::PluginFailedToDecode {
|
||||
msg: "expected list stream data, found raw data".into(),
|
||||
}),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl TryFrom<StreamData> for Result<Vec<u8>, ShellError> {
|
||||
type Error = ShellError;
|
||||
|
||||
fn try_from(data: StreamData) -> Result<Result<Vec<u8>, ShellError>, ShellError> {
|
||||
match data {
|
||||
StreamData::Raw(value) => Ok(value),
|
||||
StreamData::List(_) => Err(ShellError::PluginFailedToDecode {
|
||||
msg: "expected raw stream data, found list data".into(),
|
||||
}),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// A stream control or data message.
|
||||
#[derive(Serialize, Deserialize, Debug, Clone)]
|
||||
pub enum StreamMessage {
|
||||
/// Append data to the stream. Sent by the stream producer.
|
||||
Data(StreamId, StreamData),
|
||||
/// End of stream. Sent by the stream producer.
|
||||
End(StreamId),
|
||||
/// Notify that the read end of the stream has closed, and further messages should not be
|
||||
/// sent. Sent by the stream consumer.
|
||||
Drop(StreamId),
|
||||
/// Acknowledge that a message has been consumed. This is used to implement flow control by
|
||||
/// the stream producer. Sent by the stream consumer.
|
||||
Ack(StreamId),
|
||||
}
|
||||
|
||||
/// An error message with debugging information that can be passed to Nushell from the plugin
|
||||
@ -36,7 +204,7 @@ pub enum PluginCall {
|
||||
/// a [Plugin](crate::Plugin)'s [`run`](crate::Plugin::run()) method. It contains
|
||||
/// the error message along with optional [Span] data to support highlighting in the
|
||||
/// shell.
|
||||
#[derive(Clone, PartialEq, Eq, Serialize, Deserialize, Debug)]
|
||||
#[derive(Serialize, Deserialize, PartialEq, Eq, Debug, Clone)]
|
||||
pub struct LabeledError {
|
||||
/// The name of the error
|
||||
pub label: String,
|
||||
@ -48,81 +216,108 @@ pub struct LabeledError {
|
||||
|
||||
impl From<LabeledError> for ShellError {
|
||||
fn from(error: LabeledError) -> Self {
|
||||
match error.span {
|
||||
Some(span) => ShellError::GenericError {
|
||||
if error.span.is_some() {
|
||||
ShellError::GenericError {
|
||||
error: error.label,
|
||||
msg: error.msg,
|
||||
span: Some(span),
|
||||
span: error.span,
|
||||
help: None,
|
||||
inner: vec![],
|
||||
},
|
||||
None => ShellError::GenericError {
|
||||
}
|
||||
} else {
|
||||
ShellError::GenericError {
|
||||
error: error.label,
|
||||
msg: "".into(),
|
||||
span: None,
|
||||
help: Some(error.msg),
|
||||
help: (!error.msg.is_empty()).then_some(error.msg),
|
||||
inner: vec![],
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<ShellError> for LabeledError {
|
||||
fn from(error: ShellError) -> Self {
|
||||
match error {
|
||||
ShellError::GenericError {
|
||||
error: label,
|
||||
msg,
|
||||
span,
|
||||
..
|
||||
} => LabeledError { label, msg, span },
|
||||
ShellError::CantConvert {
|
||||
to_type: expected,
|
||||
from_type: input,
|
||||
span,
|
||||
help: _help,
|
||||
} => LabeledError {
|
||||
label: format!("Can't convert to {expected}"),
|
||||
msg: format!("can't convert from {input} to {expected}"),
|
||||
use miette::Diagnostic;
|
||||
// This is not perfect - we can only take the first labeled span as that's all we have
|
||||
// space for.
|
||||
if let Some(labeled_span) = error.labels().and_then(|mut iter| iter.nth(0)) {
|
||||
let offset = labeled_span.offset();
|
||||
let span = Span::new(offset, offset + labeled_span.len());
|
||||
LabeledError {
|
||||
label: error.to_string(),
|
||||
msg: labeled_span
|
||||
.label()
|
||||
.map(|label| label.to_owned())
|
||||
.unwrap_or_else(|| "".into()),
|
||||
span: Some(span),
|
||||
},
|
||||
ShellError::DidYouMean { suggestion, span } => LabeledError {
|
||||
label: "Name not found".into(),
|
||||
msg: format!("did you mean '{suggestion}'?"),
|
||||
span: Some(span),
|
||||
},
|
||||
ShellError::PluginFailedToLoad { msg } => LabeledError {
|
||||
label: "Plugin failed to load".into(),
|
||||
msg,
|
||||
}
|
||||
} else {
|
||||
LabeledError {
|
||||
label: error.to_string(),
|
||||
msg: error
|
||||
.help()
|
||||
.map(|help| help.to_string())
|
||||
.unwrap_or_else(|| "".into()),
|
||||
span: None,
|
||||
},
|
||||
ShellError::PluginFailedToEncode { msg } => LabeledError {
|
||||
label: "Plugin failed to encode".into(),
|
||||
msg,
|
||||
span: None,
|
||||
},
|
||||
ShellError::PluginFailedToDecode { msg } => LabeledError {
|
||||
label: "Plugin failed to decode".into(),
|
||||
msg,
|
||||
span: None,
|
||||
},
|
||||
err => LabeledError {
|
||||
label: "Error - Add to LabeledError From<ShellError>".into(),
|
||||
msg: err.to_string(),
|
||||
span: None,
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Information received from the plugin
|
||||
// Needs to be public to communicate with nu-parser but not typically
|
||||
// used by Plugin authors
|
||||
/// Response to a [`PluginCall`]. The type parameter determines the output type for pipeline data.
|
||||
///
|
||||
/// Note: exported for internal use, not public.
|
||||
#[derive(Serialize, Deserialize, Debug, Clone)]
|
||||
#[doc(hidden)]
|
||||
#[derive(Serialize, Deserialize)]
|
||||
pub enum PluginResponse {
|
||||
pub enum PluginCallResponse<D> {
|
||||
Error(LabeledError),
|
||||
Signature(Vec<PluginSignature>),
|
||||
Value(Box<Value>),
|
||||
PluginData(String, PluginData),
|
||||
PipelineData(D),
|
||||
}
|
||||
|
||||
impl PluginCallResponse<PipelineDataHeader> {
|
||||
/// Construct a plugin call response with a single value
|
||||
pub fn value(value: Value) -> PluginCallResponse<PipelineDataHeader> {
|
||||
if value.is_nothing() {
|
||||
PluginCallResponse::PipelineData(PipelineDataHeader::Empty)
|
||||
} else {
|
||||
PluginCallResponse::PipelineData(PipelineDataHeader::Value(value))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Information received from the plugin
|
||||
///
|
||||
/// Note: exported for internal use, not public.
|
||||
#[derive(Serialize, Deserialize, Debug, Clone)]
|
||||
#[doc(hidden)]
|
||||
pub enum PluginOutput {
|
||||
/// This must be the first message. Indicates supported protocol
|
||||
Hello(ProtocolInfo),
|
||||
/// A response to a [`PluginCall`]. The ID should be the same sent with the plugin call this
|
||||
/// is a response to
|
||||
CallResponse(PluginCallId, PluginCallResponse<PipelineDataHeader>),
|
||||
/// Stream control or data message. Untagged to keep them as small as possible.
|
||||
///
|
||||
/// For example, `Stream(Ack(0))` is encoded as `{"Ack": 0}`
|
||||
#[serde(untagged)]
|
||||
Stream(StreamMessage),
|
||||
}
|
||||
|
||||
impl TryFrom<PluginOutput> for StreamMessage {
|
||||
type Error = PluginOutput;
|
||||
|
||||
fn try_from(msg: PluginOutput) -> Result<StreamMessage, PluginOutput> {
|
||||
match msg {
|
||||
PluginOutput::Stream(stream_msg) => Ok(stream_msg),
|
||||
_ => Err(msg),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<StreamMessage> for PluginOutput {
|
||||
fn from(stream_msg: StreamMessage) -> PluginOutput {
|
||||
PluginOutput::Stream(stream_msg)
|
||||
}
|
||||
}
|
||||
|
@ -1,37 +1,39 @@
|
||||
use std::path::PathBuf;
|
||||
use std::sync::Arc;
|
||||
|
||||
use nu_protocol::{CustomValue, ShellError, Value};
|
||||
use serde::Serialize;
|
||||
use nu_protocol::{CustomValue, ShellError, Span, Spanned, Value};
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::plugin::{call_plugin, create_command, get_plugin_encoding};
|
||||
use crate::plugin::PluginIdentity;
|
||||
|
||||
use super::{PluginCall, PluginData, PluginResponse};
|
||||
#[cfg(test)]
|
||||
mod tests;
|
||||
|
||||
/// An opaque container for a custom value that is handled fully by a plugin
|
||||
///
|
||||
/// This is constructed by the main nushell engine when it receives [`PluginResponse::PluginData`]
|
||||
/// it stores that data as well as metadata related to the plugin to be able to call the plugin
|
||||
/// later.
|
||||
/// Since the data in it is opaque to the engine, there are only two final destinations for it:
|
||||
/// either it will be sent back to the plugin that generated it across a pipeline, or it will be
|
||||
/// sent to the plugin with a request to collapse it into a base value
|
||||
#[derive(Clone, Debug, Serialize)]
|
||||
/// This is the only type of custom value that is allowed to cross the plugin serialization
|
||||
/// boundary.
|
||||
///
|
||||
/// [`EngineInterface`](crate::interface::EngineInterface) is responsible for ensuring
|
||||
/// that local plugin custom values are converted to and from [`PluginCustomData`] on the boundary.
|
||||
///
|
||||
/// [`PluginInterface`](crate::interface::PluginInterface) is responsible for adding the
|
||||
/// appropriate [`PluginIdentity`](crate::plugin::PluginIdentity), ensuring that only
|
||||
/// [`PluginCustomData`] is contained within any values sent, and that the `source` of any
|
||||
/// values sent matches the plugin it is being sent to.
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
pub struct PluginCustomValue {
|
||||
/// The name of the custom value as defined by the plugin
|
||||
/// The name of the custom value as defined by the plugin (`value_string()`)
|
||||
pub name: String,
|
||||
/// The bincoded representation of the custom value on the plugin side
|
||||
pub data: Vec<u8>,
|
||||
pub filename: PathBuf,
|
||||
|
||||
// PluginCustomValue must implement Serialize because all CustomValues must implement Serialize
|
||||
// However, the main place where values are serialized and deserialized is when they are being
|
||||
// sent between plugins and nushell's main engine. PluginCustomValue is never meant to be sent
|
||||
// between that boundary
|
||||
#[serde(skip)]
|
||||
pub shell: Option<PathBuf>,
|
||||
#[serde(skip)]
|
||||
pub source: String,
|
||||
/// Which plugin the custom value came from. This is not defined on the plugin side. The engine
|
||||
/// side is responsible for maintaining it, and it is not sent over the serialization boundary.
|
||||
#[serde(skip, default)]
|
||||
pub source: Option<Arc<PluginIdentity>>,
|
||||
}
|
||||
|
||||
#[typetag::serde]
|
||||
impl CustomValue for PluginCustomValue {
|
||||
fn clone_value(&self, span: nu_protocol::Span) -> nu_protocol::Value {
|
||||
Value::custom_value(Box::new(self.clone()), span)
|
||||
@ -45,83 +47,295 @@ impl CustomValue for PluginCustomValue {
|
||||
&self,
|
||||
span: nu_protocol::Span,
|
||||
) -> Result<nu_protocol::Value, nu_protocol::ShellError> {
|
||||
let mut plugin_cmd = create_command(&self.filename, self.shell.as_deref());
|
||||
|
||||
let mut child = plugin_cmd.spawn().map_err(|err| ShellError::GenericError {
|
||||
let wrap_err = |err: ShellError| ShellError::GenericError {
|
||||
error: format!(
|
||||
"Unable to spawn plugin for {} to get base value",
|
||||
"Unable to spawn plugin `{}` to get base value",
|
||||
self.source
|
||||
.as_ref()
|
||||
.map(|s| s.plugin_name.as_str())
|
||||
.unwrap_or("<unknown>")
|
||||
),
|
||||
msg: format!("{err}"),
|
||||
msg: err.to_string(),
|
||||
span: Some(span),
|
||||
help: None,
|
||||
inner: vec![],
|
||||
inner: vec![err],
|
||||
};
|
||||
|
||||
let identity = self.source.clone().ok_or_else(|| {
|
||||
wrap_err(ShellError::NushellFailed {
|
||||
msg: "The plugin source for the custom value was not set".into(),
|
||||
})
|
||||
})?;
|
||||
|
||||
let plugin_call = PluginCall::CollapseCustomValue(PluginData {
|
||||
data: self.data.clone(),
|
||||
span,
|
||||
});
|
||||
let encoding = {
|
||||
let stdout_reader = match &mut child.stdout {
|
||||
Some(out) => out,
|
||||
None => {
|
||||
return Err(ShellError::PluginFailedToLoad {
|
||||
msg: "Plugin missing stdout reader".into(),
|
||||
})
|
||||
}
|
||||
};
|
||||
get_plugin_encoding(stdout_reader)?
|
||||
};
|
||||
let empty_env: Option<(String, String)> = None;
|
||||
let plugin = identity.spawn(empty_env).map_err(wrap_err)?;
|
||||
|
||||
let response = call_plugin(&mut child, plugin_call, &encoding, span).map_err(|err| {
|
||||
ShellError::GenericError {
|
||||
error: format!(
|
||||
"Unable to decode call for {} to get base value",
|
||||
self.source
|
||||
),
|
||||
msg: format!("{err}"),
|
||||
span: Some(span),
|
||||
help: None,
|
||||
inner: vec![],
|
||||
}
|
||||
});
|
||||
|
||||
let value = match response {
|
||||
Ok(PluginResponse::Value(value)) => Ok(*value),
|
||||
Ok(PluginResponse::PluginData(..)) => Err(ShellError::GenericError {
|
||||
error: "Plugin misbehaving".into(),
|
||||
msg: "Plugin returned custom data as a response to a collapse call".into(),
|
||||
span: Some(span),
|
||||
help: None,
|
||||
inner: vec![],
|
||||
}),
|
||||
Ok(PluginResponse::Error(err)) => Err(err.into()),
|
||||
Ok(PluginResponse::Signature(..)) => Err(ShellError::GenericError {
|
||||
error: "Plugin missing value".into(),
|
||||
msg: "Received a signature from plugin instead of value".into(),
|
||||
span: Some(span),
|
||||
help: None,
|
||||
inner: vec![],
|
||||
}),
|
||||
Err(err) => Err(err),
|
||||
};
|
||||
|
||||
// We need to call .wait() on the child, or we'll risk summoning the zombie horde
|
||||
let _ = child.wait();
|
||||
|
||||
value
|
||||
plugin
|
||||
.custom_value_to_base_value(Spanned {
|
||||
item: self.clone(),
|
||||
span,
|
||||
})
|
||||
.map_err(wrap_err)
|
||||
}
|
||||
|
||||
fn as_any(&self) -> &dyn std::any::Any {
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
fn typetag_name(&self) -> &'static str {
|
||||
"PluginCustomValue"
|
||||
impl PluginCustomValue {
|
||||
/// Serialize a custom value into a [`PluginCustomValue`]. This should only be done on the
|
||||
/// plugin side.
|
||||
pub(crate) fn serialize_from_custom_value(
|
||||
custom_value: &dyn CustomValue,
|
||||
span: Span,
|
||||
) -> Result<PluginCustomValue, ShellError> {
|
||||
let name = custom_value.value_string();
|
||||
bincode::serialize(custom_value)
|
||||
.map(|data| PluginCustomValue {
|
||||
name,
|
||||
data,
|
||||
source: None,
|
||||
})
|
||||
.map_err(|err| ShellError::CustomValueFailedToEncode {
|
||||
msg: err.to_string(),
|
||||
span,
|
||||
})
|
||||
}
|
||||
|
||||
fn typetag_deserialize(&self) {
|
||||
unimplemented!("typetag_deserialize")
|
||||
/// Deserialize a [`PluginCustomValue`] into a `Box<dyn CustomValue>`. This should only be done
|
||||
/// on the plugin side.
|
||||
pub(crate) fn deserialize_to_custom_value(
|
||||
&self,
|
||||
span: Span,
|
||||
) -> Result<Box<dyn CustomValue>, ShellError> {
|
||||
bincode::deserialize::<Box<dyn CustomValue>>(&self.data).map_err(|err| {
|
||||
ShellError::CustomValueFailedToDecode {
|
||||
msg: err.to_string(),
|
||||
span,
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
/// Add a [`PluginIdentity`] to all [`PluginCustomValue`]s within a value, recursively.
|
||||
pub(crate) fn add_source(value: &mut Value, source: &Arc<PluginIdentity>) {
|
||||
let span = value.span();
|
||||
match value {
|
||||
// Set source on custom value
|
||||
Value::CustomValue { ref val, .. } => {
|
||||
if let Some(custom_value) = val.as_any().downcast_ref::<PluginCustomValue>() {
|
||||
// Since there's no `as_mut_any()`, we have to copy the whole thing
|
||||
let mut custom_value = custom_value.clone();
|
||||
custom_value.source = Some(source.clone());
|
||||
*value = Value::custom_value(Box::new(custom_value), span);
|
||||
}
|
||||
}
|
||||
// Any values that can contain other values need to be handled recursively
|
||||
Value::Range { ref mut val, .. } => {
|
||||
Self::add_source(&mut val.from, source);
|
||||
Self::add_source(&mut val.to, source);
|
||||
Self::add_source(&mut val.incr, source);
|
||||
}
|
||||
Value::Record { ref mut val, .. } => {
|
||||
for (_, rec_value) in val.iter_mut() {
|
||||
Self::add_source(rec_value, source);
|
||||
}
|
||||
}
|
||||
Value::List { ref mut vals, .. } => {
|
||||
for list_value in vals.iter_mut() {
|
||||
Self::add_source(list_value, source);
|
||||
}
|
||||
}
|
||||
// All of these don't contain other values
|
||||
Value::Bool { .. }
|
||||
| Value::Int { .. }
|
||||
| Value::Float { .. }
|
||||
| Value::Filesize { .. }
|
||||
| Value::Duration { .. }
|
||||
| Value::Date { .. }
|
||||
| Value::String { .. }
|
||||
| Value::Glob { .. }
|
||||
| Value::Block { .. }
|
||||
| Value::Closure { .. }
|
||||
| Value::Nothing { .. }
|
||||
| Value::Error { .. }
|
||||
| Value::Binary { .. }
|
||||
| Value::CellPath { .. } => (),
|
||||
// LazyRecord could generate other values, but we shouldn't be receiving it anyway
|
||||
//
|
||||
// It's better to handle this as a bug
|
||||
Value::LazyRecord { .. } => unimplemented!("add_source for LazyRecord"),
|
||||
}
|
||||
}
|
||||
|
||||
/// Check that all [`CustomValue`]s present within the `value` are [`PluginCustomValue`]s that
|
||||
/// come from the given `source`, and return an error if not.
|
||||
///
|
||||
/// This method will collapse `LazyRecord` in-place as necessary to make the guarantee,
|
||||
/// since `LazyRecord` could return something different the next time it is called.
|
||||
pub(crate) fn verify_source(
|
||||
value: &mut Value,
|
||||
source: &PluginIdentity,
|
||||
) -> Result<(), ShellError> {
|
||||
let span = value.span();
|
||||
match value {
|
||||
// Set source on custom value
|
||||
Value::CustomValue { val, .. } => {
|
||||
if let Some(custom_value) = val.as_any().downcast_ref::<PluginCustomValue>() {
|
||||
if custom_value.source.as_deref() == Some(source) {
|
||||
Ok(())
|
||||
} else {
|
||||
Err(ShellError::CustomValueIncorrectForPlugin {
|
||||
name: custom_value.name.clone(),
|
||||
span,
|
||||
dest_plugin: source.plugin_name.clone(),
|
||||
src_plugin: custom_value.source.as_ref().map(|s| s.plugin_name.clone()),
|
||||
})
|
||||
}
|
||||
} else {
|
||||
// Only PluginCustomValues can be sent
|
||||
Err(ShellError::CustomValueIncorrectForPlugin {
|
||||
name: val.value_string(),
|
||||
span,
|
||||
dest_plugin: source.plugin_name.clone(),
|
||||
src_plugin: None,
|
||||
})
|
||||
}
|
||||
}
|
||||
// Any values that can contain other values need to be handled recursively
|
||||
Value::Range { val, .. } => {
|
||||
Self::verify_source(&mut val.from, source)?;
|
||||
Self::verify_source(&mut val.to, source)?;
|
||||
Self::verify_source(&mut val.incr, source)
|
||||
}
|
||||
Value::Record { ref mut val, .. } => val
|
||||
.iter_mut()
|
||||
.try_for_each(|(_, rec_value)| Self::verify_source(rec_value, source)),
|
||||
Value::List { ref mut vals, .. } => vals
|
||||
.iter_mut()
|
||||
.try_for_each(|list_value| Self::verify_source(list_value, source)),
|
||||
// All of these don't contain other values
|
||||
Value::Bool { .. }
|
||||
| Value::Int { .. }
|
||||
| Value::Float { .. }
|
||||
| Value::Filesize { .. }
|
||||
| Value::Duration { .. }
|
||||
| Value::Date { .. }
|
||||
| Value::String { .. }
|
||||
| Value::Glob { .. }
|
||||
| Value::Block { .. }
|
||||
| Value::Closure { .. }
|
||||
| Value::Nothing { .. }
|
||||
| Value::Error { .. }
|
||||
| Value::Binary { .. }
|
||||
| Value::CellPath { .. } => Ok(()),
|
||||
// LazyRecord would be a problem for us, since it could return something else the next
|
||||
// time, and we have to collect it anyway to serialize it. Collect it in place, and then
|
||||
// verify the source of the result
|
||||
Value::LazyRecord { val, .. } => {
|
||||
*value = val.collect()?;
|
||||
Self::verify_source(value, source)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Convert all plugin-native custom values to [`PluginCustomValue`] within the given `value`,
|
||||
/// recursively. This should only be done on the plugin side.
|
||||
pub(crate) fn serialize_custom_values_in(value: &mut Value) -> Result<(), ShellError> {
|
||||
let span = value.span();
|
||||
match value {
|
||||
Value::CustomValue { ref val, .. } => {
|
||||
if val.as_any().downcast_ref::<PluginCustomValue>().is_some() {
|
||||
// Already a PluginCustomValue
|
||||
Ok(())
|
||||
} else {
|
||||
let serialized = Self::serialize_from_custom_value(&**val, span)?;
|
||||
*value = Value::custom_value(Box::new(serialized), span);
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
// Any values that can contain other values need to be handled recursively
|
||||
Value::Range { ref mut val, .. } => {
|
||||
Self::serialize_custom_values_in(&mut val.from)?;
|
||||
Self::serialize_custom_values_in(&mut val.to)?;
|
||||
Self::serialize_custom_values_in(&mut val.incr)
|
||||
}
|
||||
Value::Record { ref mut val, .. } => val
|
||||
.iter_mut()
|
||||
.try_for_each(|(_, rec_value)| Self::serialize_custom_values_in(rec_value)),
|
||||
Value::List { ref mut vals, .. } => vals
|
||||
.iter_mut()
|
||||
.try_for_each(Self::serialize_custom_values_in),
|
||||
// All of these don't contain other values
|
||||
Value::Bool { .. }
|
||||
| Value::Int { .. }
|
||||
| Value::Float { .. }
|
||||
| Value::Filesize { .. }
|
||||
| Value::Duration { .. }
|
||||
| Value::Date { .. }
|
||||
| Value::String { .. }
|
||||
| Value::Glob { .. }
|
||||
| Value::Block { .. }
|
||||
| Value::Closure { .. }
|
||||
| Value::Nothing { .. }
|
||||
| Value::Error { .. }
|
||||
| Value::Binary { .. }
|
||||
| Value::CellPath { .. } => Ok(()),
|
||||
// Collect any lazy records that exist and try again
|
||||
Value::LazyRecord { val, .. } => {
|
||||
*value = val.collect()?;
|
||||
Self::serialize_custom_values_in(value)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Convert all [`PluginCustomValue`]s to plugin-native custom values within the given `value`,
|
||||
/// recursively. This should only be done on the plugin side.
|
||||
pub(crate) fn deserialize_custom_values_in(value: &mut Value) -> Result<(), ShellError> {
|
||||
let span = value.span();
|
||||
match value {
|
||||
Value::CustomValue { ref val, .. } => {
|
||||
if let Some(val) = val.as_any().downcast_ref::<PluginCustomValue>() {
|
||||
let deserialized = val.deserialize_to_custom_value(span)?;
|
||||
*value = Value::custom_value(deserialized, span);
|
||||
Ok(())
|
||||
} else {
|
||||
// Already not a PluginCustomValue
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
// Any values that can contain other values need to be handled recursively
|
||||
Value::Range { ref mut val, .. } => {
|
||||
Self::deserialize_custom_values_in(&mut val.from)?;
|
||||
Self::deserialize_custom_values_in(&mut val.to)?;
|
||||
Self::deserialize_custom_values_in(&mut val.incr)
|
||||
}
|
||||
Value::Record { ref mut val, .. } => val
|
||||
.iter_mut()
|
||||
.try_for_each(|(_, rec_value)| Self::deserialize_custom_values_in(rec_value)),
|
||||
Value::List { ref mut vals, .. } => vals
|
||||
.iter_mut()
|
||||
.try_for_each(Self::deserialize_custom_values_in),
|
||||
// All of these don't contain other values
|
||||
Value::Bool { .. }
|
||||
| Value::Int { .. }
|
||||
| Value::Float { .. }
|
||||
| Value::Filesize { .. }
|
||||
| Value::Duration { .. }
|
||||
| Value::Date { .. }
|
||||
| Value::String { .. }
|
||||
| Value::Glob { .. }
|
||||
| Value::Block { .. }
|
||||
| Value::Closure { .. }
|
||||
| Value::Nothing { .. }
|
||||
| Value::Error { .. }
|
||||
| Value::Binary { .. }
|
||||
| Value::CellPath { .. } => Ok(()),
|
||||
// Collect any lazy records that exist and try again
|
||||
Value::LazyRecord { val, .. } => {
|
||||
*value = val.collect()?;
|
||||
Self::deserialize_custom_values_in(value)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
492
crates/nu-plugin/src/protocol/plugin_custom_value/tests.rs
Normal file
492
crates/nu-plugin/src/protocol/plugin_custom_value/tests.rs
Normal file
@ -0,0 +1,492 @@
|
||||
use nu_protocol::{ast::RangeInclusion, record, CustomValue, Range, ShellError, Span, Value};
|
||||
|
||||
use crate::{
|
||||
plugin::PluginIdentity,
|
||||
protocol::test_util::{
|
||||
expected_test_custom_value, test_plugin_custom_value, test_plugin_custom_value_with_source,
|
||||
TestCustomValue,
|
||||
},
|
||||
};
|
||||
|
||||
use super::PluginCustomValue;
|
||||
|
||||
#[test]
|
||||
fn serialize_deserialize() -> Result<(), ShellError> {
|
||||
let original_value = TestCustomValue(32);
|
||||
let span = Span::test_data();
|
||||
let serialized = PluginCustomValue::serialize_from_custom_value(&original_value, span)?;
|
||||
assert_eq!(original_value.value_string(), serialized.name);
|
||||
assert!(serialized.source.is_none());
|
||||
let deserialized = serialized.deserialize_to_custom_value(span)?;
|
||||
let downcasted = deserialized
|
||||
.as_any()
|
||||
.downcast_ref::<TestCustomValue>()
|
||||
.expect("failed to downcast: not TestCustomValue");
|
||||
assert_eq!(original_value, *downcasted);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn expected_serialize_output() -> Result<(), ShellError> {
|
||||
let original_value = expected_test_custom_value();
|
||||
let span = Span::test_data();
|
||||
let serialized = PluginCustomValue::serialize_from_custom_value(&original_value, span)?;
|
||||
assert_eq!(
|
||||
test_plugin_custom_value().data,
|
||||
serialized.data,
|
||||
"The bincode configuration is probably different from what we expected. \
|
||||
Fix test_plugin_custom_value() to match it"
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn add_source_at_root() -> Result<(), ShellError> {
|
||||
let mut val = Value::test_custom_value(Box::new(test_plugin_custom_value()));
|
||||
let source = PluginIdentity::new_fake("foo");
|
||||
PluginCustomValue::add_source(&mut val, &source);
|
||||
|
||||
let custom_value = val.as_custom_value()?;
|
||||
let plugin_custom_value: &PluginCustomValue = custom_value
|
||||
.as_any()
|
||||
.downcast_ref()
|
||||
.expect("not PluginCustomValue");
|
||||
assert_eq!(Some(source), plugin_custom_value.source);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn check_range_custom_values(
|
||||
val: &Value,
|
||||
mut f: impl FnMut(&str, &dyn CustomValue) -> Result<(), ShellError>,
|
||||
) -> Result<(), ShellError> {
|
||||
let range = val.as_range()?;
|
||||
for (name, val) in [
|
||||
("from", &range.from),
|
||||
("incr", &range.incr),
|
||||
("to", &range.to),
|
||||
] {
|
||||
let custom_value = val
|
||||
.as_custom_value()
|
||||
.unwrap_or_else(|_| panic!("{name} not custom value"));
|
||||
f(name, custom_value)?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn add_source_nested_range() -> Result<(), ShellError> {
|
||||
let orig_custom_val = Value::test_custom_value(Box::new(test_plugin_custom_value()));
|
||||
let mut val = Value::test_range(Range {
|
||||
from: orig_custom_val.clone(),
|
||||
incr: orig_custom_val.clone(),
|
||||
to: orig_custom_val.clone(),
|
||||
inclusion: RangeInclusion::Inclusive,
|
||||
});
|
||||
let source = PluginIdentity::new_fake("foo");
|
||||
PluginCustomValue::add_source(&mut val, &source);
|
||||
|
||||
check_range_custom_values(&val, |name, custom_value| {
|
||||
let plugin_custom_value: &PluginCustomValue = custom_value
|
||||
.as_any()
|
||||
.downcast_ref()
|
||||
.unwrap_or_else(|| panic!("{name} not PluginCustomValue"));
|
||||
assert_eq!(
|
||||
Some(&source),
|
||||
plugin_custom_value.source.as_ref(),
|
||||
"{name} source not set correctly"
|
||||
);
|
||||
Ok(())
|
||||
})
|
||||
}
|
||||
|
||||
fn check_record_custom_values(
|
||||
val: &Value,
|
||||
keys: &[&str],
|
||||
mut f: impl FnMut(&str, &dyn CustomValue) -> Result<(), ShellError>,
|
||||
) -> Result<(), ShellError> {
|
||||
let record = val.as_record()?;
|
||||
for key in keys {
|
||||
let val = record
|
||||
.get(key)
|
||||
.unwrap_or_else(|| panic!("record does not contain '{key}'"));
|
||||
let custom_value = val
|
||||
.as_custom_value()
|
||||
.unwrap_or_else(|_| panic!("'{key}' not custom value"));
|
||||
f(key, custom_value)?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn add_source_nested_record() -> Result<(), ShellError> {
|
||||
let orig_custom_val = Value::test_custom_value(Box::new(test_plugin_custom_value()));
|
||||
let mut val = Value::test_record(record! {
|
||||
"foo" => orig_custom_val.clone(),
|
||||
"bar" => orig_custom_val.clone(),
|
||||
});
|
||||
let source = PluginIdentity::new_fake("foo");
|
||||
PluginCustomValue::add_source(&mut val, &source);
|
||||
|
||||
check_record_custom_values(&val, &["foo", "bar"], |key, custom_value| {
|
||||
let plugin_custom_value: &PluginCustomValue = custom_value
|
||||
.as_any()
|
||||
.downcast_ref()
|
||||
.unwrap_or_else(|| panic!("'{key}' not PluginCustomValue"));
|
||||
assert_eq!(
|
||||
Some(&source),
|
||||
plugin_custom_value.source.as_ref(),
|
||||
"'{key}' source not set correctly"
|
||||
);
|
||||
Ok(())
|
||||
})
|
||||
}
|
||||
|
||||
fn check_list_custom_values(
|
||||
val: &Value,
|
||||
indices: impl IntoIterator<Item = usize>,
|
||||
mut f: impl FnMut(usize, &dyn CustomValue) -> Result<(), ShellError>,
|
||||
) -> Result<(), ShellError> {
|
||||
let list = val.as_list()?;
|
||||
for index in indices {
|
||||
let val = list
|
||||
.get(index)
|
||||
.unwrap_or_else(|| panic!("[{index}] not present in list"));
|
||||
let custom_value = val
|
||||
.as_custom_value()
|
||||
.unwrap_or_else(|_| panic!("[{index}] not custom value"));
|
||||
f(index, custom_value)?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn add_source_nested_list() -> Result<(), ShellError> {
|
||||
let orig_custom_val = Value::test_custom_value(Box::new(test_plugin_custom_value()));
|
||||
let mut val = Value::test_list(vec![orig_custom_val.clone(), orig_custom_val.clone()]);
|
||||
let source = PluginIdentity::new_fake("foo");
|
||||
PluginCustomValue::add_source(&mut val, &source);
|
||||
|
||||
check_list_custom_values(&val, 0..=1, |index, custom_value| {
|
||||
let plugin_custom_value: &PluginCustomValue = custom_value
|
||||
.as_any()
|
||||
.downcast_ref()
|
||||
.unwrap_or_else(|| panic!("[{index}] not PluginCustomValue"));
|
||||
assert_eq!(
|
||||
Some(&source),
|
||||
plugin_custom_value.source.as_ref(),
|
||||
"[{index}] source not set correctly"
|
||||
);
|
||||
Ok(())
|
||||
})
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn verify_source_error_message() -> Result<(), ShellError> {
|
||||
let span = Span::new(5, 7);
|
||||
let mut ok_val = Value::custom_value(Box::new(test_plugin_custom_value_with_source()), span);
|
||||
let mut native_val = Value::custom_value(Box::new(TestCustomValue(32)), span);
|
||||
let mut foreign_val = {
|
||||
let mut val = test_plugin_custom_value();
|
||||
val.source = Some(PluginIdentity::new_fake("other"));
|
||||
Value::custom_value(Box::new(val), span)
|
||||
};
|
||||
let source = PluginIdentity::new_fake("test");
|
||||
|
||||
PluginCustomValue::verify_source(&mut ok_val, &source).expect("ok_val should be verified ok");
|
||||
|
||||
for (val, src_plugin) in [(&mut native_val, None), (&mut foreign_val, Some("other"))] {
|
||||
let error = PluginCustomValue::verify_source(val, &source).expect_err(&format!(
|
||||
"a custom value from {src_plugin:?} should result in an error"
|
||||
));
|
||||
if let ShellError::CustomValueIncorrectForPlugin {
|
||||
name,
|
||||
span: err_span,
|
||||
dest_plugin,
|
||||
src_plugin: err_src_plugin,
|
||||
} = error
|
||||
{
|
||||
assert_eq!("TestCustomValue", name, "error.name from {src_plugin:?}");
|
||||
assert_eq!(span, err_span, "error.span from {src_plugin:?}");
|
||||
assert_eq!("test", dest_plugin, "error.dest_plugin from {src_plugin:?}");
|
||||
assert_eq!(src_plugin, err_src_plugin.as_deref(), "error.src_plugin");
|
||||
} else {
|
||||
panic!("the error returned should be CustomValueIncorrectForPlugin");
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn verify_source_nested_range() -> Result<(), ShellError> {
|
||||
let native_val = Value::test_custom_value(Box::new(TestCustomValue(32)));
|
||||
let source = PluginIdentity::new_fake("test");
|
||||
for (name, mut val) in [
|
||||
(
|
||||
"from",
|
||||
Value::test_range(Range {
|
||||
from: native_val.clone(),
|
||||
incr: Value::test_nothing(),
|
||||
to: Value::test_nothing(),
|
||||
inclusion: RangeInclusion::RightExclusive,
|
||||
}),
|
||||
),
|
||||
(
|
||||
"incr",
|
||||
Value::test_range(Range {
|
||||
from: Value::test_nothing(),
|
||||
incr: native_val.clone(),
|
||||
to: Value::test_nothing(),
|
||||
inclusion: RangeInclusion::RightExclusive,
|
||||
}),
|
||||
),
|
||||
(
|
||||
"to",
|
||||
Value::test_range(Range {
|
||||
from: Value::test_nothing(),
|
||||
incr: Value::test_nothing(),
|
||||
to: native_val.clone(),
|
||||
inclusion: RangeInclusion::RightExclusive,
|
||||
}),
|
||||
),
|
||||
] {
|
||||
PluginCustomValue::verify_source(&mut val, &source)
|
||||
.expect_err(&format!("error not generated on {name}"));
|
||||
}
|
||||
|
||||
let mut ok_range = Value::test_range(Range {
|
||||
from: Value::test_nothing(),
|
||||
incr: Value::test_nothing(),
|
||||
to: Value::test_nothing(),
|
||||
inclusion: RangeInclusion::RightExclusive,
|
||||
});
|
||||
PluginCustomValue::verify_source(&mut ok_range, &source)
|
||||
.expect("ok_range should not generate error");
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn verify_source_nested_record() -> Result<(), ShellError> {
|
||||
let native_val = Value::test_custom_value(Box::new(TestCustomValue(32)));
|
||||
let source = PluginIdentity::new_fake("test");
|
||||
for (name, mut val) in [
|
||||
(
|
||||
"first element foo",
|
||||
Value::test_record(record! {
|
||||
"foo" => native_val.clone(),
|
||||
"bar" => Value::test_nothing(),
|
||||
}),
|
||||
),
|
||||
(
|
||||
"second element bar",
|
||||
Value::test_record(record! {
|
||||
"foo" => Value::test_nothing(),
|
||||
"bar" => native_val.clone(),
|
||||
}),
|
||||
),
|
||||
] {
|
||||
PluginCustomValue::verify_source(&mut val, &source)
|
||||
.expect_err(&format!("error not generated on {name}"));
|
||||
}
|
||||
|
||||
let mut ok_record = Value::test_record(record! {"foo" => Value::test_nothing()});
|
||||
PluginCustomValue::verify_source(&mut ok_record, &source)
|
||||
.expect("ok_record should not generate error");
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn verify_source_nested_list() -> Result<(), ShellError> {
|
||||
let native_val = Value::test_custom_value(Box::new(TestCustomValue(32)));
|
||||
let source = PluginIdentity::new_fake("test");
|
||||
for (name, mut val) in [
|
||||
(
|
||||
"first element",
|
||||
Value::test_list(vec![native_val.clone(), Value::test_nothing()]),
|
||||
),
|
||||
(
|
||||
"second element",
|
||||
Value::test_list(vec![Value::test_nothing(), native_val.clone()]),
|
||||
),
|
||||
] {
|
||||
PluginCustomValue::verify_source(&mut val, &source)
|
||||
.expect_err(&format!("error not generated on {name}"));
|
||||
}
|
||||
|
||||
let mut ok_list = Value::test_list(vec![Value::test_nothing()]);
|
||||
PluginCustomValue::verify_source(&mut ok_list, &source)
|
||||
.expect("ok_list should not generate error");
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn serialize_in_root() -> Result<(), ShellError> {
|
||||
let span = Span::new(4, 10);
|
||||
let mut val = Value::custom_value(Box::new(expected_test_custom_value()), span);
|
||||
PluginCustomValue::serialize_custom_values_in(&mut val)?;
|
||||
|
||||
assert_eq!(span, val.span());
|
||||
|
||||
let custom_value = val.as_custom_value()?;
|
||||
if let Some(plugin_custom_value) = custom_value.as_any().downcast_ref::<PluginCustomValue>() {
|
||||
assert_eq!("TestCustomValue", plugin_custom_value.name);
|
||||
assert_eq!(test_plugin_custom_value().data, plugin_custom_value.data);
|
||||
assert!(plugin_custom_value.source.is_none());
|
||||
} else {
|
||||
panic!("Failed to downcast to PluginCustomValue");
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn serialize_in_range() -> Result<(), ShellError> {
|
||||
let orig_custom_val = Value::test_custom_value(Box::new(TestCustomValue(-1)));
|
||||
let mut val = Value::test_range(Range {
|
||||
from: orig_custom_val.clone(),
|
||||
incr: orig_custom_val.clone(),
|
||||
to: orig_custom_val.clone(),
|
||||
inclusion: RangeInclusion::Inclusive,
|
||||
});
|
||||
PluginCustomValue::serialize_custom_values_in(&mut val)?;
|
||||
|
||||
check_range_custom_values(&val, |name, custom_value| {
|
||||
let plugin_custom_value: &PluginCustomValue = custom_value
|
||||
.as_any()
|
||||
.downcast_ref()
|
||||
.unwrap_or_else(|| panic!("{name} not PluginCustomValue"));
|
||||
assert_eq!(
|
||||
"TestCustomValue", plugin_custom_value.name,
|
||||
"{name} name not set correctly"
|
||||
);
|
||||
Ok(())
|
||||
})
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn serialize_in_record() -> Result<(), ShellError> {
|
||||
let orig_custom_val = Value::test_custom_value(Box::new(TestCustomValue(32)));
|
||||
let mut val = Value::test_record(record! {
|
||||
"foo" => orig_custom_val.clone(),
|
||||
"bar" => orig_custom_val.clone(),
|
||||
});
|
||||
PluginCustomValue::serialize_custom_values_in(&mut val)?;
|
||||
|
||||
check_record_custom_values(&val, &["foo", "bar"], |key, custom_value| {
|
||||
let plugin_custom_value: &PluginCustomValue = custom_value
|
||||
.as_any()
|
||||
.downcast_ref()
|
||||
.unwrap_or_else(|| panic!("'{key}' not PluginCustomValue"));
|
||||
assert_eq!(
|
||||
"TestCustomValue", plugin_custom_value.name,
|
||||
"'{key}' name not set correctly"
|
||||
);
|
||||
Ok(())
|
||||
})
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn serialize_in_list() -> Result<(), ShellError> {
|
||||
let orig_custom_val = Value::test_custom_value(Box::new(TestCustomValue(24)));
|
||||
let mut val = Value::test_list(vec![orig_custom_val.clone(), orig_custom_val.clone()]);
|
||||
PluginCustomValue::serialize_custom_values_in(&mut val)?;
|
||||
|
||||
check_list_custom_values(&val, 0..=1, |index, custom_value| {
|
||||
let plugin_custom_value: &PluginCustomValue = custom_value
|
||||
.as_any()
|
||||
.downcast_ref()
|
||||
.unwrap_or_else(|| panic!("[{index}] not PluginCustomValue"));
|
||||
assert_eq!(
|
||||
"TestCustomValue", plugin_custom_value.name,
|
||||
"[{index}] name not set correctly"
|
||||
);
|
||||
Ok(())
|
||||
})
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn deserialize_in_root() -> Result<(), ShellError> {
|
||||
let span = Span::new(4, 10);
|
||||
let mut val = Value::custom_value(Box::new(test_plugin_custom_value()), span);
|
||||
PluginCustomValue::deserialize_custom_values_in(&mut val)?;
|
||||
|
||||
assert_eq!(span, val.span());
|
||||
|
||||
let custom_value = val.as_custom_value()?;
|
||||
if let Some(test_custom_value) = custom_value.as_any().downcast_ref::<TestCustomValue>() {
|
||||
assert_eq!(expected_test_custom_value(), *test_custom_value);
|
||||
} else {
|
||||
panic!("Failed to downcast to TestCustomValue");
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn deserialize_in_range() -> Result<(), ShellError> {
|
||||
let orig_custom_val = Value::test_custom_value(Box::new(test_plugin_custom_value()));
|
||||
let mut val = Value::test_range(Range {
|
||||
from: orig_custom_val.clone(),
|
||||
incr: orig_custom_val.clone(),
|
||||
to: orig_custom_val.clone(),
|
||||
inclusion: RangeInclusion::Inclusive,
|
||||
});
|
||||
PluginCustomValue::deserialize_custom_values_in(&mut val)?;
|
||||
|
||||
check_range_custom_values(&val, |name, custom_value| {
|
||||
let test_custom_value: &TestCustomValue = custom_value
|
||||
.as_any()
|
||||
.downcast_ref()
|
||||
.unwrap_or_else(|| panic!("{name} not TestCustomValue"));
|
||||
assert_eq!(
|
||||
expected_test_custom_value(),
|
||||
*test_custom_value,
|
||||
"{name} not deserialized correctly"
|
||||
);
|
||||
Ok(())
|
||||
})
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn deserialize_in_record() -> Result<(), ShellError> {
|
||||
let orig_custom_val = Value::test_custom_value(Box::new(test_plugin_custom_value()));
|
||||
let mut val = Value::test_record(record! {
|
||||
"foo" => orig_custom_val.clone(),
|
||||
"bar" => orig_custom_val.clone(),
|
||||
});
|
||||
PluginCustomValue::deserialize_custom_values_in(&mut val)?;
|
||||
|
||||
check_record_custom_values(&val, &["foo", "bar"], |key, custom_value| {
|
||||
let test_custom_value: &TestCustomValue = custom_value
|
||||
.as_any()
|
||||
.downcast_ref()
|
||||
.unwrap_or_else(|| panic!("'{key}' not TestCustomValue"));
|
||||
assert_eq!(
|
||||
expected_test_custom_value(),
|
||||
*test_custom_value,
|
||||
"{key} not deserialized correctly"
|
||||
);
|
||||
Ok(())
|
||||
})
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn deserialize_in_list() -> Result<(), ShellError> {
|
||||
let orig_custom_val = Value::test_custom_value(Box::new(test_plugin_custom_value()));
|
||||
let mut val = Value::test_list(vec![orig_custom_val.clone(), orig_custom_val.clone()]);
|
||||
PluginCustomValue::deserialize_custom_values_in(&mut val)?;
|
||||
|
||||
check_list_custom_values(&val, 0..=1, |index, custom_value| {
|
||||
let test_custom_value: &TestCustomValue = custom_value
|
||||
.as_any()
|
||||
.downcast_ref()
|
||||
.unwrap_or_else(|| panic!("[{index}] not TestCustomValue"));
|
||||
assert_eq!(
|
||||
expected_test_custom_value(),
|
||||
*test_custom_value,
|
||||
"[{index}] name not deserialized correctly"
|
||||
);
|
||||
Ok(())
|
||||
})
|
||||
}
|
@ -1,8 +0,0 @@
|
||||
use nu_protocol::Span;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, PartialEq, Eq)]
|
||||
pub struct PluginData {
|
||||
pub data: Vec<u8>,
|
||||
pub span: Span,
|
||||
}
|
80
crates/nu-plugin/src/protocol/protocol_info.rs
Normal file
80
crates/nu-plugin/src/protocol/protocol_info.rs
Normal file
@ -0,0 +1,80 @@
|
||||
use nu_protocol::ShellError;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
/// Protocol information, sent as a `Hello` message on initialization. This determines the
|
||||
/// compatibility of the plugin and engine. They are considered to be compatible if the lower
|
||||
/// version is semver compatible with the higher one.
|
||||
#[derive(Serialize, Deserialize, Debug, Clone)]
|
||||
pub struct ProtocolInfo {
|
||||
/// The name of the protocol being implemented. Only one protocol is supported. This field
|
||||
/// can be safely ignored, because not matching is a deserialization error
|
||||
pub protocol: Protocol,
|
||||
/// The semantic version of the protocol. This should be the version of the `nu-plugin`
|
||||
/// crate
|
||||
pub version: String,
|
||||
/// Supported optional features. This helps to maintain semver compatibility when adding new
|
||||
/// features
|
||||
pub features: Vec<Feature>,
|
||||
}
|
||||
|
||||
impl Default for ProtocolInfo {
|
||||
fn default() -> ProtocolInfo {
|
||||
ProtocolInfo {
|
||||
protocol: Protocol::NuPlugin,
|
||||
version: env!("CARGO_PKG_VERSION").into(),
|
||||
features: vec![],
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ProtocolInfo {
|
||||
pub fn is_compatible_with(&self, other: &ProtocolInfo) -> Result<bool, ShellError> {
|
||||
fn parse_failed(error: semver::Error) -> ShellError {
|
||||
ShellError::PluginFailedToLoad {
|
||||
msg: format!("Failed to parse protocol version: {error}"),
|
||||
}
|
||||
}
|
||||
let mut versions = [
|
||||
semver::Version::parse(&self.version).map_err(parse_failed)?,
|
||||
semver::Version::parse(&other.version).map_err(parse_failed)?,
|
||||
];
|
||||
|
||||
versions.sort();
|
||||
|
||||
// For example, if the lower version is 1.1.0, and the higher version is 1.2.3, the
|
||||
// requirement is that 1.2.3 matches ^1.1.0 (which it does)
|
||||
Ok(semver::Comparator {
|
||||
op: semver::Op::Caret,
|
||||
major: versions[0].major,
|
||||
minor: Some(versions[0].minor),
|
||||
patch: Some(versions[0].patch),
|
||||
pre: versions[0].pre.clone(),
|
||||
}
|
||||
.matches(&versions[1]))
|
||||
}
|
||||
}
|
||||
|
||||
/// Indicates the protocol in use. Only one protocol is supported.
|
||||
#[derive(Serialize, Deserialize, Debug, Clone, Default)]
|
||||
pub enum Protocol {
|
||||
/// Serializes to the value `"nu-plugin"`
|
||||
#[serde(rename = "nu-plugin")]
|
||||
#[default]
|
||||
NuPlugin,
|
||||
}
|
||||
|
||||
/// Indicates optional protocol features. This can help to make non-breaking-change additions to
|
||||
/// the protocol. Features are not restricted to plain strings and can contain additional
|
||||
/// configuration data.
|
||||
///
|
||||
/// Optional features should not be used by the protocol if they are not present in the
|
||||
/// [`ProtocolInfo`] sent by the other side.
|
||||
#[derive(Serialize, Deserialize, Debug, Clone)]
|
||||
#[serde(tag = "name")]
|
||||
pub enum Feature {
|
||||
/// A feature that was not recognized on deserialization. Attempting to serialize this feature
|
||||
/// is an error. Matching against it may only be used if necessary to determine whether
|
||||
/// unsupported features are present.
|
||||
#[serde(other, skip_serializing)]
|
||||
Unknown,
|
||||
}
|
50
crates/nu-plugin/src/protocol/test_util.rs
Normal file
50
crates/nu-plugin/src/protocol/test_util.rs
Normal file
@ -0,0 +1,50 @@
|
||||
use nu_protocol::{CustomValue, ShellError, Span, Value};
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::plugin::PluginIdentity;
|
||||
|
||||
use super::PluginCustomValue;
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
|
||||
pub(crate) struct TestCustomValue(pub i32);
|
||||
|
||||
#[typetag::serde]
|
||||
impl CustomValue for TestCustomValue {
|
||||
fn clone_value(&self, span: Span) -> Value {
|
||||
Value::custom_value(Box::new(self.clone()), span)
|
||||
}
|
||||
|
||||
fn value_string(&self) -> String {
|
||||
"TestCustomValue".into()
|
||||
}
|
||||
|
||||
fn to_base_value(&self, span: Span) -> Result<Value, ShellError> {
|
||||
Ok(Value::int(self.0 as i64, span))
|
||||
}
|
||||
|
||||
fn as_any(&self) -> &dyn std::any::Any {
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn test_plugin_custom_value() -> PluginCustomValue {
|
||||
let data = bincode::serialize(&expected_test_custom_value() as &dyn CustomValue)
|
||||
.expect("bincode serialization of the expected_test_custom_value() failed");
|
||||
|
||||
PluginCustomValue {
|
||||
name: "TestCustomValue".into(),
|
||||
data,
|
||||
source: None,
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn expected_test_custom_value() -> TestCustomValue {
|
||||
TestCustomValue(-1)
|
||||
}
|
||||
|
||||
pub(crate) fn test_plugin_custom_value_with_source() -> PluginCustomValue {
|
||||
PluginCustomValue {
|
||||
source: Some(PluginIdentity::new_fake("test")),
|
||||
..test_plugin_custom_value()
|
||||
}
|
||||
}
|
35
crates/nu-plugin/src/protocol/tests.rs
Normal file
35
crates/nu-plugin/src/protocol/tests.rs
Normal file
@ -0,0 +1,35 @@
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn protocol_info_compatible() -> Result<(), ShellError> {
|
||||
let ver_1_2_3 = ProtocolInfo {
|
||||
protocol: Protocol::NuPlugin,
|
||||
version: "1.2.3".into(),
|
||||
features: vec![],
|
||||
};
|
||||
let ver_1_1_0 = ProtocolInfo {
|
||||
protocol: Protocol::NuPlugin,
|
||||
version: "1.1.0".into(),
|
||||
features: vec![],
|
||||
};
|
||||
assert!(ver_1_1_0.is_compatible_with(&ver_1_2_3)?);
|
||||
assert!(ver_1_2_3.is_compatible_with(&ver_1_1_0)?);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn protocol_info_incompatible() -> Result<(), ShellError> {
|
||||
let ver_2_0_0 = ProtocolInfo {
|
||||
protocol: Protocol::NuPlugin,
|
||||
version: "2.0.0".into(),
|
||||
features: vec![],
|
||||
};
|
||||
let ver_1_1_0 = ProtocolInfo {
|
||||
protocol: Protocol::NuPlugin,
|
||||
version: "1.1.0".into(),
|
||||
features: vec![],
|
||||
};
|
||||
assert!(!ver_2_0_0.is_compatible_with(&ver_1_1_0)?);
|
||||
assert!(!ver_1_1_0.is_compatible_with(&ver_2_0_0)?);
|
||||
Ok(())
|
||||
}
|
65
crates/nu-plugin/src/sequence.rs
Normal file
65
crates/nu-plugin/src/sequence.rs
Normal file
@ -0,0 +1,65 @@
|
||||
use std::sync::atomic::{AtomicUsize, Ordering::Relaxed};
|
||||
|
||||
use nu_protocol::ShellError;
|
||||
|
||||
/// Implements an atomically incrementing sequential series of numbers
|
||||
#[derive(Debug, Default)]
|
||||
pub(crate) struct Sequence(AtomicUsize);
|
||||
|
||||
impl Sequence {
|
||||
/// Return the next available id from a sequence, returning an error on overflow
|
||||
#[track_caller]
|
||||
pub(crate) fn next(&self) -> Result<usize, ShellError> {
|
||||
// It's totally safe to use Relaxed ordering here, as there aren't other memory operations
|
||||
// that depend on this value having been set for safety
|
||||
//
|
||||
// We're only not using `fetch_add` so that we can check for overflow, as wrapping with the
|
||||
// identifier would lead to a serious bug - however unlikely that is.
|
||||
self.0
|
||||
.fetch_update(Relaxed, Relaxed, |current| current.checked_add(1))
|
||||
.map_err(|_| ShellError::NushellFailedHelp {
|
||||
msg: "an accumulator for identifiers overflowed".into(),
|
||||
help: format!("see {}", std::panic::Location::caller()),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn output_is_sequential() {
|
||||
let sequence = Sequence::default();
|
||||
|
||||
for (expected, generated) in (0..1000).zip(std::iter::repeat_with(|| sequence.next())) {
|
||||
assert_eq!(expected, generated.expect("error in sequence"));
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn output_is_unique_even_under_contention() {
|
||||
let sequence = Sequence::default();
|
||||
|
||||
std::thread::scope(|scope| {
|
||||
// Spawn four threads, all advancing the sequence simultaneously
|
||||
let threads = (0..4)
|
||||
.map(|_| {
|
||||
scope.spawn(|| {
|
||||
(0..100000)
|
||||
.map(|_| sequence.next())
|
||||
.collect::<Result<Vec<_>, _>>()
|
||||
})
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
// Collect all of the results into a single flat vec
|
||||
let mut results = threads
|
||||
.into_iter()
|
||||
.flat_map(|thread| thread.join().expect("panicked").expect("error"))
|
||||
.collect::<Vec<usize>>();
|
||||
|
||||
// Check uniqueness
|
||||
results.sort();
|
||||
let initial_length = results.len();
|
||||
results.dedup();
|
||||
let deduplicated_length = results.len();
|
||||
assert_eq!(initial_length, deduplicated_length);
|
||||
})
|
||||
}
|
@ -1,53 +1,94 @@
|
||||
use crate::{
|
||||
plugin::{Encoder, PluginEncoder},
|
||||
protocol::{PluginInput, PluginOutput},
|
||||
};
|
||||
use nu_protocol::ShellError;
|
||||
use serde::Deserialize;
|
||||
|
||||
use crate::{plugin::PluginEncoder, protocol::PluginResponse};
|
||||
|
||||
/// A `PluginEncoder` that enables the plugin to communicate with Nushel with JSON
|
||||
/// A `PluginEncoder` that enables the plugin to communicate with Nushell with JSON
|
||||
/// serialized data.
|
||||
#[derive(Clone, Debug)]
|
||||
///
|
||||
/// Each message in the stream is followed by a newline when serializing, but is not required for
|
||||
/// deserialization. The output is not pretty printed and each object does not contain newlines.
|
||||
/// If it is more convenient, a plugin may choose to separate messages by newline.
|
||||
#[derive(Clone, Copy, Debug)]
|
||||
pub struct JsonSerializer;
|
||||
|
||||
impl PluginEncoder for JsonSerializer {
|
||||
fn name(&self) -> &str {
|
||||
"json"
|
||||
}
|
||||
}
|
||||
|
||||
fn encode_call(
|
||||
impl Encoder<PluginInput> for JsonSerializer {
|
||||
fn encode(
|
||||
&self,
|
||||
plugin_call: &crate::protocol::PluginCall,
|
||||
plugin_input: &PluginInput,
|
||||
writer: &mut impl std::io::Write,
|
||||
) -> Result<(), nu_protocol::ShellError> {
|
||||
serde_json::to_writer(writer, plugin_call).map_err(|err| ShellError::PluginFailedToEncode {
|
||||
serde_json::to_writer(&mut *writer, plugin_input).map_err(json_encode_err)?;
|
||||
writer.write_all(b"\n").map_err(|err| ShellError::IOError {
|
||||
msg: err.to_string(),
|
||||
})
|
||||
}
|
||||
|
||||
fn decode_call(
|
||||
fn decode(
|
||||
&self,
|
||||
reader: &mut impl std::io::BufRead,
|
||||
) -> Result<crate::protocol::PluginCall, nu_protocol::ShellError> {
|
||||
serde_json::from_reader(reader).map_err(|err| ShellError::PluginFailedToEncode {
|
||||
msg: err.to_string(),
|
||||
})
|
||||
) -> Result<Option<PluginInput>, nu_protocol::ShellError> {
|
||||
let mut de = serde_json::Deserializer::from_reader(reader);
|
||||
PluginInput::deserialize(&mut de)
|
||||
.map(Some)
|
||||
.or_else(json_decode_err)
|
||||
}
|
||||
}
|
||||
|
||||
fn encode_response(
|
||||
impl Encoder<PluginOutput> for JsonSerializer {
|
||||
fn encode(
|
||||
&self,
|
||||
plugin_response: &PluginResponse,
|
||||
plugin_output: &PluginOutput,
|
||||
writer: &mut impl std::io::Write,
|
||||
) -> Result<(), ShellError> {
|
||||
serde_json::to_writer(writer, plugin_response).map_err(|err| {
|
||||
ShellError::PluginFailedToEncode {
|
||||
msg: err.to_string(),
|
||||
}
|
||||
serde_json::to_writer(&mut *writer, plugin_output).map_err(json_encode_err)?;
|
||||
writer.write_all(b"\n").map_err(|err| ShellError::IOError {
|
||||
msg: err.to_string(),
|
||||
})
|
||||
}
|
||||
|
||||
fn decode_response(
|
||||
fn decode(
|
||||
&self,
|
||||
reader: &mut impl std::io::BufRead,
|
||||
) -> Result<PluginResponse, ShellError> {
|
||||
serde_json::from_reader(reader).map_err(|err| ShellError::PluginFailedToEncode {
|
||||
) -> Result<Option<PluginOutput>, ShellError> {
|
||||
let mut de = serde_json::Deserializer::from_reader(reader);
|
||||
PluginOutput::deserialize(&mut de)
|
||||
.map(Some)
|
||||
.or_else(json_decode_err)
|
||||
}
|
||||
}
|
||||
|
||||
/// Handle a `serde_json` encode error.
|
||||
fn json_encode_err(err: serde_json::Error) -> ShellError {
|
||||
if err.is_io() {
|
||||
ShellError::IOError {
|
||||
msg: err.to_string(),
|
||||
}
|
||||
} else {
|
||||
ShellError::PluginFailedToEncode {
|
||||
msg: err.to_string(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Handle a `serde_json` decode error. Returns `Ok(None)` on eof.
|
||||
fn json_decode_err<T>(err: serde_json::Error) -> Result<Option<T>, ShellError> {
|
||||
if err.is_eof() {
|
||||
Ok(None)
|
||||
} else if err.is_io() {
|
||||
Err(ShellError::IOError {
|
||||
msg: err.to_string(),
|
||||
})
|
||||
} else {
|
||||
Err(ShellError::PluginFailedToDecode {
|
||||
msg: err.to_string(),
|
||||
})
|
||||
}
|
||||
@ -56,306 +97,38 @@ impl PluginEncoder for JsonSerializer {
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::protocol::{
|
||||
CallInfo, CallInput, EvaluatedCall, LabeledError, PluginCall, PluginData, PluginResponse,
|
||||
};
|
||||
use nu_protocol::{PluginSignature, Span, Spanned, SyntaxShape, Value};
|
||||
crate::serializers::tests::generate_tests!(JsonSerializer {});
|
||||
|
||||
#[test]
|
||||
fn callinfo_round_trip_signature() {
|
||||
let plugin_call = PluginCall::Signature;
|
||||
let encoder = JsonSerializer {};
|
||||
|
||||
let mut buffer: Vec<u8> = Vec::new();
|
||||
encoder
|
||||
.encode_call(&plugin_call, &mut buffer)
|
||||
.expect("unable to serialize message");
|
||||
let returned = encoder
|
||||
.decode_call(&mut buffer.as_slice())
|
||||
.expect("unable to deserialize message");
|
||||
|
||||
match returned {
|
||||
PluginCall::Signature => {}
|
||||
PluginCall::CallInfo(_) => panic!("decoded into wrong value"),
|
||||
PluginCall::CollapseCustomValue(_) => panic!("decoded into wrong value"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn callinfo_round_trip_callinfo() {
|
||||
let name = "test".to_string();
|
||||
|
||||
let input = Value::bool(false, Span::new(1, 20));
|
||||
|
||||
let call = EvaluatedCall {
|
||||
head: Span::new(0, 10),
|
||||
positional: vec![
|
||||
Value::float(1.0, Span::new(0, 10)),
|
||||
Value::string("something", Span::new(0, 10)),
|
||||
],
|
||||
named: vec![(
|
||||
Spanned {
|
||||
item: "name".to_string(),
|
||||
span: Span::new(0, 10),
|
||||
},
|
||||
Some(Value::float(1.0, Span::new(0, 10))),
|
||||
)],
|
||||
};
|
||||
|
||||
let plugin_call = PluginCall::CallInfo(CallInfo {
|
||||
name: name.clone(),
|
||||
call: call.clone(),
|
||||
input: CallInput::Value(input.clone()),
|
||||
config: None,
|
||||
});
|
||||
|
||||
let encoder = JsonSerializer {};
|
||||
let mut buffer: Vec<u8> = Vec::new();
|
||||
encoder
|
||||
.encode_call(&plugin_call, &mut buffer)
|
||||
.expect("unable to serialize message");
|
||||
let returned = encoder
|
||||
.decode_call(&mut buffer.as_slice())
|
||||
.expect("unable to deserialize message");
|
||||
|
||||
match returned {
|
||||
PluginCall::Signature => panic!("returned wrong call type"),
|
||||
PluginCall::CallInfo(call_info) => {
|
||||
assert_eq!(name, call_info.name);
|
||||
assert_eq!(CallInput::Value(input), call_info.input);
|
||||
assert_eq!(call.head, call_info.call.head);
|
||||
assert_eq!(call.positional.len(), call_info.call.positional.len());
|
||||
|
||||
call.positional
|
||||
.iter()
|
||||
.zip(call_info.call.positional.iter())
|
||||
.for_each(|(lhs, rhs)| assert_eq!(lhs, rhs));
|
||||
|
||||
call.named
|
||||
.iter()
|
||||
.zip(call_info.call.named.iter())
|
||||
.for_each(|(lhs, rhs)| {
|
||||
// Comparing the keys
|
||||
assert_eq!(lhs.0.item, rhs.0.item);
|
||||
|
||||
match (&lhs.1, &rhs.1) {
|
||||
(None, None) => {}
|
||||
(Some(a), Some(b)) => assert_eq!(a, b),
|
||||
_ => panic!("not matching values"),
|
||||
}
|
||||
});
|
||||
}
|
||||
PluginCall::CollapseCustomValue(_) => panic!("returned wrong call type"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn callinfo_round_trip_collapsecustomvalue() {
|
||||
let data = vec![1, 2, 3, 4, 5, 6, 7];
|
||||
let span = Span::new(0, 20);
|
||||
|
||||
let collapse_custom_value = PluginCall::CollapseCustomValue(PluginData {
|
||||
data: data.clone(),
|
||||
span,
|
||||
});
|
||||
|
||||
let encoder = JsonSerializer {};
|
||||
let mut buffer: Vec<u8> = Vec::new();
|
||||
encoder
|
||||
.encode_call(&collapse_custom_value, &mut buffer)
|
||||
.expect("unable to serialize message");
|
||||
let returned = encoder
|
||||
.decode_call(&mut buffer.as_slice())
|
||||
.expect("unable to deserialize message");
|
||||
|
||||
match returned {
|
||||
PluginCall::Signature => panic!("returned wrong call type"),
|
||||
PluginCall::CallInfo(_) => panic!("returned wrong call type"),
|
||||
PluginCall::CollapseCustomValue(plugin_data) => {
|
||||
assert_eq!(data, plugin_data.data);
|
||||
assert_eq!(span, plugin_data.span);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn response_round_trip_signature() {
|
||||
let signature = PluginSignature::build("nu-plugin")
|
||||
.required("first", SyntaxShape::String, "first required")
|
||||
.required("second", SyntaxShape::Int, "second required")
|
||||
.required_named("first-named", SyntaxShape::String, "first named", Some('f'))
|
||||
.required_named(
|
||||
"second-named",
|
||||
SyntaxShape::String,
|
||||
"second named",
|
||||
Some('s'),
|
||||
)
|
||||
.rest("remaining", SyntaxShape::Int, "remaining");
|
||||
|
||||
let response = PluginResponse::Signature(vec![signature.clone()]);
|
||||
|
||||
let encoder = JsonSerializer {};
|
||||
let mut buffer: Vec<u8> = Vec::new();
|
||||
encoder
|
||||
.encode_response(&response, &mut buffer)
|
||||
.expect("unable to serialize message");
|
||||
let returned = encoder
|
||||
.decode_response(&mut buffer.as_slice())
|
||||
.expect("unable to deserialize message");
|
||||
|
||||
match returned {
|
||||
PluginResponse::Error(_) => panic!("returned wrong call type"),
|
||||
PluginResponse::Value(_) => panic!("returned wrong call type"),
|
||||
PluginResponse::PluginData(..) => panic!("returned wrong call type"),
|
||||
PluginResponse::Signature(returned_signature) => {
|
||||
assert_eq!(returned_signature.len(), 1);
|
||||
assert_eq!(signature.sig.name, returned_signature[0].sig.name);
|
||||
assert_eq!(signature.sig.usage, returned_signature[0].sig.usage);
|
||||
assert_eq!(
|
||||
signature.sig.extra_usage,
|
||||
returned_signature[0].sig.extra_usage
|
||||
);
|
||||
assert_eq!(signature.sig.is_filter, returned_signature[0].sig.is_filter);
|
||||
|
||||
signature
|
||||
.sig
|
||||
.required_positional
|
||||
.iter()
|
||||
.zip(returned_signature[0].sig.required_positional.iter())
|
||||
.for_each(|(lhs, rhs)| assert_eq!(lhs, rhs));
|
||||
|
||||
signature
|
||||
.sig
|
||||
.optional_positional
|
||||
.iter()
|
||||
.zip(returned_signature[0].sig.optional_positional.iter())
|
||||
.for_each(|(lhs, rhs)| assert_eq!(lhs, rhs));
|
||||
|
||||
signature
|
||||
.sig
|
||||
.named
|
||||
.iter()
|
||||
.zip(returned_signature[0].sig.named.iter())
|
||||
.for_each(|(lhs, rhs)| assert_eq!(lhs, rhs));
|
||||
|
||||
assert_eq!(
|
||||
signature.sig.rest_positional,
|
||||
returned_signature[0].sig.rest_positional,
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn response_round_trip_value() {
|
||||
let value = Value::int(10, Span::new(2, 30));
|
||||
|
||||
let response = PluginResponse::Value(Box::new(value.clone()));
|
||||
|
||||
let encoder = JsonSerializer {};
|
||||
let mut buffer: Vec<u8> = Vec::new();
|
||||
encoder
|
||||
.encode_response(&response, &mut buffer)
|
||||
.expect("unable to serialize message");
|
||||
let returned = encoder
|
||||
.decode_response(&mut buffer.as_slice())
|
||||
.expect("unable to deserialize message");
|
||||
|
||||
match returned {
|
||||
PluginResponse::Error(_) => panic!("returned wrong call type"),
|
||||
PluginResponse::Signature(_) => panic!("returned wrong call type"),
|
||||
PluginResponse::PluginData(..) => panic!("returned wrong call type"),
|
||||
PluginResponse::Value(returned_value) => {
|
||||
assert_eq!(&value, returned_value.as_ref())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn response_round_trip_plugin_data() {
|
||||
let name = "test".to_string();
|
||||
|
||||
let data = vec![1, 2, 3, 4, 5];
|
||||
let span = Span::new(2, 30);
|
||||
|
||||
let response = PluginResponse::PluginData(
|
||||
name.clone(),
|
||||
PluginData {
|
||||
data: data.clone(),
|
||||
span,
|
||||
},
|
||||
fn json_ends_in_newline() {
|
||||
let mut out = vec![];
|
||||
JsonSerializer {}
|
||||
.encode(&PluginInput::Call(0, PluginCall::Signature), &mut out)
|
||||
.expect("serialization error");
|
||||
let string = std::str::from_utf8(&out).expect("utf-8 error");
|
||||
assert!(
|
||||
string.ends_with('\n'),
|
||||
"doesn't end with newline: {:?}",
|
||||
string
|
||||
);
|
||||
|
||||
let encoder = JsonSerializer {};
|
||||
let mut buffer: Vec<u8> = Vec::new();
|
||||
encoder
|
||||
.encode_response(&response, &mut buffer)
|
||||
.expect("unable to serialize message");
|
||||
let returned = encoder
|
||||
.decode_response(&mut buffer.as_slice())
|
||||
.expect("unable to deserialize message");
|
||||
|
||||
match returned {
|
||||
PluginResponse::Error(_) => panic!("returned wrong call type"),
|
||||
PluginResponse::Signature(_) => panic!("returned wrong call type"),
|
||||
PluginResponse::Value(_) => panic!("returned wrong call type"),
|
||||
PluginResponse::PluginData(returned_name, returned_plugin_data) => {
|
||||
assert_eq!(name, returned_name);
|
||||
assert_eq!(data, returned_plugin_data.data);
|
||||
assert_eq!(span, returned_plugin_data.span);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn response_round_trip_error() {
|
||||
let error = LabeledError {
|
||||
label: "label".into(),
|
||||
msg: "msg".into(),
|
||||
span: Some(Span::new(2, 30)),
|
||||
};
|
||||
let response = PluginResponse::Error(error.clone());
|
||||
|
||||
let encoder = JsonSerializer {};
|
||||
let mut buffer: Vec<u8> = Vec::new();
|
||||
encoder
|
||||
.encode_response(&response, &mut buffer)
|
||||
.expect("unable to serialize message");
|
||||
let returned = encoder
|
||||
.decode_response(&mut buffer.as_slice())
|
||||
.expect("unable to deserialize message");
|
||||
|
||||
match returned {
|
||||
PluginResponse::Error(msg) => assert_eq!(error, msg),
|
||||
PluginResponse::Signature(_) => panic!("returned wrong call type"),
|
||||
PluginResponse::Value(_) => panic!("returned wrong call type"),
|
||||
PluginResponse::PluginData(..) => panic!("returned wrong call type"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn response_round_trip_error_none() {
|
||||
let error = LabeledError {
|
||||
label: "label".into(),
|
||||
msg: "msg".into(),
|
||||
span: None,
|
||||
};
|
||||
let response = PluginResponse::Error(error.clone());
|
||||
|
||||
let encoder = JsonSerializer {};
|
||||
let mut buffer: Vec<u8> = Vec::new();
|
||||
encoder
|
||||
.encode_response(&response, &mut buffer)
|
||||
.expect("unable to serialize message");
|
||||
let returned = encoder
|
||||
.decode_response(&mut buffer.as_slice())
|
||||
.expect("unable to deserialize message");
|
||||
|
||||
match returned {
|
||||
PluginResponse::Error(msg) => assert_eq!(error, msg),
|
||||
PluginResponse::Signature(_) => panic!("returned wrong call type"),
|
||||
PluginResponse::Value(_) => panic!("returned wrong call type"),
|
||||
PluginResponse::PluginData(..) => panic!("returned wrong call type"),
|
||||
}
|
||||
fn json_has_no_other_newlines() {
|
||||
let mut out = vec![];
|
||||
// use something deeply nested, to try to trigger any pretty printing
|
||||
let output = PluginOutput::Stream(StreamMessage::Data(
|
||||
0,
|
||||
StreamData::List(Value::test_list(vec![
|
||||
Value::test_int(4),
|
||||
// in case escaping failed
|
||||
Value::test_string("newline\ncontaining\nstring"),
|
||||
])),
|
||||
));
|
||||
JsonSerializer {}
|
||||
.encode(&output, &mut out)
|
||||
.expect("serialization error");
|
||||
let string = std::str::from_utf8(&out).expect("utf-8 error");
|
||||
assert_eq!(1, string.chars().filter(|ch| *ch == '\n').count());
|
||||
}
|
||||
}
|
||||
|
@ -1,14 +1,14 @@
|
||||
use crate::{
|
||||
plugin::PluginEncoder,
|
||||
protocol::{PluginCall, PluginResponse},
|
||||
};
|
||||
use crate::plugin::{Encoder, PluginEncoder};
|
||||
use nu_protocol::ShellError;
|
||||
|
||||
pub mod json;
|
||||
pub mod msgpack;
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests;
|
||||
|
||||
#[doc(hidden)]
|
||||
#[derive(Clone, Debug)]
|
||||
#[derive(Clone, Copy, Debug)]
|
||||
pub enum EncodingType {
|
||||
Json(json::JsonSerializer),
|
||||
MsgPack(msgpack::MsgPackSerializer),
|
||||
@ -23,48 +23,6 @@ impl EncodingType {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn encode_call(
|
||||
&self,
|
||||
plugin_call: &PluginCall,
|
||||
writer: &mut impl std::io::Write,
|
||||
) -> Result<(), ShellError> {
|
||||
match self {
|
||||
EncodingType::Json(encoder) => encoder.encode_call(plugin_call, writer),
|
||||
EncodingType::MsgPack(encoder) => encoder.encode_call(plugin_call, writer),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn decode_call(
|
||||
&self,
|
||||
reader: &mut impl std::io::BufRead,
|
||||
) -> Result<PluginCall, ShellError> {
|
||||
match self {
|
||||
EncodingType::Json(encoder) => encoder.decode_call(reader),
|
||||
EncodingType::MsgPack(encoder) => encoder.decode_call(reader),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn encode_response(
|
||||
&self,
|
||||
plugin_response: &PluginResponse,
|
||||
writer: &mut impl std::io::Write,
|
||||
) -> Result<(), ShellError> {
|
||||
match self {
|
||||
EncodingType::Json(encoder) => encoder.encode_response(plugin_response, writer),
|
||||
EncodingType::MsgPack(encoder) => encoder.encode_response(plugin_response, writer),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn decode_response(
|
||||
&self,
|
||||
reader: &mut impl std::io::BufRead,
|
||||
) -> Result<PluginResponse, ShellError> {
|
||||
match self {
|
||||
EncodingType::Json(encoder) => encoder.decode_response(reader),
|
||||
EncodingType::MsgPack(encoder) => encoder.decode_response(reader),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn to_str(&self) -> &'static str {
|
||||
match self {
|
||||
Self::Json(_) => "json",
|
||||
@ -72,3 +30,29 @@ impl EncodingType {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl PluginEncoder for EncodingType {
|
||||
fn name(&self) -> &str {
|
||||
self.to_str()
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> Encoder<T> for EncodingType
|
||||
where
|
||||
json::JsonSerializer: Encoder<T>,
|
||||
msgpack::MsgPackSerializer: Encoder<T>,
|
||||
{
|
||||
fn encode(&self, data: &T, writer: &mut impl std::io::Write) -> Result<(), ShellError> {
|
||||
match self {
|
||||
EncodingType::Json(encoder) => encoder.encode(data, writer),
|
||||
EncodingType::MsgPack(encoder) => encoder.encode(data, writer),
|
||||
}
|
||||
}
|
||||
|
||||
fn decode(&self, reader: &mut impl std::io::BufRead) -> Result<Option<T>, ShellError> {
|
||||
match self {
|
||||
EncodingType::Json(encoder) => encoder.decode(reader),
|
||||
EncodingType::MsgPack(encoder) => encoder.decode(reader),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -1,362 +1,110 @@
|
||||
use crate::{plugin::PluginEncoder, protocol::PluginResponse};
|
||||
use nu_protocol::ShellError;
|
||||
use std::io::ErrorKind;
|
||||
|
||||
/// A `PluginEncoder` that enables the plugin to communicate with Nushel with MsgPack
|
||||
use crate::{
|
||||
plugin::{Encoder, PluginEncoder},
|
||||
protocol::{PluginInput, PluginOutput},
|
||||
};
|
||||
use nu_protocol::ShellError;
|
||||
use serde::Deserialize;
|
||||
|
||||
/// A `PluginEncoder` that enables the plugin to communicate with Nushell with MsgPack
|
||||
/// serialized data.
|
||||
#[derive(Clone, Debug)]
|
||||
///
|
||||
/// Each message is written as a MessagePack object. There is no message envelope or separator.
|
||||
#[derive(Clone, Copy, Debug)]
|
||||
pub struct MsgPackSerializer;
|
||||
|
||||
impl PluginEncoder for MsgPackSerializer {
|
||||
fn name(&self) -> &str {
|
||||
"msgpack"
|
||||
}
|
||||
}
|
||||
|
||||
fn encode_call(
|
||||
impl Encoder<PluginInput> for MsgPackSerializer {
|
||||
fn encode(
|
||||
&self,
|
||||
plugin_call: &crate::protocol::PluginCall,
|
||||
plugin_input: &PluginInput,
|
||||
writer: &mut impl std::io::Write,
|
||||
) -> Result<(), nu_protocol::ShellError> {
|
||||
rmp_serde::encode::write(writer, plugin_call).map_err(|err| {
|
||||
ShellError::PluginFailedToEncode {
|
||||
msg: err.to_string(),
|
||||
}
|
||||
})
|
||||
rmp_serde::encode::write(writer, plugin_input).map_err(rmp_encode_err)
|
||||
}
|
||||
|
||||
fn decode_call(
|
||||
fn decode(
|
||||
&self,
|
||||
reader: &mut impl std::io::BufRead,
|
||||
) -> Result<crate::protocol::PluginCall, nu_protocol::ShellError> {
|
||||
rmp_serde::from_read(reader).map_err(|err| ShellError::PluginFailedToDecode {
|
||||
msg: err.to_string(),
|
||||
})
|
||||
) -> Result<Option<PluginInput>, ShellError> {
|
||||
let mut de = rmp_serde::Deserializer::new(reader);
|
||||
PluginInput::deserialize(&mut de)
|
||||
.map(Some)
|
||||
.or_else(rmp_decode_err)
|
||||
}
|
||||
}
|
||||
|
||||
fn encode_response(
|
||||
impl Encoder<PluginOutput> for MsgPackSerializer {
|
||||
fn encode(
|
||||
&self,
|
||||
plugin_response: &PluginResponse,
|
||||
plugin_output: &PluginOutput,
|
||||
writer: &mut impl std::io::Write,
|
||||
) -> Result<(), ShellError> {
|
||||
rmp_serde::encode::write(writer, plugin_response).map_err(|err| {
|
||||
rmp_serde::encode::write(writer, plugin_output).map_err(rmp_encode_err)
|
||||
}
|
||||
|
||||
fn decode(
|
||||
&self,
|
||||
reader: &mut impl std::io::BufRead,
|
||||
) -> Result<Option<PluginOutput>, ShellError> {
|
||||
let mut de = rmp_serde::Deserializer::new(reader);
|
||||
PluginOutput::deserialize(&mut de)
|
||||
.map(Some)
|
||||
.or_else(rmp_decode_err)
|
||||
}
|
||||
}
|
||||
|
||||
/// Handle a msgpack encode error
|
||||
fn rmp_encode_err(err: rmp_serde::encode::Error) -> ShellError {
|
||||
match err {
|
||||
rmp_serde::encode::Error::InvalidValueWrite(_) => {
|
||||
// I/O error
|
||||
ShellError::IOError {
|
||||
msg: err.to_string(),
|
||||
}
|
||||
}
|
||||
_ => {
|
||||
// Something else
|
||||
ShellError::PluginFailedToEncode {
|
||||
msg: err.to_string(),
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn decode_response(
|
||||
&self,
|
||||
reader: &mut impl std::io::BufRead,
|
||||
) -> Result<PluginResponse, ShellError> {
|
||||
rmp_serde::from_read(reader).map_err(|err| ShellError::PluginFailedToDecode {
|
||||
msg: err.to_string(),
|
||||
})
|
||||
/// Handle a msgpack decode error. Returns `Ok(None)` on eof
|
||||
fn rmp_decode_err<T>(err: rmp_serde::decode::Error) -> Result<Option<T>, ShellError> {
|
||||
match err {
|
||||
rmp_serde::decode::Error::InvalidMarkerRead(err)
|
||||
if matches!(err.kind(), ErrorKind::UnexpectedEof) =>
|
||||
{
|
||||
// EOF
|
||||
Ok(None)
|
||||
}
|
||||
rmp_serde::decode::Error::InvalidMarkerRead(_)
|
||||
| rmp_serde::decode::Error::InvalidDataRead(_) => {
|
||||
// I/O error
|
||||
Err(ShellError::IOError {
|
||||
msg: err.to_string(),
|
||||
})
|
||||
}
|
||||
_ => {
|
||||
// Something else
|
||||
Err(ShellError::PluginFailedToDecode {
|
||||
msg: err.to_string(),
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::protocol::{
|
||||
CallInfo, CallInput, EvaluatedCall, LabeledError, PluginCall, PluginData, PluginResponse,
|
||||
};
|
||||
use nu_protocol::{PluginSignature, Span, Spanned, SyntaxShape, Value};
|
||||
|
||||
#[test]
|
||||
fn callinfo_round_trip_signature() {
|
||||
let plugin_call = PluginCall::Signature;
|
||||
let encoder = MsgPackSerializer {};
|
||||
|
||||
let mut buffer: Vec<u8> = Vec::new();
|
||||
encoder
|
||||
.encode_call(&plugin_call, &mut buffer)
|
||||
.expect("unable to serialize message");
|
||||
let returned = encoder
|
||||
.decode_call(&mut buffer.as_slice())
|
||||
.expect("unable to deserialize message");
|
||||
|
||||
match returned {
|
||||
PluginCall::Signature => {}
|
||||
PluginCall::CallInfo(_) => panic!("decoded into wrong value"),
|
||||
PluginCall::CollapseCustomValue(_) => panic!("decoded into wrong value"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn callinfo_round_trip_callinfo() {
|
||||
let name = "test".to_string();
|
||||
|
||||
let input = Value::bool(false, Span::new(1, 20));
|
||||
|
||||
let call = EvaluatedCall {
|
||||
head: Span::new(0, 10),
|
||||
positional: vec![
|
||||
Value::float(1.0, Span::new(0, 10)),
|
||||
Value::string("something", Span::new(0, 10)),
|
||||
],
|
||||
named: vec![(
|
||||
Spanned {
|
||||
item: "name".to_string(),
|
||||
span: Span::new(0, 10),
|
||||
},
|
||||
Some(Value::float(1.0, Span::new(0, 10))),
|
||||
)],
|
||||
};
|
||||
|
||||
let plugin_call = PluginCall::CallInfo(CallInfo {
|
||||
name: name.clone(),
|
||||
call: call.clone(),
|
||||
input: CallInput::Value(input.clone()),
|
||||
config: None,
|
||||
});
|
||||
|
||||
let encoder = MsgPackSerializer {};
|
||||
let mut buffer: Vec<u8> = Vec::new();
|
||||
encoder
|
||||
.encode_call(&plugin_call, &mut buffer)
|
||||
.expect("unable to serialize message");
|
||||
let returned = encoder
|
||||
.decode_call(&mut buffer.as_slice())
|
||||
.expect("unable to deserialize message");
|
||||
|
||||
match returned {
|
||||
PluginCall::Signature => panic!("returned wrong call type"),
|
||||
PluginCall::CallInfo(call_info) => {
|
||||
assert_eq!(name, call_info.name);
|
||||
assert_eq!(CallInput::Value(input), call_info.input);
|
||||
assert_eq!(call.head, call_info.call.head);
|
||||
assert_eq!(call.positional.len(), call_info.call.positional.len());
|
||||
|
||||
call.positional
|
||||
.iter()
|
||||
.zip(call_info.call.positional.iter())
|
||||
.for_each(|(lhs, rhs)| assert_eq!(lhs, rhs));
|
||||
|
||||
call.named
|
||||
.iter()
|
||||
.zip(call_info.call.named.iter())
|
||||
.for_each(|(lhs, rhs)| {
|
||||
// Comparing the keys
|
||||
assert_eq!(lhs.0.item, rhs.0.item);
|
||||
|
||||
match (&lhs.1, &rhs.1) {
|
||||
(None, None) => {}
|
||||
(Some(a), Some(b)) => assert_eq!(a, b),
|
||||
_ => panic!("not matching values"),
|
||||
}
|
||||
});
|
||||
}
|
||||
PluginCall::CollapseCustomValue(_) => panic!("returned wrong call type"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn callinfo_round_trip_collapsecustomvalue() {
|
||||
let data = vec![1, 2, 3, 4, 5, 6, 7];
|
||||
let span = Span::new(0, 20);
|
||||
|
||||
let collapse_custom_value = PluginCall::CollapseCustomValue(PluginData {
|
||||
data: data.clone(),
|
||||
span,
|
||||
});
|
||||
|
||||
let encoder = MsgPackSerializer {};
|
||||
let mut buffer: Vec<u8> = Vec::new();
|
||||
encoder
|
||||
.encode_call(&collapse_custom_value, &mut buffer)
|
||||
.expect("unable to serialize message");
|
||||
let returned = encoder
|
||||
.decode_call(&mut buffer.as_slice())
|
||||
.expect("unable to deserialize message");
|
||||
|
||||
match returned {
|
||||
PluginCall::Signature => panic!("returned wrong call type"),
|
||||
PluginCall::CallInfo(_) => panic!("returned wrong call type"),
|
||||
PluginCall::CollapseCustomValue(plugin_data) => {
|
||||
assert_eq!(data, plugin_data.data);
|
||||
assert_eq!(span, plugin_data.span);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn response_round_trip_signature() {
|
||||
let signature = PluginSignature::build("nu-plugin")
|
||||
.required("first", SyntaxShape::String, "first required")
|
||||
.required("second", SyntaxShape::Int, "second required")
|
||||
.required_named("first-named", SyntaxShape::String, "first named", Some('f'))
|
||||
.required_named(
|
||||
"second-named",
|
||||
SyntaxShape::String,
|
||||
"second named",
|
||||
Some('s'),
|
||||
)
|
||||
.rest("remaining", SyntaxShape::Int, "remaining");
|
||||
|
||||
let response = PluginResponse::Signature(vec![signature.clone()]);
|
||||
|
||||
let encoder = MsgPackSerializer {};
|
||||
let mut buffer: Vec<u8> = Vec::new();
|
||||
encoder
|
||||
.encode_response(&response, &mut buffer)
|
||||
.expect("unable to serialize message");
|
||||
let returned = encoder
|
||||
.decode_response(&mut buffer.as_slice())
|
||||
.expect("unable to deserialize message");
|
||||
|
||||
match returned {
|
||||
PluginResponse::Error(_) => panic!("returned wrong call type"),
|
||||
PluginResponse::Value(_) => panic!("returned wrong call type"),
|
||||
PluginResponse::PluginData(..) => panic!("returned wrong call type"),
|
||||
PluginResponse::Signature(returned_signature) => {
|
||||
assert_eq!(returned_signature.len(), 1);
|
||||
assert_eq!(signature.sig.name, returned_signature[0].sig.name);
|
||||
assert_eq!(signature.sig.usage, returned_signature[0].sig.usage);
|
||||
assert_eq!(
|
||||
signature.sig.extra_usage,
|
||||
returned_signature[0].sig.extra_usage
|
||||
);
|
||||
assert_eq!(signature.sig.is_filter, returned_signature[0].sig.is_filter);
|
||||
|
||||
signature
|
||||
.sig
|
||||
.required_positional
|
||||
.iter()
|
||||
.zip(returned_signature[0].sig.required_positional.iter())
|
||||
.for_each(|(lhs, rhs)| assert_eq!(lhs, rhs));
|
||||
|
||||
signature
|
||||
.sig
|
||||
.optional_positional
|
||||
.iter()
|
||||
.zip(returned_signature[0].sig.optional_positional.iter())
|
||||
.for_each(|(lhs, rhs)| assert_eq!(lhs, rhs));
|
||||
|
||||
signature
|
||||
.sig
|
||||
.named
|
||||
.iter()
|
||||
.zip(returned_signature[0].sig.named.iter())
|
||||
.for_each(|(lhs, rhs)| assert_eq!(lhs, rhs));
|
||||
|
||||
assert_eq!(
|
||||
signature.sig.rest_positional,
|
||||
returned_signature[0].sig.rest_positional,
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn response_round_trip_value() {
|
||||
let value = Value::int(10, Span::new(2, 30));
|
||||
|
||||
let response = PluginResponse::Value(Box::new(value.clone()));
|
||||
|
||||
let encoder = MsgPackSerializer {};
|
||||
let mut buffer: Vec<u8> = Vec::new();
|
||||
encoder
|
||||
.encode_response(&response, &mut buffer)
|
||||
.expect("unable to serialize message");
|
||||
let returned = encoder
|
||||
.decode_response(&mut buffer.as_slice())
|
||||
.expect("unable to deserialize message");
|
||||
|
||||
match returned {
|
||||
PluginResponse::Error(_) => panic!("returned wrong call type"),
|
||||
PluginResponse::Signature(_) => panic!("returned wrong call type"),
|
||||
PluginResponse::PluginData(..) => panic!("returned wrong call type"),
|
||||
PluginResponse::Value(returned_value) => {
|
||||
assert_eq!(&value, returned_value.as_ref())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn response_round_trip_plugin_data() {
|
||||
let name = "test".to_string();
|
||||
|
||||
let data = vec![1, 2, 3, 4, 5];
|
||||
let span = Span::new(2, 30);
|
||||
|
||||
let response = PluginResponse::PluginData(
|
||||
name.clone(),
|
||||
PluginData {
|
||||
data: data.clone(),
|
||||
span,
|
||||
},
|
||||
);
|
||||
|
||||
let encoder = MsgPackSerializer {};
|
||||
let mut buffer: Vec<u8> = Vec::new();
|
||||
encoder
|
||||
.encode_response(&response, &mut buffer)
|
||||
.expect("unable to serialize message");
|
||||
let returned = encoder
|
||||
.decode_response(&mut buffer.as_slice())
|
||||
.expect("unable to deserialize message");
|
||||
|
||||
match returned {
|
||||
PluginResponse::Error(_) => panic!("returned wrong call type"),
|
||||
PluginResponse::Signature(_) => panic!("returned wrong call type"),
|
||||
PluginResponse::Value(_) => panic!("returned wrong call type"),
|
||||
PluginResponse::PluginData(returned_name, returned_plugin_data) => {
|
||||
assert_eq!(name, returned_name);
|
||||
assert_eq!(data, returned_plugin_data.data);
|
||||
assert_eq!(span, returned_plugin_data.span);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn response_round_trip_error() {
|
||||
let error = LabeledError {
|
||||
label: "label".into(),
|
||||
msg: "msg".into(),
|
||||
span: Some(Span::new(2, 30)),
|
||||
};
|
||||
let response = PluginResponse::Error(error.clone());
|
||||
|
||||
let encoder = MsgPackSerializer {};
|
||||
let mut buffer: Vec<u8> = Vec::new();
|
||||
encoder
|
||||
.encode_response(&response, &mut buffer)
|
||||
.expect("unable to serialize message");
|
||||
let returned = encoder
|
||||
.decode_response(&mut buffer.as_slice())
|
||||
.expect("unable to deserialize message");
|
||||
|
||||
match returned {
|
||||
PluginResponse::Error(msg) => assert_eq!(error, msg),
|
||||
PluginResponse::Signature(_) => panic!("returned wrong call type"),
|
||||
PluginResponse::Value(_) => panic!("returned wrong call type"),
|
||||
PluginResponse::PluginData(..) => panic!("returned wrong call type"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn response_round_trip_error_none() {
|
||||
let error = LabeledError {
|
||||
label: "label".into(),
|
||||
msg: "msg".into(),
|
||||
span: None,
|
||||
};
|
||||
let response = PluginResponse::Error(error.clone());
|
||||
|
||||
let encoder = MsgPackSerializer {};
|
||||
let mut buffer: Vec<u8> = Vec::new();
|
||||
encoder
|
||||
.encode_response(&response, &mut buffer)
|
||||
.expect("unable to serialize message");
|
||||
let returned = encoder
|
||||
.decode_response(&mut buffer.as_slice())
|
||||
.expect("unable to deserialize message");
|
||||
|
||||
match returned {
|
||||
PluginResponse::Error(msg) => assert_eq!(error, msg),
|
||||
PluginResponse::Signature(_) => panic!("returned wrong call type"),
|
||||
PluginResponse::Value(_) => panic!("returned wrong call type"),
|
||||
PluginResponse::PluginData(..) => panic!("returned wrong call type"),
|
||||
}
|
||||
}
|
||||
crate::serializers::tests::generate_tests!(MsgPackSerializer {});
|
||||
}
|
||||
|
538
crates/nu-plugin/src/serializers/tests.rs
Normal file
538
crates/nu-plugin/src/serializers/tests.rs
Normal file
@ -0,0 +1,538 @@
|
||||
macro_rules! generate_tests {
|
||||
($encoder:expr) => {
|
||||
use crate::protocol::{
|
||||
CallInfo, CustomValueOp, EvaluatedCall, LabeledError, PipelineDataHeader, PluginCall,
|
||||
PluginCallResponse, PluginCustomValue, PluginInput, PluginOutput, StreamData,
|
||||
StreamMessage,
|
||||
};
|
||||
use nu_protocol::{PluginSignature, Span, Spanned, SyntaxShape, Value};
|
||||
|
||||
#[test]
|
||||
fn decode_eof() {
|
||||
let mut buffer: &[u8] = &[];
|
||||
let encoder = $encoder;
|
||||
let result: Option<PluginInput> = encoder
|
||||
.decode(&mut buffer)
|
||||
.expect("eof should not result in an error");
|
||||
assert!(result.is_none(), "decode result: {result:?}");
|
||||
let result: Option<PluginOutput> = encoder
|
||||
.decode(&mut buffer)
|
||||
.expect("eof should not result in an error");
|
||||
assert!(result.is_none(), "decode result: {result:?}");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn decode_io_error() {
|
||||
struct ErrorProducer;
|
||||
impl std::io::Read for ErrorProducer {
|
||||
fn read(&mut self, _buf: &mut [u8]) -> std::io::Result<usize> {
|
||||
Err(std::io::Error::from(std::io::ErrorKind::ConnectionReset))
|
||||
}
|
||||
}
|
||||
let encoder = $encoder;
|
||||
let mut buffered = std::io::BufReader::new(ErrorProducer);
|
||||
match Encoder::<PluginInput>::decode(&encoder, &mut buffered) {
|
||||
Ok(_) => panic!("decode: i/o error was not passed through"),
|
||||
Err(ShellError::IOError { .. }) => (), // okay
|
||||
Err(other) => panic!(
|
||||
"decode: got other error, should have been a \
|
||||
ShellError::IOError: {other:?}"
|
||||
),
|
||||
}
|
||||
match Encoder::<PluginOutput>::decode(&encoder, &mut buffered) {
|
||||
Ok(_) => panic!("decode: i/o error was not passed through"),
|
||||
Err(ShellError::IOError { .. }) => (), // okay
|
||||
Err(other) => panic!(
|
||||
"decode: got other error, should have been a \
|
||||
ShellError::IOError: {other:?}"
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn decode_gibberish() {
|
||||
// just a sequence of bytes that shouldn't be valid in anything we use
|
||||
let gibberish: &[u8] = &[
|
||||
0, 80, 74, 85, 117, 122, 86, 100, 74, 115, 20, 104, 55, 98, 67, 203, 83, 85, 77,
|
||||
112, 74, 79, 254, 71, 80,
|
||||
];
|
||||
let encoder = $encoder;
|
||||
|
||||
let mut buffered = std::io::BufReader::new(&gibberish[..]);
|
||||
match Encoder::<PluginInput>::decode(&encoder, &mut buffered) {
|
||||
Ok(value) => panic!("decode: parsed successfully => {value:?}"),
|
||||
Err(ShellError::PluginFailedToDecode { .. }) => (), // okay
|
||||
Err(other) => panic!(
|
||||
"decode: got other error, should have been a \
|
||||
ShellError::PluginFailedToDecode: {other:?}"
|
||||
),
|
||||
}
|
||||
|
||||
let mut buffered = std::io::BufReader::new(&gibberish[..]);
|
||||
match Encoder::<PluginOutput>::decode(&encoder, &mut buffered) {
|
||||
Ok(value) => panic!("decode: parsed successfully => {value:?}"),
|
||||
Err(ShellError::PluginFailedToDecode { .. }) => (), // okay
|
||||
Err(other) => panic!(
|
||||
"decode: got other error, should have been a \
|
||||
ShellError::PluginFailedToDecode: {other:?}"
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn call_round_trip_signature() {
|
||||
let plugin_call = PluginCall::Signature;
|
||||
let plugin_input = PluginInput::Call(0, plugin_call);
|
||||
let encoder = $encoder;
|
||||
|
||||
let mut buffer: Vec<u8> = Vec::new();
|
||||
encoder
|
||||
.encode(&plugin_input, &mut buffer)
|
||||
.expect("unable to serialize message");
|
||||
let returned = encoder
|
||||
.decode(&mut buffer.as_slice())
|
||||
.expect("unable to deserialize message")
|
||||
.expect("eof");
|
||||
|
||||
match returned {
|
||||
PluginInput::Call(0, PluginCall::Signature) => {}
|
||||
_ => panic!("decoded into wrong value: {returned:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn call_round_trip_run() {
|
||||
let name = "test".to_string();
|
||||
|
||||
let input = Value::bool(false, Span::new(1, 20));
|
||||
|
||||
let call = EvaluatedCall {
|
||||
head: Span::new(0, 10),
|
||||
positional: vec![
|
||||
Value::float(1.0, Span::new(0, 10)),
|
||||
Value::string("something", Span::new(0, 10)),
|
||||
],
|
||||
named: vec![(
|
||||
Spanned {
|
||||
item: "name".to_string(),
|
||||
span: Span::new(0, 10),
|
||||
},
|
||||
Some(Value::float(1.0, Span::new(0, 10))),
|
||||
)],
|
||||
};
|
||||
|
||||
let plugin_call = PluginCall::Run(CallInfo {
|
||||
name: name.clone(),
|
||||
call: call.clone(),
|
||||
input: PipelineDataHeader::Value(input.clone()),
|
||||
config: None,
|
||||
});
|
||||
|
||||
let plugin_input = PluginInput::Call(1, plugin_call);
|
||||
|
||||
let encoder = $encoder;
|
||||
let mut buffer: Vec<u8> = Vec::new();
|
||||
encoder
|
||||
.encode(&plugin_input, &mut buffer)
|
||||
.expect("unable to serialize message");
|
||||
let returned = encoder
|
||||
.decode(&mut buffer.as_slice())
|
||||
.expect("unable to deserialize message")
|
||||
.expect("eof");
|
||||
|
||||
match returned {
|
||||
PluginInput::Call(1, PluginCall::Run(call_info)) => {
|
||||
assert_eq!(name, call_info.name);
|
||||
assert_eq!(PipelineDataHeader::Value(input), call_info.input);
|
||||
assert_eq!(call.head, call_info.call.head);
|
||||
assert_eq!(call.positional.len(), call_info.call.positional.len());
|
||||
|
||||
call.positional
|
||||
.iter()
|
||||
.zip(call_info.call.positional.iter())
|
||||
.for_each(|(lhs, rhs)| assert_eq!(lhs, rhs));
|
||||
|
||||
call.named
|
||||
.iter()
|
||||
.zip(call_info.call.named.iter())
|
||||
.for_each(|(lhs, rhs)| {
|
||||
// Comparing the keys
|
||||
assert_eq!(lhs.0.item, rhs.0.item);
|
||||
|
||||
match (&lhs.1, &rhs.1) {
|
||||
(None, None) => {}
|
||||
(Some(a), Some(b)) => assert_eq!(a, b),
|
||||
_ => panic!("not matching values"),
|
||||
}
|
||||
});
|
||||
}
|
||||
_ => panic!("decoded into wrong value: {returned:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn call_round_trip_customvalueop() {
|
||||
let data = vec![1, 2, 3, 4, 5, 6, 7];
|
||||
let span = Span::new(0, 20);
|
||||
|
||||
let custom_value_op = PluginCall::CustomValueOp(
|
||||
Spanned {
|
||||
item: PluginCustomValue {
|
||||
name: "Foo".into(),
|
||||
data: data.clone(),
|
||||
source: None,
|
||||
},
|
||||
span,
|
||||
},
|
||||
CustomValueOp::ToBaseValue,
|
||||
);
|
||||
|
||||
let plugin_input = PluginInput::Call(2, custom_value_op);
|
||||
|
||||
let encoder = $encoder;
|
||||
let mut buffer: Vec<u8> = Vec::new();
|
||||
encoder
|
||||
.encode(&plugin_input, &mut buffer)
|
||||
.expect("unable to serialize message");
|
||||
let returned = encoder
|
||||
.decode(&mut buffer.as_slice())
|
||||
.expect("unable to deserialize message")
|
||||
.expect("eof");
|
||||
|
||||
match returned {
|
||||
PluginInput::Call(2, PluginCall::CustomValueOp(val, op)) => {
|
||||
assert_eq!("Foo", val.item.name);
|
||||
assert_eq!(data, val.item.data);
|
||||
assert_eq!(span, val.span);
|
||||
#[allow(unreachable_patterns)]
|
||||
match op {
|
||||
CustomValueOp::ToBaseValue => (),
|
||||
_ => panic!("wrong op: {op:?}"),
|
||||
}
|
||||
}
|
||||
_ => panic!("decoded into wrong value: {returned:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn response_round_trip_signature() {
|
||||
let signature = PluginSignature::build("nu-plugin")
|
||||
.required("first", SyntaxShape::String, "first required")
|
||||
.required("second", SyntaxShape::Int, "second required")
|
||||
.required_named("first-named", SyntaxShape::String, "first named", Some('f'))
|
||||
.required_named(
|
||||
"second-named",
|
||||
SyntaxShape::String,
|
||||
"second named",
|
||||
Some('s'),
|
||||
)
|
||||
.rest("remaining", SyntaxShape::Int, "remaining");
|
||||
|
||||
let response = PluginCallResponse::Signature(vec![signature.clone()]);
|
||||
let output = PluginOutput::CallResponse(3, response);
|
||||
|
||||
let encoder = $encoder;
|
||||
let mut buffer: Vec<u8> = Vec::new();
|
||||
encoder
|
||||
.encode(&output, &mut buffer)
|
||||
.expect("unable to serialize message");
|
||||
let returned = encoder
|
||||
.decode(&mut buffer.as_slice())
|
||||
.expect("unable to deserialize message")
|
||||
.expect("eof");
|
||||
|
||||
match returned {
|
||||
PluginOutput::CallResponse(
|
||||
3,
|
||||
PluginCallResponse::Signature(returned_signature),
|
||||
) => {
|
||||
assert_eq!(returned_signature.len(), 1);
|
||||
assert_eq!(signature.sig.name, returned_signature[0].sig.name);
|
||||
assert_eq!(signature.sig.usage, returned_signature[0].sig.usage);
|
||||
assert_eq!(
|
||||
signature.sig.extra_usage,
|
||||
returned_signature[0].sig.extra_usage
|
||||
);
|
||||
assert_eq!(signature.sig.is_filter, returned_signature[0].sig.is_filter);
|
||||
|
||||
signature
|
||||
.sig
|
||||
.required_positional
|
||||
.iter()
|
||||
.zip(returned_signature[0].sig.required_positional.iter())
|
||||
.for_each(|(lhs, rhs)| assert_eq!(lhs, rhs));
|
||||
|
||||
signature
|
||||
.sig
|
||||
.optional_positional
|
||||
.iter()
|
||||
.zip(returned_signature[0].sig.optional_positional.iter())
|
||||
.for_each(|(lhs, rhs)| assert_eq!(lhs, rhs));
|
||||
|
||||
signature
|
||||
.sig
|
||||
.named
|
||||
.iter()
|
||||
.zip(returned_signature[0].sig.named.iter())
|
||||
.for_each(|(lhs, rhs)| assert_eq!(lhs, rhs));
|
||||
|
||||
assert_eq!(
|
||||
signature.sig.rest_positional,
|
||||
returned_signature[0].sig.rest_positional,
|
||||
);
|
||||
}
|
||||
_ => panic!("decoded into wrong value: {returned:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn response_round_trip_value() {
|
||||
let value = Value::int(10, Span::new(2, 30));
|
||||
|
||||
let response = PluginCallResponse::value(value.clone());
|
||||
let output = PluginOutput::CallResponse(4, response);
|
||||
|
||||
let encoder = $encoder;
|
||||
let mut buffer: Vec<u8> = Vec::new();
|
||||
encoder
|
||||
.encode(&output, &mut buffer)
|
||||
.expect("unable to serialize message");
|
||||
let returned = encoder
|
||||
.decode(&mut buffer.as_slice())
|
||||
.expect("unable to deserialize message")
|
||||
.expect("eof");
|
||||
|
||||
match returned {
|
||||
PluginOutput::CallResponse(
|
||||
4,
|
||||
PluginCallResponse::PipelineData(PipelineDataHeader::Value(returned_value)),
|
||||
) => {
|
||||
assert_eq!(value, returned_value)
|
||||
}
|
||||
_ => panic!("decoded into wrong value: {returned:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn response_round_trip_plugin_custom_value() {
|
||||
let name = "test";
|
||||
|
||||
let data = vec![1, 2, 3, 4, 5];
|
||||
let span = Span::new(2, 30);
|
||||
|
||||
let value = Value::custom_value(
|
||||
Box::new(PluginCustomValue {
|
||||
name: name.into(),
|
||||
data: data.clone(),
|
||||
source: None,
|
||||
}),
|
||||
span,
|
||||
);
|
||||
|
||||
let response = PluginCallResponse::PipelineData(PipelineDataHeader::Value(value));
|
||||
let output = PluginOutput::CallResponse(5, response);
|
||||
|
||||
let encoder = $encoder;
|
||||
let mut buffer: Vec<u8> = Vec::new();
|
||||
encoder
|
||||
.encode(&output, &mut buffer)
|
||||
.expect("unable to serialize message");
|
||||
let returned = encoder
|
||||
.decode(&mut buffer.as_slice())
|
||||
.expect("unable to deserialize message")
|
||||
.expect("eof");
|
||||
|
||||
match returned {
|
||||
PluginOutput::CallResponse(
|
||||
5,
|
||||
PluginCallResponse::PipelineData(PipelineDataHeader::Value(returned_value)),
|
||||
) => {
|
||||
assert_eq!(span, returned_value.span());
|
||||
|
||||
if let Some(plugin_val) = returned_value
|
||||
.as_custom_value()
|
||||
.unwrap()
|
||||
.as_any()
|
||||
.downcast_ref::<PluginCustomValue>()
|
||||
{
|
||||
assert_eq!(name, plugin_val.name);
|
||||
assert_eq!(data, plugin_val.data);
|
||||
} else {
|
||||
panic!("returned CustomValue is not a PluginCustomValue");
|
||||
}
|
||||
}
|
||||
_ => panic!("decoded into wrong value: {returned:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn response_round_trip_error() {
|
||||
let error = LabeledError {
|
||||
label: "label".into(),
|
||||
msg: "msg".into(),
|
||||
span: Some(Span::new(2, 30)),
|
||||
};
|
||||
let response = PluginCallResponse::Error(error.clone());
|
||||
let output = PluginOutput::CallResponse(6, response);
|
||||
|
||||
let encoder = $encoder;
|
||||
let mut buffer: Vec<u8> = Vec::new();
|
||||
encoder
|
||||
.encode(&output, &mut buffer)
|
||||
.expect("unable to serialize message");
|
||||
let returned = encoder
|
||||
.decode(&mut buffer.as_slice())
|
||||
.expect("unable to deserialize message")
|
||||
.expect("eof");
|
||||
|
||||
match returned {
|
||||
PluginOutput::CallResponse(6, PluginCallResponse::Error(msg)) => {
|
||||
assert_eq!(error, msg)
|
||||
}
|
||||
_ => panic!("decoded into wrong value: {returned:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn response_round_trip_error_none() {
|
||||
let error = LabeledError {
|
||||
label: "label".into(),
|
||||
msg: "msg".into(),
|
||||
span: None,
|
||||
};
|
||||
let response = PluginCallResponse::Error(error.clone());
|
||||
let output = PluginOutput::CallResponse(7, response);
|
||||
|
||||
let encoder = $encoder;
|
||||
let mut buffer: Vec<u8> = Vec::new();
|
||||
encoder
|
||||
.encode(&output, &mut buffer)
|
||||
.expect("unable to serialize message");
|
||||
let returned = encoder
|
||||
.decode(&mut buffer.as_slice())
|
||||
.expect("unable to deserialize message")
|
||||
.expect("eof");
|
||||
|
||||
match returned {
|
||||
PluginOutput::CallResponse(7, PluginCallResponse::Error(msg)) => {
|
||||
assert_eq!(error, msg)
|
||||
}
|
||||
_ => panic!("decoded into wrong value: {returned:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn input_round_trip_stream_data_list() {
|
||||
let span = Span::new(12, 30);
|
||||
let item = Value::int(1, span);
|
||||
|
||||
let stream_data = StreamData::List(item.clone());
|
||||
let plugin_input = PluginInput::Stream(StreamMessage::Data(0, stream_data));
|
||||
|
||||
let encoder = $encoder;
|
||||
let mut buffer: Vec<u8> = Vec::new();
|
||||
encoder
|
||||
.encode(&plugin_input, &mut buffer)
|
||||
.expect("unable to serialize message");
|
||||
let returned = encoder
|
||||
.decode(&mut buffer.as_slice())
|
||||
.expect("unable to deserialize message")
|
||||
.expect("eof");
|
||||
|
||||
match returned {
|
||||
PluginInput::Stream(StreamMessage::Data(id, StreamData::List(list_data))) => {
|
||||
assert_eq!(0, id);
|
||||
assert_eq!(item, list_data);
|
||||
}
|
||||
_ => panic!("decoded into wrong value: {returned:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn input_round_trip_stream_data_raw() {
|
||||
let data = b"Hello world";
|
||||
|
||||
let stream_data = StreamData::Raw(Ok(data.to_vec()));
|
||||
let plugin_input = PluginInput::Stream(StreamMessage::Data(1, stream_data));
|
||||
|
||||
let encoder = $encoder;
|
||||
let mut buffer: Vec<u8> = Vec::new();
|
||||
encoder
|
||||
.encode(&plugin_input, &mut buffer)
|
||||
.expect("unable to serialize message");
|
||||
let returned = encoder
|
||||
.decode(&mut buffer.as_slice())
|
||||
.expect("unable to deserialize message")
|
||||
.expect("eof");
|
||||
|
||||
match returned {
|
||||
PluginInput::Stream(StreamMessage::Data(id, StreamData::Raw(bytes))) => {
|
||||
assert_eq!(1, id);
|
||||
match bytes {
|
||||
Ok(bytes) => assert_eq!(data, &bytes[..]),
|
||||
Err(err) => panic!("decoded into error variant: {err:?}"),
|
||||
}
|
||||
}
|
||||
_ => panic!("decoded into wrong value: {returned:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn output_round_trip_stream_data_list() {
|
||||
let span = Span::new(12, 30);
|
||||
let item = Value::int(1, span);
|
||||
|
||||
let stream_data = StreamData::List(item.clone());
|
||||
let plugin_output = PluginOutput::Stream(StreamMessage::Data(4, stream_data));
|
||||
|
||||
let encoder = $encoder;
|
||||
let mut buffer: Vec<u8> = Vec::new();
|
||||
encoder
|
||||
.encode(&plugin_output, &mut buffer)
|
||||
.expect("unable to serialize message");
|
||||
let returned = encoder
|
||||
.decode(&mut buffer.as_slice())
|
||||
.expect("unable to deserialize message")
|
||||
.expect("eof");
|
||||
|
||||
match returned {
|
||||
PluginOutput::Stream(StreamMessage::Data(id, StreamData::List(list_data))) => {
|
||||
assert_eq!(4, id);
|
||||
assert_eq!(item, list_data);
|
||||
}
|
||||
_ => panic!("decoded into wrong value: {returned:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn output_round_trip_stream_data_raw() {
|
||||
let data = b"Hello world";
|
||||
|
||||
let stream_data = StreamData::Raw(Ok(data.to_vec()));
|
||||
let plugin_output = PluginOutput::Stream(StreamMessage::Data(5, stream_data));
|
||||
|
||||
let encoder = $encoder;
|
||||
let mut buffer: Vec<u8> = Vec::new();
|
||||
encoder
|
||||
.encode(&plugin_output, &mut buffer)
|
||||
.expect("unable to serialize message");
|
||||
let returned = encoder
|
||||
.decode(&mut buffer.as_slice())
|
||||
.expect("unable to deserialize message")
|
||||
.expect("eof");
|
||||
|
||||
match returned {
|
||||
PluginOutput::Stream(StreamMessage::Data(id, StreamData::Raw(bytes))) => {
|
||||
assert_eq!(5, id);
|
||||
match bytes {
|
||||
Ok(bytes) => assert_eq!(data, &bytes[..]),
|
||||
Err(err) => panic!("decoded into error variant: {err:?}"),
|
||||
}
|
||||
}
|
||||
_ => panic!("decoded into wrong value: {returned:?}"),
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
pub(crate) use generate_tests;
|
@ -6,7 +6,7 @@ use crate::engine::Command;
|
||||
use crate::{BlockId, Category, Flag, PositionalArg, SyntaxShape, Type};
|
||||
|
||||
/// A simple wrapper for Signature that includes examples.
|
||||
#[derive(Clone, Serialize, Deserialize)]
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct PluginSignature {
|
||||
pub sig: Signature,
|
||||
pub examples: Vec<PluginExample>,
|
||||
|
@ -774,6 +774,54 @@ pub enum ShellError {
|
||||
#[diagnostic(code(nu::shell::plugin_failed_to_decode))]
|
||||
PluginFailedToDecode { msg: String },
|
||||
|
||||
/// A custom value cannot be sent to the given plugin.
|
||||
///
|
||||
/// ## Resolution
|
||||
///
|
||||
/// Custom values can only be used with the plugin they came from. Use a command from that
|
||||
/// plugin instead.
|
||||
#[error("Custom value `{name}` cannot be sent to plugin")]
|
||||
#[diagnostic(code(nu::shell::custom_value_incorrect_for_plugin))]
|
||||
CustomValueIncorrectForPlugin {
|
||||
name: String,
|
||||
#[label("the `{dest_plugin}` plugin does not support this kind of value")]
|
||||
span: Span,
|
||||
dest_plugin: String,
|
||||
#[help("this value came from the `{}` plugin")]
|
||||
src_plugin: Option<String>,
|
||||
},
|
||||
|
||||
/// The plugin failed to encode a custom value.
|
||||
///
|
||||
/// ## Resolution
|
||||
///
|
||||
/// This is likely a bug with the plugin itself. The plugin may have tried to send a custom
|
||||
/// value that is not serializable.
|
||||
#[error("Custom value failed to encode")]
|
||||
#[diagnostic(code(nu::shell::custom_value_failed_to_encode))]
|
||||
CustomValueFailedToEncode {
|
||||
msg: String,
|
||||
#[label("{msg}")]
|
||||
span: Span,
|
||||
},
|
||||
|
||||
/// The plugin failed to encode a custom value.
|
||||
///
|
||||
/// ## Resolution
|
||||
///
|
||||
/// This may be a bug within the plugin, or the plugin may have been updated in between the
|
||||
/// creation of the custom value and its use.
|
||||
#[error("Custom value failed to decode")]
|
||||
#[diagnostic(code(nu::shell::custom_value_failed_to_decode))]
|
||||
#[diagnostic(help(
|
||||
"the plugin may have been updated and no longer support this custom value"
|
||||
))]
|
||||
CustomValueFailedToDecode {
|
||||
msg: String,
|
||||
#[label("{msg}")]
|
||||
span: Span,
|
||||
},
|
||||
|
||||
/// I/O operation interrupted.
|
||||
///
|
||||
/// ## Resolution
|
||||
|
@ -164,7 +164,6 @@ pub enum Value {
|
||||
#[serde(rename = "span")]
|
||||
internal_span: Span,
|
||||
},
|
||||
#[serde(skip_serializing)]
|
||||
CustomValue {
|
||||
val: Box<dyn CustomValue>,
|
||||
// note: spans are being refactored out of Value
|
||||
|
@ -4,7 +4,7 @@ mod second_custom_value;
|
||||
use cool_custom_value::CoolCustomValue;
|
||||
use nu_plugin::{serve_plugin, MsgPackSerializer, Plugin};
|
||||
use nu_plugin::{EvaluatedCall, LabeledError};
|
||||
use nu_protocol::{Category, PluginSignature, ShellError, Value};
|
||||
use nu_protocol::{Category, PluginSignature, ShellError, SyntaxShape, Value};
|
||||
use second_custom_value::SecondCustomValue;
|
||||
|
||||
struct CustomValuePlugin;
|
||||
@ -21,6 +21,14 @@ impl Plugin for CustomValuePlugin {
|
||||
PluginSignature::build("custom-value update")
|
||||
.usage("PluginSignature for a plugin that updates a custom value")
|
||||
.category(Category::Experimental),
|
||||
PluginSignature::build("custom-value update-arg")
|
||||
.usage("PluginSignature for a plugin that updates a custom value as an argument")
|
||||
.required(
|
||||
"custom_value",
|
||||
SyntaxShape::Any,
|
||||
"the custom value to update",
|
||||
)
|
||||
.category(Category::Experimental),
|
||||
]
|
||||
}
|
||||
|
||||
@ -35,6 +43,7 @@ impl Plugin for CustomValuePlugin {
|
||||
"custom-value generate" => self.generate(call, input),
|
||||
"custom-value generate2" => self.generate2(call, input),
|
||||
"custom-value update" => self.update(call, input),
|
||||
"custom-value update-arg" => self.update(call, &call.req(0)?),
|
||||
_ => Err(LabeledError {
|
||||
label: "Plugin call with wrong name signature".into(),
|
||||
msg: "the signature used to call the plugin does not match any name in the plugin signature vector".into(),
|
||||
|
@ -10,9 +10,8 @@
|
||||
# register <path-to-py-file>
|
||||
#
|
||||
# Be careful with the spans. Miette will crash if a span is outside the
|
||||
# size of the contents vector. For this example we are using 0 and 1, which will
|
||||
# point to the beginning of the contents vector. We strongly suggest using the span
|
||||
# found in the plugin call head
|
||||
# size of the contents vector. We strongly suggest using the span found in the
|
||||
# plugin call head as in this example.
|
||||
#
|
||||
# The plugin will be run using the active Python implementation. If you are in
|
||||
# a Python environment, that is the Python version that is used
|
||||
@ -113,7 +112,7 @@ def signatures():
|
||||
}
|
||||
|
||||
|
||||
def process_call(plugin_call):
|
||||
def process_call(id, plugin_call):
|
||||
"""
|
||||
plugin_call is a dictionary with the information from the call
|
||||
It should contain:
|
||||
@ -127,277 +126,38 @@ def process_call(plugin_call):
|
||||
sys.stderr.write(json.dumps(plugin_call, indent=4))
|
||||
sys.stderr.write("\n")
|
||||
|
||||
# Get the span from the call
|
||||
span = plugin_call["Run"]["call"]["head"]
|
||||
|
||||
# Creates a Value of type List that will be encoded and sent to Nushell
|
||||
return {
|
||||
value = {
|
||||
"Value": {
|
||||
"List": {
|
||||
"vals": [
|
||||
{
|
||||
"Record": {
|
||||
"cols": ["one", "two", "three"],
|
||||
"vals": [
|
||||
{
|
||||
"Int": {
|
||||
"val": 0,
|
||||
"span": {"start": 0, "end": 1},
|
||||
}
|
||||
},
|
||||
{
|
||||
"Int": {
|
||||
"val": 0,
|
||||
"span": {"start": 0, "end": 1},
|
||||
}
|
||||
},
|
||||
{
|
||||
"Int": {
|
||||
"val": 0,
|
||||
"span": {"start": 0, "end": 1},
|
||||
}
|
||||
},
|
||||
],
|
||||
"span": {"start": 0, "end": 1},
|
||||
"val": {
|
||||
"cols": ["one", "two", "three"],
|
||||
"vals": [
|
||||
{
|
||||
"Int": {
|
||||
"val": x * y,
|
||||
"span": span
|
||||
}
|
||||
} for y in [0, 1, 2]
|
||||
]
|
||||
},
|
||||
"span": span
|
||||
}
|
||||
},
|
||||
{
|
||||
"Record": {
|
||||
"cols": ["one", "two", "three"],
|
||||
"vals": [
|
||||
{
|
||||
"Int": {
|
||||
"val": 0,
|
||||
"span": {"start": 0, "end": 1},
|
||||
}
|
||||
},
|
||||
{
|
||||
"Int": {
|
||||
"val": 1,
|
||||
"span": {"start": 0, "end": 1},
|
||||
}
|
||||
},
|
||||
{
|
||||
"Int": {
|
||||
"val": 2,
|
||||
"span": {"start": 0, "end": 1},
|
||||
}
|
||||
},
|
||||
],
|
||||
"span": {"start": 0, "end": 1},
|
||||
}
|
||||
},
|
||||
{
|
||||
"Record": {
|
||||
"cols": ["one", "two", "three"],
|
||||
"vals": [
|
||||
{
|
||||
"Int": {
|
||||
"val": 0,
|
||||
"span": {"start": 0, "end": 1},
|
||||
}
|
||||
},
|
||||
{
|
||||
"Int": {
|
||||
"val": 2,
|
||||
"span": {"start": 0, "end": 1},
|
||||
}
|
||||
},
|
||||
{
|
||||
"Int": {
|
||||
"val": 4,
|
||||
"span": {"start": 0, "end": 1},
|
||||
}
|
||||
},
|
||||
],
|
||||
"span": {"start": 0, "end": 1},
|
||||
}
|
||||
},
|
||||
{
|
||||
"Record": {
|
||||
"cols": ["one", "two", "three"],
|
||||
"vals": [
|
||||
{
|
||||
"Int": {
|
||||
"val": 0,
|
||||
"span": {"start": 0, "end": 1},
|
||||
}
|
||||
},
|
||||
{
|
||||
"Int": {
|
||||
"val": 3,
|
||||
"span": {"start": 0, "end": 1},
|
||||
}
|
||||
},
|
||||
{
|
||||
"Int": {
|
||||
"val": 6,
|
||||
"span": {"start": 0, "end": 1},
|
||||
}
|
||||
},
|
||||
],
|
||||
"span": {"start": 0, "end": 1},
|
||||
}
|
||||
},
|
||||
{
|
||||
"Record": {
|
||||
"cols": ["one", "two", "three"],
|
||||
"vals": [
|
||||
{
|
||||
"Int": {
|
||||
"val": 0,
|
||||
"span": {"start": 0, "end": 1},
|
||||
}
|
||||
},
|
||||
{
|
||||
"Int": {
|
||||
"val": 4,
|
||||
"span": {"start": 0, "end": 1},
|
||||
}
|
||||
},
|
||||
{
|
||||
"Int": {
|
||||
"val": 8,
|
||||
"span": {"start": 0, "end": 1},
|
||||
}
|
||||
},
|
||||
],
|
||||
"span": {"start": 0, "end": 1},
|
||||
}
|
||||
},
|
||||
{
|
||||
"Record": {
|
||||
"cols": ["one", "two", "three"],
|
||||
"vals": [
|
||||
{
|
||||
"Int": {
|
||||
"val": 0,
|
||||
"span": {"start": 0, "end": 1},
|
||||
}
|
||||
},
|
||||
{
|
||||
"Int": {
|
||||
"val": 5,
|
||||
"span": {"start": 0, "end": 1},
|
||||
}
|
||||
},
|
||||
{
|
||||
"Int": {
|
||||
"val": 10,
|
||||
"span": {"start": 0, "end": 1},
|
||||
}
|
||||
},
|
||||
],
|
||||
"span": {"start": 0, "end": 1},
|
||||
}
|
||||
},
|
||||
{
|
||||
"Record": {
|
||||
"cols": ["one", "two", "three"],
|
||||
"vals": [
|
||||
{
|
||||
"Int": {
|
||||
"val": 0,
|
||||
"span": {"start": 0, "end": 1},
|
||||
}
|
||||
},
|
||||
{
|
||||
"Int": {
|
||||
"val": 6,
|
||||
"span": {"start": 0, "end": 1},
|
||||
}
|
||||
},
|
||||
{
|
||||
"Int": {
|
||||
"val": 12,
|
||||
"span": {"start": 0, "end": 1},
|
||||
}
|
||||
},
|
||||
],
|
||||
"span": {"start": 0, "end": 1},
|
||||
}
|
||||
},
|
||||
{
|
||||
"Record": {
|
||||
"cols": ["one", "two", "three"],
|
||||
"vals": [
|
||||
{
|
||||
"Int": {
|
||||
"val": 0,
|
||||
"span": {"start": 0, "end": 1},
|
||||
}
|
||||
},
|
||||
{
|
||||
"Int": {
|
||||
"val": 7,
|
||||
"span": {"start": 0, "end": 1},
|
||||
}
|
||||
},
|
||||
{
|
||||
"Int": {
|
||||
"val": 14,
|
||||
"span": {"start": 0, "end": 1},
|
||||
}
|
||||
},
|
||||
],
|
||||
"span": {"start": 0, "end": 1},
|
||||
}
|
||||
},
|
||||
{
|
||||
"Record": {
|
||||
"cols": ["one", "two", "three"],
|
||||
"vals": [
|
||||
{
|
||||
"Int": {
|
||||
"val": 0,
|
||||
"span": {"start": 0, "end": 1},
|
||||
}
|
||||
},
|
||||
{
|
||||
"Int": {
|
||||
"val": 8,
|
||||
"span": {"start": 0, "end": 1},
|
||||
}
|
||||
},
|
||||
{
|
||||
"Int": {
|
||||
"val": 16,
|
||||
"span": {"start": 0, "end": 1},
|
||||
}
|
||||
},
|
||||
],
|
||||
"span": {"start": 0, "end": 1},
|
||||
}
|
||||
},
|
||||
{
|
||||
"Record": {
|
||||
"cols": ["one", "two", "three"],
|
||||
"vals": [
|
||||
{
|
||||
"Int": {
|
||||
"val": 0,
|
||||
"span": {"start": 0, "end": 1},
|
||||
}
|
||||
},
|
||||
{
|
||||
"Int": {
|
||||
"val": 9,
|
||||
"span": {"start": 0, "end": 1},
|
||||
}
|
||||
},
|
||||
{
|
||||
"Int": {
|
||||
"val": 18,
|
||||
"span": {"start": 0, "end": 1},
|
||||
}
|
||||
},
|
||||
],
|
||||
"span": {"start": 0, "end": 1},
|
||||
}
|
||||
},
|
||||
} for x in range(0, 10)
|
||||
],
|
||||
"span": {"start": 0, "end": 1},
|
||||
"span": span
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
write_response(id, {"PipelineData": value})
|
||||
|
||||
|
||||
def tell_nushell_encoding():
|
||||
sys.stdout.write(chr(4))
|
||||
@ -406,30 +166,79 @@ def tell_nushell_encoding():
|
||||
sys.stdout.flush()
|
||||
|
||||
|
||||
def tell_nushell_hello():
|
||||
"""
|
||||
A `Hello` message is required at startup to inform nushell of the protocol capabilities and
|
||||
compatibility of the plugin. The version specified should be the version of nushell that this
|
||||
plugin was tested and developed against.
|
||||
"""
|
||||
hello = {
|
||||
"Hello": {
|
||||
"protocol": "nu-plugin", # always this value
|
||||
"version": "0.90.2",
|
||||
"features": []
|
||||
}
|
||||
}
|
||||
sys.stdout.write(json.dumps(hello))
|
||||
sys.stdout.write("\n")
|
||||
sys.stdout.flush()
|
||||
|
||||
|
||||
def write_response(id, response):
|
||||
"""
|
||||
Use this format to send a response to a plugin call. The ID of the plugin call is required.
|
||||
"""
|
||||
wrapped_response = {
|
||||
"CallResponse": [
|
||||
id,
|
||||
response,
|
||||
]
|
||||
}
|
||||
sys.stdout.write(json.dumps(wrapped_response))
|
||||
sys.stdout.write("\n")
|
||||
sys.stdout.flush()
|
||||
|
||||
|
||||
def write_error(id, msg, span=None):
|
||||
"""
|
||||
Use this error format to send errors to nushell in response to a plugin call. The ID of the
|
||||
plugin call is required.
|
||||
"""
|
||||
error = {
|
||||
"Error": {
|
||||
"label": "ERROR from plugin",
|
||||
"msg": msg,
|
||||
"span": span
|
||||
}
|
||||
}
|
||||
write_response(id, error)
|
||||
|
||||
|
||||
def handle_input(input):
|
||||
if "Hello" in input:
|
||||
return
|
||||
elif "Call" in input:
|
||||
[id, plugin_call] = input["Call"]
|
||||
if "Signature" in plugin_call:
|
||||
write_response(id, signatures())
|
||||
elif "Run" in plugin_call:
|
||||
process_call(id, plugin_call)
|
||||
else:
|
||||
write_error(id, "Operation not supported: " + str(plugin_call))
|
||||
else:
|
||||
sys.stderr.write("Unknown message: " + str(input) + "\n")
|
||||
exit(1)
|
||||
|
||||
|
||||
def plugin():
|
||||
tell_nushell_encoding()
|
||||
call_str = ",".join(sys.stdin.readlines())
|
||||
plugin_call = json.loads(call_str)
|
||||
|
||||
if plugin_call == "Signature":
|
||||
signature = json.dumps(signatures())
|
||||
sys.stdout.write(signature)
|
||||
|
||||
elif "CallInfo" in plugin_call:
|
||||
response = process_call(plugin_call)
|
||||
sys.stdout.write(json.dumps(response))
|
||||
|
||||
else:
|
||||
# Use this error format if you want to return an error back to Nushell
|
||||
error = {
|
||||
"Error": {
|
||||
"label": "ERROR from plugin",
|
||||
"msg": "error message pointing to call head span",
|
||||
"span": {"start": 0, "end": 1},
|
||||
}
|
||||
}
|
||||
sys.stdout.write(json.dumps(error))
|
||||
|
||||
tell_nushell_hello()
|
||||
for line in sys.stdin:
|
||||
input = json.loads(line)
|
||||
handle_input(input)
|
||||
|
||||
if __name__ == "__main__":
|
||||
plugin()
|
||||
if len(sys.argv) == 2 and sys.argv[1] == "--stdio":
|
||||
plugin()
|
||||
else:
|
||||
print("Run me from inside nushell!")
|
||||
|
19
crates/nu_plugin_stream_example/Cargo.toml
Normal file
19
crates/nu_plugin_stream_example/Cargo.toml
Normal file
@ -0,0 +1,19 @@
|
||||
[package]
|
||||
authors = ["The Nushell Project Developers"]
|
||||
description = "An example of stream handling in nushell plugins"
|
||||
repository = "https://github.com/nushell/nushell/tree/main/crates/nu_plugin_stream_example"
|
||||
edition = "2021"
|
||||
license = "MIT"
|
||||
name = "nu_plugin_stream_example"
|
||||
version = "0.90.2"
|
||||
|
||||
[[bin]]
|
||||
name = "nu_plugin_stream_example"
|
||||
bench = false
|
||||
|
||||
[lib]
|
||||
bench = false
|
||||
|
||||
[dependencies]
|
||||
nu-plugin = { path = "../nu-plugin", version = "0.90.2" }
|
||||
nu-protocol = { path = "../nu-protocol", version = "0.90.2", features = ["plugin"] }
|
48
crates/nu_plugin_stream_example/README.md
Normal file
48
crates/nu_plugin_stream_example/README.md
Normal file
@ -0,0 +1,48 @@
|
||||
# Streaming Plugin Example
|
||||
|
||||
Crate with a simple example of the `StreamingPlugin` trait that needs to be implemented
|
||||
in order to create a binary that can be registered into nushell declaration list
|
||||
|
||||
## `stream_example seq`
|
||||
|
||||
This command demonstrates generating list streams. It generates numbers from the first argument
|
||||
to the second argument just like the builtin `seq` command does.
|
||||
|
||||
Examples:
|
||||
|
||||
> ```nushell
|
||||
> stream_example seq 1 10
|
||||
> ```
|
||||
|
||||
[1 2 3 4 5 6 7 8 9 10]
|
||||
|
||||
> ```nushell
|
||||
> stream_example seq 1 10 | describe
|
||||
> ```
|
||||
|
||||
list<int> (stream)
|
||||
|
||||
## `stream_example sum`
|
||||
|
||||
This command demonstrates consuming list streams. It consumes a stream of numbers and calculates the
|
||||
sum just like the builtin `math sum` command does.
|
||||
|
||||
Examples:
|
||||
|
||||
> ```nushell
|
||||
> seq 1 5 | stream_example sum
|
||||
> ```
|
||||
|
||||
15
|
||||
|
||||
## `stream_example collect-external`
|
||||
|
||||
This command demonstrates transforming streams into external streams. The list (or stream) of
|
||||
strings on input will be concatenated into an external stream (raw input) on stdout.
|
||||
|
||||
> ```nushell
|
||||
> [Hello "\n" world how are you] | stream_example collect-external
|
||||
> ````
|
||||
|
||||
Hello
|
||||
worldhowareyou
|
67
crates/nu_plugin_stream_example/src/example.rs
Normal file
67
crates/nu_plugin_stream_example/src/example.rs
Normal file
@ -0,0 +1,67 @@
|
||||
use nu_plugin::{EvaluatedCall, LabeledError};
|
||||
use nu_protocol::{ListStream, PipelineData, RawStream, Value};
|
||||
|
||||
pub struct Example;
|
||||
|
||||
mod int_or_float;
|
||||
use self::int_or_float::IntOrFloat;
|
||||
|
||||
impl Example {
|
||||
pub fn seq(
|
||||
&self,
|
||||
call: &EvaluatedCall,
|
||||
_input: PipelineData,
|
||||
) -> Result<PipelineData, LabeledError> {
|
||||
let first: i64 = call.req(0)?;
|
||||
let last: i64 = call.req(1)?;
|
||||
let span = call.head;
|
||||
let iter = (first..=last).map(move |number| Value::int(number, span));
|
||||
let list_stream = ListStream::from_stream(iter, None);
|
||||
Ok(PipelineData::ListStream(list_stream, None))
|
||||
}
|
||||
|
||||
pub fn sum(
|
||||
&self,
|
||||
call: &EvaluatedCall,
|
||||
input: PipelineData,
|
||||
) -> Result<PipelineData, LabeledError> {
|
||||
let mut acc = IntOrFloat::Int(0);
|
||||
let span = input.span();
|
||||
for value in input {
|
||||
if let Ok(n) = value.as_i64() {
|
||||
acc.add_i64(n);
|
||||
} else if let Ok(n) = value.as_f64() {
|
||||
acc.add_f64(n);
|
||||
} else {
|
||||
return Err(LabeledError {
|
||||
label: "Stream only accepts ints and floats".into(),
|
||||
msg: format!("found {}", value.get_type()),
|
||||
span,
|
||||
});
|
||||
}
|
||||
}
|
||||
Ok(PipelineData::Value(acc.to_value(call.head), None))
|
||||
}
|
||||
|
||||
pub fn collect_external(
|
||||
&self,
|
||||
call: &EvaluatedCall,
|
||||
input: PipelineData,
|
||||
) -> Result<PipelineData, LabeledError> {
|
||||
let stream = input.into_iter().map(|value| {
|
||||
value
|
||||
.as_str()
|
||||
.map(|str| str.as_bytes())
|
||||
.or_else(|_| value.as_binary())
|
||||
.map(|bin| bin.to_vec())
|
||||
});
|
||||
Ok(PipelineData::ExternalStream {
|
||||
stdout: Some(RawStream::new(Box::new(stream), None, call.head, None)),
|
||||
stderr: None,
|
||||
exit_code: None,
|
||||
span: call.head,
|
||||
metadata: None,
|
||||
trim_end_newline: false,
|
||||
})
|
||||
}
|
||||
}
|
42
crates/nu_plugin_stream_example/src/example/int_or_float.rs
Normal file
42
crates/nu_plugin_stream_example/src/example/int_or_float.rs
Normal file
@ -0,0 +1,42 @@
|
||||
use nu_protocol::Value;
|
||||
|
||||
use nu_protocol::Span;
|
||||
|
||||
/// Accumulates numbers into either an int or a float. Changes type to float on the first
|
||||
/// float received.
|
||||
#[derive(Clone, Copy)]
|
||||
pub(crate) enum IntOrFloat {
|
||||
Int(i64),
|
||||
Float(f64),
|
||||
}
|
||||
|
||||
impl IntOrFloat {
|
||||
pub(crate) fn add_i64(&mut self, n: i64) {
|
||||
match self {
|
||||
IntOrFloat::Int(ref mut v) => {
|
||||
*v += n;
|
||||
}
|
||||
IntOrFloat::Float(ref mut v) => {
|
||||
*v += n as f64;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn add_f64(&mut self, n: f64) {
|
||||
match self {
|
||||
IntOrFloat::Int(v) => {
|
||||
*self = IntOrFloat::Float(*v as f64 + n);
|
||||
}
|
||||
IntOrFloat::Float(ref mut v) => {
|
||||
*v += n;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn to_value(self, span: Span) -> Value {
|
||||
match self {
|
||||
IntOrFloat::Int(v) => Value::int(v, span),
|
||||
IntOrFloat::Float(v) => Value::float(v, span),
|
||||
}
|
||||
}
|
||||
}
|
4
crates/nu_plugin_stream_example/src/lib.rs
Normal file
4
crates/nu_plugin_stream_example/src/lib.rs
Normal file
@ -0,0 +1,4 @@
|
||||
mod example;
|
||||
mod nu;
|
||||
|
||||
pub use example::Example;
|
30
crates/nu_plugin_stream_example/src/main.rs
Normal file
30
crates/nu_plugin_stream_example/src/main.rs
Normal file
@ -0,0 +1,30 @@
|
||||
use nu_plugin::{serve_plugin, MsgPackSerializer};
|
||||
use nu_plugin_stream_example::Example;
|
||||
|
||||
fn main() {
|
||||
// When defining your plugin, you can select the Serializer that could be
|
||||
// used to encode and decode the messages. The available options are
|
||||
// MsgPackSerializer and JsonSerializer. Both are defined in the serializer
|
||||
// folder in nu-plugin.
|
||||
serve_plugin(&mut Example {}, MsgPackSerializer {})
|
||||
|
||||
// Note
|
||||
// When creating plugins in other languages one needs to consider how a plugin
|
||||
// is added and used in nushell.
|
||||
// The steps are:
|
||||
// - The plugin is register. In this stage nushell calls the binary file of
|
||||
// the plugin sending information using the encoded PluginCall::PluginSignature object.
|
||||
// Use this encoded data in your plugin to design the logic that will return
|
||||
// the encoded signatures.
|
||||
// Nushell is expecting and encoded PluginResponse::PluginSignature with all the
|
||||
// plugin signatures
|
||||
// - When calling the plugin, nushell sends to the binary file the encoded
|
||||
// PluginCall::CallInfo which has all the call information, such as the
|
||||
// values of the arguments, the name of the signature called and the input
|
||||
// from the pipeline.
|
||||
// Use this data to design your plugin login and to create the value that
|
||||
// will be sent to nushell
|
||||
// Nushell expects an encoded PluginResponse::Value from the plugin
|
||||
// - If an error needs to be sent back to nushell, one can encode PluginResponse::Error.
|
||||
// This is a labeled error that nushell can format for pretty printing
|
||||
}
|
86
crates/nu_plugin_stream_example/src/nu/mod.rs
Normal file
86
crates/nu_plugin_stream_example/src/nu/mod.rs
Normal file
@ -0,0 +1,86 @@
|
||||
use crate::Example;
|
||||
use nu_plugin::{EvaluatedCall, LabeledError, StreamingPlugin};
|
||||
use nu_protocol::{
|
||||
Category, PipelineData, PluginExample, PluginSignature, Span, SyntaxShape, Type, Value,
|
||||
};
|
||||
|
||||
impl StreamingPlugin for Example {
|
||||
fn signature(&self) -> Vec<PluginSignature> {
|
||||
let span = Span::unknown();
|
||||
vec![
|
||||
PluginSignature::build("stream_example")
|
||||
.usage("Examples for streaming plugins")
|
||||
.search_terms(vec!["example".into()])
|
||||
.category(Category::Experimental),
|
||||
PluginSignature::build("stream_example seq")
|
||||
.usage("Example stream generator for a list of values")
|
||||
.search_terms(vec!["example".into()])
|
||||
.required("first", SyntaxShape::Int, "first number to generate")
|
||||
.required("last", SyntaxShape::Int, "last number to generate")
|
||||
.input_output_type(Type::Nothing, Type::List(Type::Int.into()))
|
||||
.plugin_examples(vec![PluginExample {
|
||||
example: "stream_example seq 1 3".into(),
|
||||
description: "generate a sequence from 1 to 3".into(),
|
||||
result: Some(Value::list(
|
||||
vec![
|
||||
Value::int(1, span),
|
||||
Value::int(2, span),
|
||||
Value::int(3, span),
|
||||
],
|
||||
span,
|
||||
)),
|
||||
}])
|
||||
.category(Category::Experimental),
|
||||
PluginSignature::build("stream_example sum")
|
||||
.usage("Example stream consumer for a list of values")
|
||||
.search_terms(vec!["example".into()])
|
||||
.input_output_types(vec![
|
||||
(Type::List(Type::Int.into()), Type::Int),
|
||||
(Type::List(Type::Float.into()), Type::Float),
|
||||
])
|
||||
.plugin_examples(vec![PluginExample {
|
||||
example: "seq 1 5 | stream_example sum".into(),
|
||||
description: "sum values from 1 to 5".into(),
|
||||
result: Some(Value::int(15, span)),
|
||||
}])
|
||||
.category(Category::Experimental),
|
||||
PluginSignature::build("stream_example collect-external")
|
||||
.usage("Example transformer to raw external stream")
|
||||
.search_terms(vec!["example".into()])
|
||||
.input_output_types(vec![
|
||||
(Type::List(Type::String.into()), Type::String),
|
||||
(Type::List(Type::Binary.into()), Type::Binary),
|
||||
])
|
||||
.plugin_examples(vec![PluginExample {
|
||||
example: "[a b] | stream_example collect-external".into(),
|
||||
description: "collect strings into one stream".into(),
|
||||
result: Some(Value::string("ab", span)),
|
||||
}])
|
||||
.category(Category::Experimental),
|
||||
]
|
||||
}
|
||||
|
||||
fn run(
|
||||
&mut self,
|
||||
name: &str,
|
||||
_config: &Option<Value>,
|
||||
call: &EvaluatedCall,
|
||||
input: PipelineData,
|
||||
) -> Result<PipelineData, LabeledError> {
|
||||
match name {
|
||||
"stream_example" => Err(LabeledError {
|
||||
label: "No subcommand provided".into(),
|
||||
msg: "add --help here to see usage".into(),
|
||||
span: Some(call.head)
|
||||
}),
|
||||
"stream_example seq" => self.seq(call, input),
|
||||
"stream_example sum" => self.sum(call, input),
|
||||
"stream_example collect-external" => self.collect_external(call, input),
|
||||
_ => Err(LabeledError {
|
||||
label: "Plugin call with wrong name signature".into(),
|
||||
msg: "the signature used to call the plugin does not match any name in the plugin signature vector".into(),
|
||||
span: Some(call.head),
|
||||
}),
|
||||
}
|
||||
}
|
||||
}
|
@ -26,6 +26,20 @@ fn can_get_custom_value_from_plugin_and_pass_it_over() {
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn can_get_custom_value_from_plugin_and_pass_it_over_as_an_argument() {
|
||||
let actual = nu_with_plugins!(
|
||||
cwd: "tests",
|
||||
plugin: ("nu_plugin_custom_values"),
|
||||
"custom-value update-arg (custom-value generate)"
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
actual.out,
|
||||
"I used to be a custom value! My data was (abcxyz)"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn can_generate_and_updated_multiple_types_of_custom_values() {
|
||||
let actual = nu_with_plugins!(
|
||||
@ -65,7 +79,10 @@ fn fails_if_passing_engine_custom_values_to_plugins() {
|
||||
|
||||
assert!(actual
|
||||
.err
|
||||
.contains("Plugin custom-value update can not handle the custom value SQLiteDatabase"));
|
||||
.contains("`SQLiteDatabase` cannot be sent to plugin"));
|
||||
assert!(actual
|
||||
.err
|
||||
.contains("the `custom_values` plugin does not support this kind of value"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
@ -81,5 +98,8 @@ fn fails_if_passing_custom_values_across_plugins() {
|
||||
|
||||
assert!(actual
|
||||
.err
|
||||
.contains("Plugin inc can not handle the custom value CoolCustomValue"));
|
||||
.contains("`CoolCustomValue` cannot be sent to plugin"));
|
||||
assert!(actual
|
||||
.err
|
||||
.contains("the `inc` plugin does not support this kind of value"));
|
||||
}
|
||||
|
@ -3,3 +3,4 @@ mod core_inc;
|
||||
mod custom_values;
|
||||
mod formats;
|
||||
mod register;
|
||||
mod stream;
|
||||
|
166
tests/plugins/stream.rs
Normal file
166
tests/plugins/stream.rs
Normal file
@ -0,0 +1,166 @@
|
||||
use nu_test_support::nu_with_plugins;
|
||||
use pretty_assertions::assert_eq;
|
||||
|
||||
#[test]
|
||||
fn seq_produces_stream() {
|
||||
let actual = nu_with_plugins!(
|
||||
cwd: "tests/fixtures/formats",
|
||||
plugin: ("nu_plugin_stream_example"),
|
||||
"stream_example seq 1 5 | describe"
|
||||
);
|
||||
|
||||
assert_eq!(actual.out, "list<int> (stream)");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn seq_describe_no_collect_succeeds_without_error() {
|
||||
// This tests to ensure that there's no error if the stream is suddenly closed
|
||||
let actual = nu_with_plugins!(
|
||||
cwd: "tests/fixtures/formats",
|
||||
plugin: ("nu_plugin_stream_example"),
|
||||
"stream_example seq 1 5 | describe --no-collect"
|
||||
);
|
||||
|
||||
assert_eq!(actual.out, "stream");
|
||||
assert_eq!(actual.err, "");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn seq_stream_collects_to_correct_list() {
|
||||
let actual = nu_with_plugins!(
|
||||
cwd: "tests/fixtures/formats",
|
||||
plugin: ("nu_plugin_stream_example"),
|
||||
"stream_example seq 1 5 | to json --raw"
|
||||
);
|
||||
|
||||
assert_eq!(actual.out, "[1,2,3,4,5]");
|
||||
|
||||
let actual = nu_with_plugins!(
|
||||
cwd: "tests/fixtures/formats",
|
||||
plugin: ("nu_plugin_stream_example"),
|
||||
"stream_example seq 1 0 | to json --raw"
|
||||
);
|
||||
|
||||
assert_eq!(actual.out, "[]");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn seq_big_stream() {
|
||||
// Testing big streams helps to ensure there are no deadlocking bugs
|
||||
let actual = nu_with_plugins!(
|
||||
cwd: "tests/fixtures/formats",
|
||||
plugin: ("nu_plugin_stream_example"),
|
||||
"stream_example seq 1 100000 | length"
|
||||
);
|
||||
|
||||
assert_eq!(actual.out, "100000");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn sum_accepts_list_of_int() {
|
||||
let actual = nu_with_plugins!(
|
||||
cwd: "tests/fixtures/formats",
|
||||
plugin: ("nu_plugin_stream_example"),
|
||||
"[1 2 3] | stream_example sum"
|
||||
);
|
||||
|
||||
assert_eq!(actual.out, "6");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn sum_accepts_list_of_float() {
|
||||
let actual = nu_with_plugins!(
|
||||
cwd: "tests/fixtures/formats",
|
||||
plugin: ("nu_plugin_stream_example"),
|
||||
"[1.0 2.0 3.5] | stream_example sum"
|
||||
);
|
||||
|
||||
assert_eq!(actual.out, "6.5");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn sum_accepts_stream_of_int() {
|
||||
let actual = nu_with_plugins!(
|
||||
cwd: "tests/fixtures/formats",
|
||||
plugin: ("nu_plugin_stream_example"),
|
||||
"seq 1 5 | stream_example sum"
|
||||
);
|
||||
|
||||
assert_eq!(actual.out, "15");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn sum_accepts_stream_of_float() {
|
||||
let actual = nu_with_plugins!(
|
||||
cwd: "tests/fixtures/formats",
|
||||
plugin: ("nu_plugin_stream_example"),
|
||||
"seq 1 5 | into float | stream_example sum"
|
||||
);
|
||||
|
||||
assert_eq!(actual.out, "15");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn sum_big_stream() {
|
||||
// Testing big streams helps to ensure there are no deadlocking bugs
|
||||
let actual = nu_with_plugins!(
|
||||
cwd: "tests/fixtures/formats",
|
||||
plugin: ("nu_plugin_stream_example"),
|
||||
"seq 1 100000 | stream_example sum"
|
||||
);
|
||||
|
||||
assert_eq!(actual.out, "5000050000");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn collect_external_accepts_list_of_string() {
|
||||
let actual = nu_with_plugins!(
|
||||
cwd: "tests/fixtures/formats",
|
||||
plugin: ("nu_plugin_stream_example"),
|
||||
"[a b] | stream_example collect-external"
|
||||
);
|
||||
|
||||
assert_eq!(actual.out, "ab");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn collect_external_accepts_list_of_binary() {
|
||||
let actual = nu_with_plugins!(
|
||||
cwd: "tests/fixtures/formats",
|
||||
plugin: ("nu_plugin_stream_example"),
|
||||
"[0x[41] 0x[42]] | stream_example collect-external"
|
||||
);
|
||||
|
||||
assert_eq!(actual.out, "AB");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn collect_external_produces_raw_input() {
|
||||
let actual = nu_with_plugins!(
|
||||
cwd: "tests/fixtures/formats",
|
||||
plugin: ("nu_plugin_stream_example"),
|
||||
"[a b c] | stream_example collect-external | describe"
|
||||
);
|
||||
|
||||
assert_eq!(actual.out, "raw input");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn collect_external_big_stream() {
|
||||
// This in particular helps to ensure that a big stream can be both read and written at the same
|
||||
// time without deadlocking
|
||||
let actual = nu_with_plugins!(
|
||||
cwd: "tests/fixtures/formats",
|
||||
plugin: ("nu_plugin_stream_example"),
|
||||
r#"(
|
||||
seq 1 10000 |
|
||||
to text |
|
||||
each { into string } |
|
||||
stream_example collect-external |
|
||||
lines |
|
||||
length
|
||||
)"#
|
||||
);
|
||||
|
||||
assert_eq!(actual.out, "10000");
|
||||
}
|
@ -281,6 +281,14 @@
|
||||
Source='target\$(var.Profile)\nu_plugin_gstat.exe'
|
||||
KeyPath='yes'/>
|
||||
</Component>
|
||||
<Component Id='binary23' Guid='*' Win64='$(var.Win64)'>
|
||||
<File
|
||||
Id='exe23'
|
||||
Name='nu_plugin_stream_example.exe'
|
||||
DiskId='1'
|
||||
Source='target\$(var.Profile)\nu_plugin_stream_example.exe'
|
||||
KeyPath='yes'/>
|
||||
</Component>
|
||||
</Directory>
|
||||
</Directory>
|
||||
</Directory>
|
||||
|
Loading…
Reference in New Issue
Block a user