Split the plugin crate (#12563)

# Description

This breaks `nu-plugin` up into four crates:

- `nu-plugin-protocol`: just the type definitions for the protocol, no
I/O. If someone wanted to wire up something more bare metal, maybe for
async I/O, they could use this.
- `nu-plugin-core`: the shared stuff between engine/plugin. Less stable
interface.
- `nu-plugin-engine`: everything required for the engine to talk to
plugins. Less stable interface.
- `nu-plugin`: everything required for the plugin to talk to the engine,
what plugin developers use. Should be the most stable interface.

No changes are made to the interface exposed by `nu-plugin` - it should
all still be there. Re-exports from `nu-plugin-protocol` or
`nu-plugin-core` are used as required. Plugins shouldn't ever have to
use those crates directly.

This should be somewhat faster to compile as `nu-plugin-engine` and
`nu-plugin` can compile in parallel, and the engine doesn't need
`nu-plugin` and plugins don't need `nu-plugin-engine` (except for test
support), so that should reduce what needs to be compiled too.

The only significant change here other than splitting stuff up was to
break the `source` out of `PluginCustomValue` and create a new
`PluginCustomValueWithSource` type that contains that instead. One bonus
of that is we get rid of the option and it's now more type-safe, but it
also means that the logic for that stuff (actually running the plugin
for custom value ops) can live entirely within the `nu-plugin-engine`
crate.

# User-Facing Changes
- New crates.
- Added `local-socket` feature for `nu` to try to make it possible to
compile without that support if needed.

# Tests + Formatting
- 🟢 `toolkit fmt`
- 🟢 `toolkit clippy`
- 🟢 `toolkit test`
- 🟢 `toolkit test stdlib`
This commit is contained in:
Devyn Cairns
2024-04-27 10:08:12 -07:00
committed by GitHub
parent 884d5312bb
commit 0c4d5330ee
74 changed files with 3514 additions and 3110 deletions

View File

@ -0,0 +1,28 @@
[package]
authors = ["The Nushell Project Developers"]
description = "Shared internal functionality to support Nushell plugins"
repository = "https://github.com/nushell/nushell/tree/main/crates/nu-plugin-core"
edition = "2021"
license = "MIT"
name = "nu-plugin-core"
version = "0.92.3"
[lib]
bench = false
[dependencies]
nu-protocol = { path = "../nu-protocol", version = "0.92.3" }
nu-plugin-protocol = { path = "../nu-plugin-protocol", version = "0.92.3" }
rmp-serde = { workspace = true }
serde = { workspace = true }
serde_json = { workspace = true }
log = { workspace = true }
interprocess = { workspace = true, optional = true }
[features]
default = ["local-socket"]
local-socket = ["interprocess"]
[target.'cfg(target_os = "windows")'.dependencies]
windows = { workspace = true }

View File

@ -0,0 +1,21 @@
MIT License
Copyright (c) 2019 - 2023 The Nushell Project Developers
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

View File

@ -0,0 +1,3 @@
# nu-plugin-core
This crate provides functionality that is shared by the [Nushell](https://nushell.sh/) engine and plugins.

View File

@ -0,0 +1,84 @@
use std::ffi::OsString;
#[cfg(test)]
pub(crate) mod tests;
/// Generate a name to be used for a local socket specific to this `nu` process, described by the
/// given `unique_id`, which should be unique to the purpose of the socket.
///
/// On Unix, this is a path, which should generally be 100 characters or less for compatibility. On
/// Windows, this is a name within the `\\.\pipe` namespace.
#[cfg(unix)]
pub fn make_local_socket_name(unique_id: &str) -> OsString {
// Prefer to put it in XDG_RUNTIME_DIR if set, since that's user-local
let mut base = if let Some(runtime_dir) = std::env::var_os("XDG_RUNTIME_DIR") {
std::path::PathBuf::from(runtime_dir)
} else {
// Use std::env::temp_dir() for portability, especially since on Android this is probably
// not `/tmp`
std::env::temp_dir()
};
let socket_name = format!("nu.{}.{}.sock", std::process::id(), unique_id);
base.push(socket_name);
base.into()
}
/// Generate a name to be used for a local socket specific to this `nu` process, described by the
/// given `unique_id`, which should be unique to the purpose of the socket.
///
/// On Unix, this is a path, which should generally be 100 characters or less for compatibility. On
/// Windows, this is a name within the `\\.\pipe` namespace.
#[cfg(windows)]
pub fn make_local_socket_name(unique_id: &str) -> OsString {
format!("nu.{}.{}", std::process::id(), unique_id).into()
}
/// Determine if the error is just due to the listener not being ready yet in asynchronous mode
#[cfg(not(windows))]
pub fn is_would_block_err(err: &std::io::Error) -> bool {
err.kind() == std::io::ErrorKind::WouldBlock
}
/// Determine if the error is just due to the listener not being ready yet in asynchronous mode
#[cfg(windows)]
pub fn is_would_block_err(err: &std::io::Error) -> bool {
err.kind() == std::io::ErrorKind::WouldBlock
|| err.raw_os_error().is_some_and(|e| {
// Windows returns this error when trying to accept a pipe in non-blocking mode
e as i64 == windows::Win32::Foundation::ERROR_PIPE_LISTENING.0 as i64
})
}
/// Wraps the `interprocess` local socket stream for greater compatibility
#[derive(Debug)]
pub struct LocalSocketStream(pub interprocess::local_socket::LocalSocketStream);
impl From<interprocess::local_socket::LocalSocketStream> for LocalSocketStream {
fn from(value: interprocess::local_socket::LocalSocketStream) -> Self {
LocalSocketStream(value)
}
}
impl std::io::Read for LocalSocketStream {
fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
self.0.read(buf)
}
}
impl std::io::Write for LocalSocketStream {
fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
self.0.write(buf)
}
fn flush(&mut self) -> std::io::Result<()> {
// We don't actually flush the underlying socket on Windows. The flush operation on a
// Windows named pipe actually synchronizes with read on the other side, and won't finish
// until the other side is empty. This isn't how most of our other I/O methods work, so we
// just won't do it. The BufWriter above this will have still made a write call with the
// contents of the buffer, which should be good enough.
if cfg!(not(windows)) {
self.0.flush()?;
}
Ok(())
}
}

View File

@ -0,0 +1,19 @@
use super::make_local_socket_name;
#[test]
fn local_socket_path_contains_pid() {
let name = make_local_socket_name("test-string")
.to_string_lossy()
.into_owned();
println!("{}", name);
assert!(name.to_string().contains(&std::process::id().to_string()));
}
#[test]
fn local_socket_path_contains_provided_name() {
let name = make_local_socket_name("test-string")
.to_string_lossy()
.into_owned();
println!("{}", name);
assert!(name.to_string().contains("test-string"));
}

View File

@ -0,0 +1,249 @@
use std::ffi::OsStr;
use std::io::{Stdin, Stdout};
use std::process::{Child, ChildStdin, ChildStdout, Command, Stdio};
use nu_protocol::ShellError;
#[cfg(feature = "local-socket")]
use interprocess::local_socket::LocalSocketListener;
#[cfg(feature = "local-socket")]
mod local_socket;
#[cfg(feature = "local-socket")]
use local_socket::*;
/// The type of communication used between the plugin and the engine.
///
/// `Stdio` is required to be supported by all plugins, and is attempted initially. If the
/// `local-socket` feature is enabled and the plugin supports it, `LocalSocket` may be attempted.
///
/// Local socket communication has the benefit of not tying up stdio, so it's more compatible with
/// plugins that want to take user input from the terminal in some way.
#[derive(Debug, Clone)]
pub enum CommunicationMode {
/// Communicate using `stdin` and `stdout`.
Stdio,
/// Communicate using an operating system-specific local socket.
#[cfg(feature = "local-socket")]
LocalSocket(std::ffi::OsString),
}
impl CommunicationMode {
/// Generate a new local socket communication mode based on the given plugin exe path.
#[cfg(feature = "local-socket")]
pub fn local_socket(plugin_exe: &std::path::Path) -> CommunicationMode {
use std::hash::{Hash, Hasher};
use std::time::SystemTime;
// Generate the unique ID based on the plugin path and the current time. The actual
// algorithm here is not very important, we just want this to be relatively unique very
// briefly. Using the default hasher in the stdlib means zero extra dependencies.
let mut hasher = std::collections::hash_map::DefaultHasher::new();
plugin_exe.hash(&mut hasher);
SystemTime::now().hash(&mut hasher);
let unique_id = format!("{:016x}", hasher.finish());
CommunicationMode::LocalSocket(make_local_socket_name(&unique_id))
}
pub fn args(&self) -> Vec<&OsStr> {
match self {
CommunicationMode::Stdio => vec![OsStr::new("--stdio")],
#[cfg(feature = "local-socket")]
CommunicationMode::LocalSocket(path) => {
vec![OsStr::new("--local-socket"), path.as_os_str()]
}
}
}
pub fn setup_command_io(&self, command: &mut Command) {
match self {
CommunicationMode::Stdio => {
// Both stdout and stdin are piped so we can receive information from the plugin
command.stdin(Stdio::piped());
command.stdout(Stdio::piped());
}
#[cfg(feature = "local-socket")]
CommunicationMode::LocalSocket(_) => {
// Stdio can be used by the plugin to talk to the terminal in local socket mode,
// which is the big benefit
command.stdin(Stdio::inherit());
command.stdout(Stdio::inherit());
}
}
}
pub fn serve(&self) -> Result<PreparedServerCommunication, ShellError> {
match self {
// Nothing to set up for stdio - we just take it from the child.
CommunicationMode::Stdio => Ok(PreparedServerCommunication::Stdio),
// For sockets: we need to create the server so that the child won't fail to connect.
#[cfg(feature = "local-socket")]
CommunicationMode::LocalSocket(name) => {
let listener = LocalSocketListener::bind(name.as_os_str()).map_err(|err| {
ShellError::IOError {
msg: format!("failed to open socket for plugin: {err}"),
}
})?;
Ok(PreparedServerCommunication::LocalSocket {
name: name.clone(),
listener,
})
}
}
}
pub fn connect_as_client(&self) -> Result<ClientCommunicationIo, ShellError> {
match self {
CommunicationMode::Stdio => Ok(ClientCommunicationIo::Stdio(
std::io::stdin(),
std::io::stdout(),
)),
#[cfg(feature = "local-socket")]
CommunicationMode::LocalSocket(name) => {
// Connect to the specified socket.
let get_socket = || {
use interprocess::local_socket as ls;
ls::LocalSocketStream::connect(name.as_os_str())
.map_err(|err| ShellError::IOError {
msg: format!("failed to connect to socket: {err}"),
})
.map(LocalSocketStream::from)
};
// Reverse order from the server: read in, write out
let read_in = get_socket()?;
let write_out = get_socket()?;
Ok(ClientCommunicationIo::LocalSocket { read_in, write_out })
}
}
}
}
/// The result of [`CommunicationMode::serve()`], which acts as an intermediate stage for
/// communication modes that require some kind of socket binding to occur before the client process
/// can be started. Call [`.connect()`] once the client process has been started.
///
/// The socket may be cleaned up on `Drop` if applicable.
pub enum PreparedServerCommunication {
/// Will take stdin and stdout from the process on [`.connect()`].
Stdio,
/// Contains the listener to accept connections on. On Unix, the socket is unlinked on `Drop`.
#[cfg(feature = "local-socket")]
LocalSocket {
#[cfg_attr(windows, allow(dead_code))] // not used on Windows
name: std::ffi::OsString,
listener: LocalSocketListener,
},
}
impl PreparedServerCommunication {
pub fn connect(&self, child: &mut Child) -> Result<ServerCommunicationIo, ShellError> {
match self {
PreparedServerCommunication::Stdio => {
let stdin = child
.stdin
.take()
.ok_or_else(|| ShellError::PluginFailedToLoad {
msg: "Plugin missing stdin writer".into(),
})?;
let stdout = child
.stdout
.take()
.ok_or_else(|| ShellError::PluginFailedToLoad {
msg: "Plugin missing stdout writer".into(),
})?;
Ok(ServerCommunicationIo::Stdio(stdin, stdout))
}
#[cfg(feature = "local-socket")]
PreparedServerCommunication::LocalSocket { listener, .. } => {
use std::time::{Duration, Instant};
const RETRY_PERIOD: Duration = Duration::from_millis(1);
const TIMEOUT: Duration = Duration::from_secs(10);
let start = Instant::now();
// Use a loop to try to get two clients from the listener: one for read (the plugin
// output) and one for write (the plugin input)
listener.set_nonblocking(true)?;
let mut get_socket = || {
let mut result = None;
while let Ok(None) = child.try_wait() {
match listener.accept() {
Ok(stream) => {
// Success! But make sure the stream is in blocking mode.
stream.set_nonblocking(false)?;
result = Some(stream);
break;
}
Err(err) => {
if !is_would_block_err(&err) {
// `WouldBlock` is ok, just means it's not ready yet, but some other
// kind of error should be reported
return Err(err.into());
}
}
}
if Instant::now().saturating_duration_since(start) > TIMEOUT {
return Err(ShellError::PluginFailedToLoad {
msg: "Plugin timed out while waiting to connect to socket".into(),
});
} else {
std::thread::sleep(RETRY_PERIOD);
}
}
if let Some(stream) = result {
Ok(LocalSocketStream(stream))
} else {
// The process may have exited
Err(ShellError::PluginFailedToLoad {
msg: "Plugin exited without connecting".into(),
})
}
};
// Input stream always comes before output
let write_in = get_socket()?;
let read_out = get_socket()?;
Ok(ServerCommunicationIo::LocalSocket { read_out, write_in })
}
}
}
}
impl Drop for PreparedServerCommunication {
fn drop(&mut self) {
match self {
#[cfg(all(unix, feature = "local-socket"))]
PreparedServerCommunication::LocalSocket { name: path, .. } => {
// Just try to remove the socket file, it's ok if this fails
let _ = std::fs::remove_file(path);
}
_ => (),
}
}
}
/// The required streams for communication from the engine side, i.e. the server in socket terms.
pub enum ServerCommunicationIo {
Stdio(ChildStdin, ChildStdout),
#[cfg(feature = "local-socket")]
LocalSocket {
read_out: LocalSocketStream,
write_in: LocalSocketStream,
},
}
/// The required streams for communication from the plugin side, i.e. the client in socket terms.
pub enum ClientCommunicationIo {
Stdio(Stdin, Stdout),
#[cfg(feature = "local-socket")]
LocalSocket {
read_in: LocalSocketStream,
write_out: LocalSocketStream,
},
}

View File

@ -0,0 +1,453 @@
//! Implements the stream multiplexing interface for both the plugin side and the engine side.
use nu_plugin_protocol::{
ExternalStreamInfo, ListStreamInfo, PipelineDataHeader, RawStreamInfo, StreamMessage,
};
use nu_protocol::{ListStream, PipelineData, RawStream, ShellError};
use std::{
io::Write,
sync::{
atomic::{AtomicBool, Ordering::Relaxed},
Arc, Mutex,
},
thread,
};
pub mod stream;
use crate::{util::Sequence, Encoder};
use self::stream::{StreamManager, StreamManagerHandle, StreamWriter, WriteStreamMessage};
pub mod test_util;
#[cfg(test)]
mod tests;
/// The maximum number of list stream values to send without acknowledgement. This should be tuned
/// with consideration for memory usage.
const LIST_STREAM_HIGH_PRESSURE: i32 = 100;
/// The maximum number of raw stream buffers to send without acknowledgement. This should be tuned
/// with consideration for memory usage.
const RAW_STREAM_HIGH_PRESSURE: i32 = 50;
/// Read input/output from the stream.
pub trait PluginRead<T> {
/// Returns `Ok(None)` on end of stream.
fn read(&mut self) -> Result<Option<T>, ShellError>;
}
impl<R, E, T> PluginRead<T> for (R, E)
where
R: std::io::BufRead,
E: Encoder<T>,
{
fn read(&mut self) -> Result<Option<T>, ShellError> {
self.1.decode(&mut self.0)
}
}
impl<R, T> PluginRead<T> for &mut R
where
R: PluginRead<T>,
{
fn read(&mut self) -> Result<Option<T>, ShellError> {
(**self).read()
}
}
/// Write input/output to the stream.
///
/// The write should be atomic, without interference from other threads.
pub trait PluginWrite<T>: Send + Sync {
fn write(&self, data: &T) -> Result<(), ShellError>;
/// Flush any internal buffers, if applicable.
fn flush(&self) -> Result<(), ShellError>;
/// True if this output is stdout, so that plugins can avoid using stdout for their own purpose
fn is_stdout(&self) -> bool {
false
}
}
impl<E, T> PluginWrite<T> for (std::io::Stdout, E)
where
E: Encoder<T>,
{
fn write(&self, data: &T) -> Result<(), ShellError> {
let mut lock = self.0.lock();
self.1.encode(data, &mut lock)
}
fn flush(&self) -> Result<(), ShellError> {
self.0.lock().flush().map_err(|err| ShellError::IOError {
msg: err.to_string(),
})
}
fn is_stdout(&self) -> bool {
true
}
}
impl<W, E, T> PluginWrite<T> for (Mutex<W>, E)
where
W: std::io::Write + Send,
E: Encoder<T>,
{
fn write(&self, data: &T) -> Result<(), ShellError> {
let mut lock = self.0.lock().map_err(|_| ShellError::NushellFailed {
msg: "writer mutex poisoned".into(),
})?;
self.1.encode(data, &mut *lock)
}
fn flush(&self) -> Result<(), ShellError> {
let mut lock = self.0.lock().map_err(|_| ShellError::NushellFailed {
msg: "writer mutex poisoned".into(),
})?;
lock.flush().map_err(|err| ShellError::IOError {
msg: err.to_string(),
})
}
}
impl<W, T> PluginWrite<T> for &W
where
W: PluginWrite<T>,
{
fn write(&self, data: &T) -> Result<(), ShellError> {
(**self).write(data)
}
fn flush(&self) -> Result<(), ShellError> {
(**self).flush()
}
fn is_stdout(&self) -> bool {
(**self).is_stdout()
}
}
/// An interface manager handles I/O and state management for communication between a plugin and
/// the engine. See `PluginInterfaceManager` in `nu-plugin-engine` for communication from the engine
/// side to a plugin, or `EngineInterfaceManager` in `nu-plugin` for communication from the plugin
/// side to the engine.
///
/// There is typically one [`InterfaceManager`] consuming input from a background thread, and
/// managing shared state.
pub trait InterfaceManager {
/// The corresponding interface type.
type Interface: Interface + 'static;
/// The input message type.
type Input;
/// Make a new interface that communicates with this [`InterfaceManager`].
fn get_interface(&self) -> Self::Interface;
/// Consume an input message.
///
/// When implementing, call [`.consume_stream_message()`] for any encapsulated
/// [`StreamMessage`]s received.
fn consume(&mut self, input: Self::Input) -> Result<(), ShellError>;
/// Get the [`StreamManager`] for handling operations related to stream messages.
fn stream_manager(&self) -> &StreamManager;
/// Prepare [`PipelineData`] after reading. This is called by `read_pipeline_data()` as
/// a hook so that values that need special handling can be taken care of.
fn prepare_pipeline_data(&self, data: PipelineData) -> Result<PipelineData, ShellError>;
/// Consume an input stream message.
///
/// This method is provided for implementors to use.
fn consume_stream_message(&mut self, message: StreamMessage) -> Result<(), ShellError> {
self.stream_manager().handle_message(message)
}
/// Generate `PipelineData` for reading a stream, given a [`PipelineDataHeader`] that was
/// received from the other side.
///
/// This method is provided for implementors to use.
fn read_pipeline_data(
&self,
header: PipelineDataHeader,
ctrlc: Option<&Arc<AtomicBool>>,
) -> Result<PipelineData, ShellError> {
self.prepare_pipeline_data(match header {
PipelineDataHeader::Empty => PipelineData::Empty,
PipelineDataHeader::Value(value) => PipelineData::Value(value, None),
PipelineDataHeader::ListStream(info) => {
let handle = self.stream_manager().get_handle();
let reader = handle.read_stream(info.id, self.get_interface())?;
PipelineData::ListStream(ListStream::from_stream(reader, ctrlc.cloned()), None)
}
PipelineDataHeader::ExternalStream(info) => {
let handle = self.stream_manager().get_handle();
let span = info.span;
let new_raw_stream = |raw_info: RawStreamInfo| {
let reader = handle.read_stream(raw_info.id, self.get_interface())?;
let mut stream =
RawStream::new(Box::new(reader), ctrlc.cloned(), span, raw_info.known_size);
stream.is_binary = raw_info.is_binary;
Ok::<_, ShellError>(stream)
};
PipelineData::ExternalStream {
stdout: info.stdout.map(new_raw_stream).transpose()?,
stderr: info.stderr.map(new_raw_stream).transpose()?,
exit_code: info
.exit_code
.map(|list_info| {
handle
.read_stream(list_info.id, self.get_interface())
.map(|reader| ListStream::from_stream(reader, ctrlc.cloned()))
})
.transpose()?,
span: info.span,
metadata: None,
trim_end_newline: info.trim_end_newline,
}
}
})
}
}
/// An interface provides an API for communicating with a plugin or the engine and facilitates
/// stream I/O. See `PluginInterface` in `nu-plugin-engine` for the API from the engine side to a
/// plugin, or `EngineInterface` in `nu-plugin` for the API from the plugin side to the engine.
///
/// There can be multiple copies of the interface managed by a single [`InterfaceManager`].
pub trait Interface: Clone + Send {
/// The output message type, which must be capable of encapsulating a [`StreamMessage`].
type Output: From<StreamMessage>;
/// Any context required to construct [`PipelineData`]. Can be `()` if not needed.
type DataContext;
/// Write an output message.
fn write(&self, output: Self::Output) -> Result<(), ShellError>;
/// Flush the output buffer, so messages are visible to the other side.
fn flush(&self) -> Result<(), ShellError>;
/// Get the sequence for generating new [`StreamId`](nu_plugin_protocol::StreamId)s.
fn stream_id_sequence(&self) -> &Sequence;
/// Get the [`StreamManagerHandle`] for doing stream operations.
fn stream_manager_handle(&self) -> &StreamManagerHandle;
/// Prepare [`PipelineData`] to be written. This is called by `init_write_pipeline_data()` as
/// a hook so that values that need special handling can be taken care of.
fn prepare_pipeline_data(
&self,
data: PipelineData,
context: &Self::DataContext,
) -> Result<PipelineData, ShellError>;
/// Initialize a write for [`PipelineData`]. This returns two parts: the header, which can be
/// embedded in the particular message that references the stream, and a writer, which will
/// write out all of the data in the pipeline when `.write()` is called.
///
/// Note that not all [`PipelineData`] starts a stream. You should call `write()` anyway, as
/// it will automatically handle this case.
///
/// This method is provided for implementors to use.
fn init_write_pipeline_data(
&self,
data: PipelineData,
context: &Self::DataContext,
) -> Result<(PipelineDataHeader, PipelineDataWriter<Self>), ShellError> {
// Allocate a stream id and a writer
let new_stream = |high_pressure_mark: i32| {
// Get a free stream id
let id = self.stream_id_sequence().next()?;
// Create the writer
let writer =
self.stream_manager_handle()
.write_stream(id, self.clone(), high_pressure_mark)?;
Ok::<_, ShellError>((id, writer))
};
match self.prepare_pipeline_data(data, context)? {
PipelineData::Value(value, _) => {
Ok((PipelineDataHeader::Value(value), PipelineDataWriter::None))
}
PipelineData::Empty => Ok((PipelineDataHeader::Empty, PipelineDataWriter::None)),
PipelineData::ListStream(stream, _) => {
let (id, writer) = new_stream(LIST_STREAM_HIGH_PRESSURE)?;
Ok((
PipelineDataHeader::ListStream(ListStreamInfo { id }),
PipelineDataWriter::ListStream(writer, stream),
))
}
PipelineData::ExternalStream {
stdout,
stderr,
exit_code,
span,
metadata: _,
trim_end_newline,
} => {
// Create the writers and stream ids
let stdout_stream = stdout
.is_some()
.then(|| new_stream(RAW_STREAM_HIGH_PRESSURE))
.transpose()?;
let stderr_stream = stderr
.is_some()
.then(|| new_stream(RAW_STREAM_HIGH_PRESSURE))
.transpose()?;
let exit_code_stream = exit_code
.is_some()
.then(|| new_stream(LIST_STREAM_HIGH_PRESSURE))
.transpose()?;
// Generate the header, with the stream ids
let header = PipelineDataHeader::ExternalStream(ExternalStreamInfo {
span,
stdout: stdout
.as_ref()
.zip(stdout_stream.as_ref())
.map(|(stream, (id, _))| RawStreamInfo::new(*id, stream)),
stderr: stderr
.as_ref()
.zip(stderr_stream.as_ref())
.map(|(stream, (id, _))| RawStreamInfo::new(*id, stream)),
exit_code: exit_code_stream
.as_ref()
.map(|&(id, _)| ListStreamInfo { id }),
trim_end_newline,
});
// Collect the writers
let writer = PipelineDataWriter::ExternalStream {
stdout: stdout_stream.map(|(_, writer)| writer).zip(stdout),
stderr: stderr_stream.map(|(_, writer)| writer).zip(stderr),
exit_code: exit_code_stream.map(|(_, writer)| writer).zip(exit_code),
};
Ok((header, writer))
}
}
}
}
impl<T> WriteStreamMessage for T
where
T: Interface,
{
fn write_stream_message(&mut self, msg: StreamMessage) -> Result<(), ShellError> {
self.write(msg.into())
}
fn flush(&mut self) -> Result<(), ShellError> {
<Self as Interface>::flush(self)
}
}
/// Completes the write operation for a [`PipelineData`]. You must call
/// [`PipelineDataWriter::write()`] to write all of the data contained within the streams.
#[derive(Default)]
#[must_use]
pub enum PipelineDataWriter<W: WriteStreamMessage> {
#[default]
None,
ListStream(StreamWriter<W>, ListStream),
ExternalStream {
stdout: Option<(StreamWriter<W>, RawStream)>,
stderr: Option<(StreamWriter<W>, RawStream)>,
exit_code: Option<(StreamWriter<W>, ListStream)>,
},
}
impl<W> PipelineDataWriter<W>
where
W: WriteStreamMessage + Send + 'static,
{
/// Write all of the data in each of the streams. This method waits for completion.
pub fn write(self) -> Result<(), ShellError> {
match self {
// If no stream was contained in the PipelineData, do nothing.
PipelineDataWriter::None => Ok(()),
// Write a list stream.
PipelineDataWriter::ListStream(mut writer, stream) => {
writer.write_all(stream)?;
Ok(())
}
// Write all three possible streams of an ExternalStream on separate threads.
PipelineDataWriter::ExternalStream {
stdout,
stderr,
exit_code,
} => {
thread::scope(|scope| {
let stderr_thread = stderr
.map(|(mut writer, stream)| {
thread::Builder::new()
.name("plugin stderr writer".into())
.spawn_scoped(scope, move || {
writer.write_all(raw_stream_iter(stream))
})
})
.transpose()?;
let exit_code_thread = exit_code
.map(|(mut writer, stream)| {
thread::Builder::new()
.name("plugin exit_code writer".into())
.spawn_scoped(scope, move || writer.write_all(stream))
})
.transpose()?;
// Optimize for stdout: if only stdout is present, don't spawn any other
// threads.
if let Some((mut writer, stream)) = stdout {
writer.write_all(raw_stream_iter(stream))?;
}
let panicked = |thread_name: &str| {
Err(ShellError::NushellFailed {
msg: format!(
"{thread_name} thread panicked in PipelineDataWriter::write"
),
})
};
stderr_thread
.map(|t| t.join().unwrap_or_else(|_| panicked("stderr")))
.transpose()?;
exit_code_thread
.map(|t| t.join().unwrap_or_else(|_| panicked("exit_code")))
.transpose()?;
Ok(())
})
}
}
}
/// Write all of the data in each of the streams. This method returns immediately; any necessary
/// write will happen in the background. If a thread was spawned, its handle is returned.
pub fn write_background(
self,
) -> Result<Option<thread::JoinHandle<Result<(), ShellError>>>, ShellError> {
match self {
PipelineDataWriter::None => Ok(None),
_ => Ok(Some(
thread::Builder::new()
.name("plugin stream background writer".into())
.spawn(move || {
let result = self.write();
if let Err(ref err) = result {
// Assume that the background thread error probably won't be handled and log it
// here just in case.
log::warn!("Error while writing pipeline in background: {err}");
}
result
})?,
)),
}
}
}
/// Custom iterator for [`RawStream`] that respects ctrlc, but still has binary chunks
fn raw_stream_iter(stream: RawStream) -> impl Iterator<Item = Result<Vec<u8>, ShellError>> {
let ctrlc = stream.ctrlc;
stream
.stream
.take_while(move |_| ctrlc.as_ref().map(|b| !b.load(Relaxed)).unwrap_or(true))
}

View File

@ -0,0 +1,628 @@
use nu_plugin_protocol::{StreamData, StreamId, StreamMessage};
use nu_protocol::{ShellError, Span, Value};
use std::{
collections::{btree_map, BTreeMap},
iter::FusedIterator,
marker::PhantomData,
sync::{mpsc, Arc, Condvar, Mutex, MutexGuard, Weak},
};
#[cfg(test)]
mod tests;
/// Receives messages from a stream read from input by a [`StreamManager`].
///
/// The receiver reads for messages of type `Result<Option<StreamData>, ShellError>` from the
/// channel, which is managed by a [`StreamManager`]. Signalling for end-of-stream is explicit
/// through `Ok(Some)`.
///
/// Failing to receive is an error. When end-of-stream is received, the `receiver` is set to `None`
/// and all further calls to `next()` return `None`.
///
/// The type `T` must implement [`FromShellError`], so that errors in the stream can be represented,
/// and `TryFrom<StreamData>` to convert it to the correct type.
///
/// For each message read, it sends [`StreamMessage::Ack`] to the writer. When dropped,
/// it sends [`StreamMessage::Drop`].
#[derive(Debug)]
pub struct StreamReader<T, W>
where
W: WriteStreamMessage,
{
id: StreamId,
receiver: Option<mpsc::Receiver<Result<Option<StreamData>, ShellError>>>,
writer: W,
/// Iterator requires the item type to be fixed, so we have to keep it as part of the type,
/// even though we're actually receiving dynamic data.
marker: PhantomData<fn() -> T>,
}
impl<T, W> StreamReader<T, W>
where
T: TryFrom<StreamData, Error = ShellError>,
W: WriteStreamMessage,
{
/// Create a new StreamReader from parts
fn new(
id: StreamId,
receiver: mpsc::Receiver<Result<Option<StreamData>, ShellError>>,
writer: W,
) -> StreamReader<T, W> {
StreamReader {
id,
receiver: Some(receiver),
writer,
marker: PhantomData,
}
}
/// Receive a message from the channel, or return an error if:
///
/// * the channel couldn't be received from
/// * an error was sent on the channel
/// * the message received couldn't be converted to `T`
pub fn recv(&mut self) -> Result<Option<T>, ShellError> {
let connection_lost = || ShellError::GenericError {
error: "Stream ended unexpectedly".into(),
msg: "connection lost before explicit end of stream".into(),
span: None,
help: None,
inner: vec![],
};
if let Some(ref rx) = self.receiver {
// Try to receive a message first
let msg = match rx.try_recv() {
Ok(msg) => msg?,
Err(mpsc::TryRecvError::Empty) => {
// The receiver doesn't have any messages waiting for us. It's possible that the
// other side hasn't seen our acknowledgements. Let's flush the writer and then
// wait
self.writer.flush()?;
rx.recv().map_err(|_| connection_lost())??
}
Err(mpsc::TryRecvError::Disconnected) => return Err(connection_lost()),
};
if let Some(data) = msg {
// Acknowledge the message
self.writer
.write_stream_message(StreamMessage::Ack(self.id))?;
// Try to convert it into the correct type
Ok(Some(data.try_into()?))
} else {
// Remove the receiver, so that future recv() calls always return Ok(None)
self.receiver = None;
Ok(None)
}
} else {
// Closed already
Ok(None)
}
}
}
impl<T, W> Iterator for StreamReader<T, W>
where
T: FromShellError + TryFrom<StreamData, Error = ShellError>,
W: WriteStreamMessage,
{
type Item = T;
fn next(&mut self) -> Option<T> {
// Converting the error to the value here makes the implementation a lot easier
match self.recv() {
Ok(option) => option,
Err(err) => {
// Drop the receiver so we don't keep returning errors
self.receiver = None;
Some(T::from_shell_error(err))
}
}
}
}
// Guaranteed not to return anything after the end
impl<T, W> FusedIterator for StreamReader<T, W>
where
T: FromShellError + TryFrom<StreamData, Error = ShellError>,
W: WriteStreamMessage,
{
}
impl<T, W> Drop for StreamReader<T, W>
where
W: WriteStreamMessage,
{
fn drop(&mut self) {
if let Err(err) = self
.writer
.write_stream_message(StreamMessage::Drop(self.id))
.and_then(|_| self.writer.flush())
{
log::warn!("Failed to send message to drop stream: {err}");
}
}
}
/// Values that can contain a `ShellError` to signal an error has occurred.
pub trait FromShellError {
fn from_shell_error(err: ShellError) -> Self;
}
// For List streams.
impl FromShellError for Value {
fn from_shell_error(err: ShellError) -> Self {
Value::error(err, Span::unknown())
}
}
// For Raw streams, mostly.
impl<T> FromShellError for Result<T, ShellError> {
fn from_shell_error(err: ShellError) -> Self {
Err(err)
}
}
/// Writes messages to a stream, with flow control.
///
/// The `signal` contained
#[derive(Debug)]
pub struct StreamWriter<W: WriteStreamMessage> {
id: StreamId,
signal: Arc<StreamWriterSignal>,
writer: W,
ended: bool,
}
impl<W> StreamWriter<W>
where
W: WriteStreamMessage,
{
fn new(id: StreamId, signal: Arc<StreamWriterSignal>, writer: W) -> StreamWriter<W> {
StreamWriter {
id,
signal,
writer,
ended: false,
}
}
/// Check if the stream was dropped from the other end. Recommended to do this before calling
/// [`.write()`], especially in a loop.
pub fn is_dropped(&self) -> Result<bool, ShellError> {
self.signal.is_dropped()
}
/// Write a single piece of data to the stream.
///
/// Error if something failed with the write, or if [`.end()`] was already called
/// previously.
pub fn write(&mut self, data: impl Into<StreamData>) -> Result<(), ShellError> {
if !self.ended {
self.writer
.write_stream_message(StreamMessage::Data(self.id, data.into()))?;
// This implements flow control, so we don't write too many messages:
if !self.signal.notify_sent()? {
// Flush the output, and then wait for acknowledgements
self.writer.flush()?;
self.signal.wait_for_drain()
} else {
Ok(())
}
} else {
Err(ShellError::GenericError {
error: "Wrote to a stream after it ended".into(),
msg: format!(
"tried to write to stream {} after it was already ended",
self.id
),
span: None,
help: Some("this may be a bug in the nu-plugin crate".into()),
inner: vec![],
})
}
}
/// Write a full iterator to the stream. Note that this doesn't end the stream, so you should
/// still call [`.end()`].
///
/// If the stream is dropped from the other end, the iterator will not be fully consumed, and
/// writing will terminate.
///
/// Returns `Ok(true)` if the iterator was fully consumed, or `Ok(false)` if a drop interrupted
/// the stream from the other side.
pub fn write_all<T>(&mut self, data: impl IntoIterator<Item = T>) -> Result<bool, ShellError>
where
T: Into<StreamData>,
{
// Check before starting
if self.is_dropped()? {
return Ok(false);
}
for item in data {
// Check again after each item is consumed from the iterator, just in case the iterator
// takes a while to produce a value
if self.is_dropped()? {
return Ok(false);
}
self.write(item)?;
}
Ok(true)
}
/// End the stream. Recommend doing this instead of relying on `Drop` so that you can catch the
/// error.
pub fn end(&mut self) -> Result<(), ShellError> {
if !self.ended {
// Set the flag first so we don't double-report in the Drop
self.ended = true;
self.writer
.write_stream_message(StreamMessage::End(self.id))?;
self.writer.flush()
} else {
Ok(())
}
}
}
impl<W> Drop for StreamWriter<W>
where
W: WriteStreamMessage,
{
fn drop(&mut self) {
// Make sure we ended the stream
if let Err(err) = self.end() {
log::warn!("Error while ending stream in Drop for StreamWriter: {err}");
}
}
}
/// Stores stream state for a writer, and can be blocked on to wait for messages to be acknowledged.
/// A key part of managing stream lifecycle and flow control.
#[derive(Debug)]
pub struct StreamWriterSignal {
mutex: Mutex<StreamWriterSignalState>,
change_cond: Condvar,
}
#[derive(Debug)]
pub struct StreamWriterSignalState {
/// Stream has been dropped and consumer is no longer interested in any messages.
dropped: bool,
/// Number of messages that have been sent without acknowledgement.
unacknowledged: i32,
/// Max number of messages to send before waiting for acknowledgement.
high_pressure_mark: i32,
}
impl StreamWriterSignal {
/// Create a new signal.
///
/// If `notify_sent()` is called more than `high_pressure_mark` times, it will wait until
/// `notify_acknowledge()` is called by another thread enough times to bring the number of
/// unacknowledged sent messages below that threshold.
fn new(high_pressure_mark: i32) -> StreamWriterSignal {
assert!(high_pressure_mark > 0);
StreamWriterSignal {
mutex: Mutex::new(StreamWriterSignalState {
dropped: false,
unacknowledged: 0,
high_pressure_mark,
}),
change_cond: Condvar::new(),
}
}
fn lock(&self) -> Result<MutexGuard<StreamWriterSignalState>, ShellError> {
self.mutex.lock().map_err(|_| ShellError::NushellFailed {
msg: "StreamWriterSignal mutex poisoned due to panic".into(),
})
}
/// True if the stream was dropped and the consumer is no longer interested in it. Indicates
/// that no more messages should be sent, other than `End`.
pub fn is_dropped(&self) -> Result<bool, ShellError> {
Ok(self.lock()?.dropped)
}
/// Notify the writers that the stream has been dropped, so they can stop writing.
pub fn set_dropped(&self) -> Result<(), ShellError> {
let mut state = self.lock()?;
state.dropped = true;
// Unblock the writers so they can terminate
self.change_cond.notify_all();
Ok(())
}
/// Track that a message has been sent. Returns `Ok(true)` if more messages can be sent,
/// or `Ok(false)` if the high pressure mark has been reached and [`.wait_for_drain()`] should
/// be called to block.
pub fn notify_sent(&self) -> Result<bool, ShellError> {
let mut state = self.lock()?;
state.unacknowledged =
state
.unacknowledged
.checked_add(1)
.ok_or_else(|| ShellError::NushellFailed {
msg: "Overflow in counter: too many unacknowledged messages".into(),
})?;
Ok(state.unacknowledged < state.high_pressure_mark)
}
/// Wait for acknowledgements before sending more data. Also returns if the stream is dropped.
pub fn wait_for_drain(&self) -> Result<(), ShellError> {
let mut state = self.lock()?;
while !state.dropped && state.unacknowledged >= state.high_pressure_mark {
state = self
.change_cond
.wait(state)
.map_err(|_| ShellError::NushellFailed {
msg: "StreamWriterSignal mutex poisoned due to panic".into(),
})?;
}
Ok(())
}
/// Notify the writers that a message has been acknowledged, so they can continue to write
/// if they were waiting.
pub fn notify_acknowledged(&self) -> Result<(), ShellError> {
let mut state = self.lock()?;
state.unacknowledged =
state
.unacknowledged
.checked_sub(1)
.ok_or_else(|| ShellError::NushellFailed {
msg: "Underflow in counter: too many message acknowledgements".into(),
})?;
// Unblock the writer
self.change_cond.notify_one();
Ok(())
}
}
/// A sink for a [`StreamMessage`]
pub trait WriteStreamMessage {
fn write_stream_message(&mut self, msg: StreamMessage) -> Result<(), ShellError>;
fn flush(&mut self) -> Result<(), ShellError>;
}
#[derive(Debug, Default)]
struct StreamManagerState {
reading_streams: BTreeMap<StreamId, mpsc::Sender<Result<Option<StreamData>, ShellError>>>,
writing_streams: BTreeMap<StreamId, Weak<StreamWriterSignal>>,
}
impl StreamManagerState {
/// Lock the state, or return a [`ShellError`] if the mutex is poisoned.
fn lock(
state: &Mutex<StreamManagerState>,
) -> Result<MutexGuard<StreamManagerState>, ShellError> {
state.lock().map_err(|_| ShellError::NushellFailed {
msg: "StreamManagerState mutex poisoned due to a panic".into(),
})
}
}
#[derive(Debug)]
pub struct StreamManager {
state: Arc<Mutex<StreamManagerState>>,
}
impl StreamManager {
/// Create a new StreamManager.
pub fn new() -> StreamManager {
StreamManager {
state: Default::default(),
}
}
fn lock(&self) -> Result<MutexGuard<StreamManagerState>, ShellError> {
StreamManagerState::lock(&self.state)
}
/// Create a new handle to the StreamManager for registering streams.
pub fn get_handle(&self) -> StreamManagerHandle {
StreamManagerHandle {
state: Arc::downgrade(&self.state),
}
}
/// Process a stream message, and update internal state accordingly.
pub fn handle_message(&self, message: StreamMessage) -> Result<(), ShellError> {
let mut state = self.lock()?;
match message {
StreamMessage::Data(id, data) => {
if let Some(sender) = state.reading_streams.get(&id) {
// We should ignore the error on send. This just means the reader has dropped,
// but it will have sent a Drop message to the other side, and we will receive
// an End message at which point we can remove the channel.
let _ = sender.send(Ok(Some(data)));
Ok(())
} else {
Err(ShellError::PluginFailedToDecode {
msg: format!("received Data for unknown stream {id}"),
})
}
}
StreamMessage::End(id) => {
if let Some(sender) = state.reading_streams.remove(&id) {
// We should ignore the error on the send, because the reader might have dropped
// already
let _ = sender.send(Ok(None));
Ok(())
} else {
Err(ShellError::PluginFailedToDecode {
msg: format!("received End for unknown stream {id}"),
})
}
}
StreamMessage::Drop(id) => {
if let Some(signal) = state.writing_streams.remove(&id) {
if let Some(signal) = signal.upgrade() {
// This will wake blocked writers so they can stop writing, so it's ok
signal.set_dropped()?;
}
}
// It's possible that the stream has already finished writing and we don't have it
// anymore, so we fall through to Ok
Ok(())
}
StreamMessage::Ack(id) => {
if let Some(signal) = state.writing_streams.get(&id) {
if let Some(signal) = signal.upgrade() {
// This will wake up a blocked writer
signal.notify_acknowledged()?;
} else {
// We know it doesn't exist, so might as well remove it
state.writing_streams.remove(&id);
}
}
// It's possible that the stream has already finished writing and we don't have it
// anymore, so we fall through to Ok
Ok(())
}
}
}
/// Broadcast an error to all stream readers. This is useful for error propagation.
pub fn broadcast_read_error(&self, error: ShellError) -> Result<(), ShellError> {
let state = self.lock()?;
for channel in state.reading_streams.values() {
// Ignore send errors.
let _ = channel.send(Err(error.clone()));
}
Ok(())
}
// If the `StreamManager` is dropped, we should let all of the stream writers know that they
// won't be able to write anymore. We don't need to do anything about the readers though
// because they'll know when the `Sender` is dropped automatically
fn drop_all_writers(&self) -> Result<(), ShellError> {
let mut state = self.lock()?;
let writers = std::mem::take(&mut state.writing_streams);
for (_, signal) in writers {
if let Some(signal) = signal.upgrade() {
// more important that we send to all than handling an error
let _ = signal.set_dropped();
}
}
Ok(())
}
}
impl Default for StreamManager {
fn default() -> Self {
Self::new()
}
}
impl Drop for StreamManager {
fn drop(&mut self) {
if let Err(err) = self.drop_all_writers() {
log::warn!("error during Drop for StreamManager: {}", err)
}
}
}
/// A [`StreamManagerHandle`] supports operations for interacting with the [`StreamManager`].
///
/// Streams can be registered for reading, returning a [`StreamReader`], or for writing, returning
/// a [`StreamWriter`].
#[derive(Debug, Clone)]
pub struct StreamManagerHandle {
state: Weak<Mutex<StreamManagerState>>,
}
impl StreamManagerHandle {
/// Because the handle only has a weak reference to the [`StreamManager`] state, we have to
/// first try to upgrade to a strong reference and then lock. This function wraps those two
/// operations together, handling errors appropriately.
fn with_lock<T, F>(&self, f: F) -> Result<T, ShellError>
where
F: FnOnce(MutexGuard<StreamManagerState>) -> Result<T, ShellError>,
{
let upgraded = self
.state
.upgrade()
.ok_or_else(|| ShellError::NushellFailed {
msg: "StreamManager is no longer alive".into(),
})?;
let guard = upgraded.lock().map_err(|_| ShellError::NushellFailed {
msg: "StreamManagerState mutex poisoned due to a panic".into(),
})?;
f(guard)
}
/// Register a new stream for reading, and return a [`StreamReader`] that can be used to iterate
/// on the values received. A [`StreamMessage`] writer is required for writing control messages
/// back to the producer.
pub fn read_stream<T, W>(
&self,
id: StreamId,
writer: W,
) -> Result<StreamReader<T, W>, ShellError>
where
T: TryFrom<StreamData, Error = ShellError>,
W: WriteStreamMessage,
{
let (tx, rx) = mpsc::channel();
self.with_lock(|mut state| {
// Must be exclusive
if let btree_map::Entry::Vacant(e) = state.reading_streams.entry(id) {
e.insert(tx);
Ok(())
} else {
Err(ShellError::GenericError {
error: format!("Failed to acquire reader for stream {id}"),
msg: "tried to get a reader for a stream that's already being read".into(),
span: None,
help: Some("this may be a bug in the nu-plugin crate".into()),
inner: vec![],
})
}
})?;
Ok(StreamReader::new(id, rx, writer))
}
/// Register a new stream for writing, and return a [`StreamWriter`] that can be used to send
/// data to the stream.
///
/// The `high_pressure_mark` value controls how many messages can be written without receiving
/// an acknowledgement before any further attempts to write will wait for the consumer to
/// acknowledge them. This prevents overwhelming the reader.
pub fn write_stream<W>(
&self,
id: StreamId,
writer: W,
high_pressure_mark: i32,
) -> Result<StreamWriter<W>, ShellError>
where
W: WriteStreamMessage,
{
let signal = Arc::new(StreamWriterSignal::new(high_pressure_mark));
self.with_lock(|mut state| {
// Remove dead writing streams
state
.writing_streams
.retain(|_, signal| signal.strong_count() > 0);
// Must be exclusive
if let btree_map::Entry::Vacant(e) = state.writing_streams.entry(id) {
e.insert(Arc::downgrade(&signal));
Ok(())
} else {
Err(ShellError::GenericError {
error: format!("Failed to acquire writer for stream {id}"),
msg: "tried to get a writer for a stream that's already being written".into(),
span: None,
help: Some("this may be a bug in the nu-plugin crate".into()),
inner: vec![],
})
}
})?;
Ok(StreamWriter::new(id, signal, writer))
}
}

View File

@ -0,0 +1,550 @@
use std::{
sync::{
atomic::{AtomicBool, Ordering::Relaxed},
mpsc, Arc,
},
time::{Duration, Instant},
};
use super::{StreamManager, StreamReader, StreamWriter, StreamWriterSignal, WriteStreamMessage};
use nu_plugin_protocol::{StreamData, StreamMessage};
use nu_protocol::{ShellError, Value};
// Should be long enough to definitely complete any quick operation, but not so long that tests are
// slow to complete. 10 ms is a pretty long time
const WAIT_DURATION: Duration = Duration::from_millis(10);
// Maximum time to wait for a condition to be true
const MAX_WAIT_DURATION: Duration = Duration::from_millis(500);
/// Wait for a condition to be true, or panic if the duration exceeds MAX_WAIT_DURATION
#[track_caller]
fn wait_for_condition(mut cond: impl FnMut() -> bool, message: &str) {
// Early check
if cond() {
return;
}
let start = Instant::now();
loop {
std::thread::sleep(Duration::from_millis(10));
if cond() {
return;
}
let elapsed = Instant::now().saturating_duration_since(start);
if elapsed > MAX_WAIT_DURATION {
panic!(
"{message}: Waited {:.2}sec, which is more than the maximum of {:.2}sec",
elapsed.as_secs_f64(),
MAX_WAIT_DURATION.as_secs_f64(),
);
}
}
}
#[derive(Debug, Clone, Default)]
struct TestSink(Vec<StreamMessage>);
impl WriteStreamMessage for TestSink {
fn write_stream_message(&mut self, msg: StreamMessage) -> Result<(), ShellError> {
self.0.push(msg);
Ok(())
}
fn flush(&mut self) -> Result<(), ShellError> {
Ok(())
}
}
impl WriteStreamMessage for mpsc::Sender<StreamMessage> {
fn write_stream_message(&mut self, msg: StreamMessage) -> Result<(), ShellError> {
self.send(msg).map_err(|err| ShellError::NushellFailed {
msg: err.to_string(),
})
}
fn flush(&mut self) -> Result<(), ShellError> {
Ok(())
}
}
#[test]
fn reader_recv_list_messages() -> Result<(), ShellError> {
let (tx, rx) = mpsc::channel();
let mut reader = StreamReader::new(0, rx, TestSink::default());
tx.send(Ok(Some(StreamData::List(Value::test_int(5)))))
.unwrap();
drop(tx);
assert_eq!(Some(Value::test_int(5)), reader.recv()?);
Ok(())
}
#[test]
fn list_reader_recv_wrong_type() -> Result<(), ShellError> {
let (tx, rx) = mpsc::channel();
let mut reader = StreamReader::<Value, _>::new(0, rx, TestSink::default());
tx.send(Ok(Some(StreamData::Raw(Ok(vec![10, 20])))))
.unwrap();
tx.send(Ok(Some(StreamData::List(Value::test_nothing()))))
.unwrap();
drop(tx);
reader.recv().expect_err("should be an error");
reader.recv().expect("should be able to recover");
Ok(())
}
#[test]
fn reader_recv_raw_messages() -> Result<(), ShellError> {
let (tx, rx) = mpsc::channel();
let mut reader =
StreamReader::<Result<Vec<u8>, ShellError>, _>::new(0, rx, TestSink::default());
tx.send(Ok(Some(StreamData::Raw(Ok(vec![10, 20])))))
.unwrap();
drop(tx);
assert_eq!(Some(vec![10, 20]), reader.recv()?.transpose()?);
Ok(())
}
#[test]
fn raw_reader_recv_wrong_type() -> Result<(), ShellError> {
let (tx, rx) = mpsc::channel();
let mut reader =
StreamReader::<Result<Vec<u8>, ShellError>, _>::new(0, rx, TestSink::default());
tx.send(Ok(Some(StreamData::List(Value::test_nothing()))))
.unwrap();
tx.send(Ok(Some(StreamData::Raw(Ok(vec![10, 20])))))
.unwrap();
drop(tx);
reader.recv().expect_err("should be an error");
reader.recv().expect("should be able to recover");
Ok(())
}
#[test]
fn reader_recv_acknowledge() -> Result<(), ShellError> {
let (tx, rx) = mpsc::channel();
let mut reader = StreamReader::<Value, _>::new(0, rx, TestSink::default());
tx.send(Ok(Some(StreamData::List(Value::test_int(5)))))
.unwrap();
tx.send(Ok(Some(StreamData::List(Value::test_int(6)))))
.unwrap();
drop(tx);
reader.recv()?;
reader.recv()?;
let wrote = &reader.writer.0;
assert!(wrote.len() >= 2);
assert!(
matches!(wrote[0], StreamMessage::Ack(0)),
"0 = {:?}",
wrote[0]
);
assert!(
matches!(wrote[1], StreamMessage::Ack(0)),
"1 = {:?}",
wrote[1]
);
Ok(())
}
#[test]
fn reader_recv_end_of_stream() -> Result<(), ShellError> {
let (tx, rx) = mpsc::channel();
let mut reader = StreamReader::<Value, _>::new(0, rx, TestSink::default());
tx.send(Ok(Some(StreamData::List(Value::test_int(5)))))
.unwrap();
tx.send(Ok(None)).unwrap();
drop(tx);
assert!(reader.recv()?.is_some(), "actual message");
assert!(reader.recv()?.is_none(), "on close");
assert!(reader.recv()?.is_none(), "after close");
Ok(())
}
#[test]
fn reader_iter_fuse_on_error() -> Result<(), ShellError> {
let (tx, rx) = mpsc::channel();
let mut reader = StreamReader::<Value, _>::new(0, rx, TestSink::default());
drop(tx); // should cause error, because we didn't explicitly signal the end
assert!(
reader.next().is_some_and(|e| e.is_error()),
"should be error the first time"
);
assert!(reader.next().is_none(), "should be closed the second time");
Ok(())
}
#[test]
fn reader_drop() {
let (_tx, rx) = mpsc::channel();
// Flag set if drop message is received.
struct Check(Arc<AtomicBool>);
impl WriteStreamMessage for Check {
fn write_stream_message(&mut self, msg: StreamMessage) -> Result<(), ShellError> {
assert!(matches!(msg, StreamMessage::Drop(1)), "got {:?}", msg);
self.0.store(true, Relaxed);
Ok(())
}
fn flush(&mut self) -> Result<(), ShellError> {
Ok(())
}
}
let flag = Arc::new(AtomicBool::new(false));
let reader = StreamReader::<Value, _>::new(1, rx, Check(flag.clone()));
drop(reader);
assert!(flag.load(Relaxed));
}
#[test]
fn writer_write_all_stops_if_dropped() -> Result<(), ShellError> {
let signal = Arc::new(StreamWriterSignal::new(20));
let id = 1337;
let mut writer = StreamWriter::new(id, signal.clone(), TestSink::default());
// Simulate this by having it consume a stream that will actually do the drop halfway through
let iter = (0..5).map(Value::test_int).chain({
let mut n = 5;
std::iter::from_fn(move || {
// produces numbers 5..10, but drops for the first one
if n == 5 {
signal.set_dropped().unwrap();
}
if n < 10 {
let value = Value::test_int(n);
n += 1;
Some(value)
} else {
None
}
})
});
writer.write_all(iter)?;
assert!(writer.is_dropped()?);
let wrote = &writer.writer.0;
assert_eq!(5, wrote.len(), "length wrong: {wrote:?}");
for (n, message) in (0..5).zip(wrote) {
match message {
StreamMessage::Data(msg_id, StreamData::List(value)) => {
assert_eq!(id, *msg_id, "id");
assert_eq!(Value::test_int(n), *value, "value");
}
other => panic!("unexpected message: {other:?}"),
}
}
Ok(())
}
#[test]
fn writer_end() -> Result<(), ShellError> {
let signal = Arc::new(StreamWriterSignal::new(20));
let mut writer = StreamWriter::new(9001, signal.clone(), TestSink::default());
writer.end()?;
writer
.write(Value::test_int(2))
.expect_err("shouldn't be able to write after end");
writer.end().expect("end twice should be ok");
let wrote = &writer.writer.0;
assert!(
matches!(wrote.last(), Some(StreamMessage::End(9001))),
"didn't write end message: {wrote:?}"
);
Ok(())
}
#[test]
fn signal_set_dropped() -> Result<(), ShellError> {
let signal = StreamWriterSignal::new(4);
assert!(!signal.is_dropped()?);
signal.set_dropped()?;
assert!(signal.is_dropped()?);
Ok(())
}
#[test]
fn signal_notify_sent_false_if_unacknowledged() -> Result<(), ShellError> {
let signal = StreamWriterSignal::new(2);
assert!(signal.notify_sent()?);
for _ in 0..100 {
assert!(!signal.notify_sent()?);
}
Ok(())
}
#[test]
fn signal_notify_sent_never_false_if_flowing() -> Result<(), ShellError> {
let signal = StreamWriterSignal::new(1);
for _ in 0..100 {
signal.notify_acknowledged()?;
}
for _ in 0..100 {
assert!(signal.notify_sent()?);
}
Ok(())
}
#[test]
fn signal_wait_for_drain_blocks_on_unacknowledged() -> Result<(), ShellError> {
let signal = StreamWriterSignal::new(50);
std::thread::scope(|scope| {
let spawned = scope.spawn(|| {
for _ in 0..100 {
if !signal.notify_sent()? {
signal.wait_for_drain()?;
}
}
Ok(())
});
std::thread::sleep(WAIT_DURATION);
assert!(!spawned.is_finished(), "didn't block");
for _ in 0..100 {
signal.notify_acknowledged()?;
}
wait_for_condition(|| spawned.is_finished(), "blocked at end");
spawned.join().unwrap()
})
}
#[test]
fn signal_wait_for_drain_unblocks_on_dropped() -> Result<(), ShellError> {
let signal = StreamWriterSignal::new(1);
std::thread::scope(|scope| {
let spawned = scope.spawn(|| {
while !signal.is_dropped()? {
if !signal.notify_sent()? {
signal.wait_for_drain()?;
}
}
Ok(())
});
std::thread::sleep(WAIT_DURATION);
assert!(!spawned.is_finished(), "didn't block");
signal.set_dropped()?;
wait_for_condition(|| spawned.is_finished(), "still blocked at end");
spawned.join().unwrap()
})
}
#[test]
fn stream_manager_single_stream_read_scenario() -> Result<(), ShellError> {
let manager = StreamManager::new();
let handle = manager.get_handle();
let (tx, rx) = mpsc::channel();
let readable = handle.read_stream::<Value, _>(2, tx)?;
let expected_values = vec![Value::test_int(40), Value::test_string("hello")];
for value in &expected_values {
manager.handle_message(StreamMessage::Data(2, value.clone().into()))?;
}
manager.handle_message(StreamMessage::End(2))?;
let values = readable.collect::<Vec<Value>>();
assert_eq!(expected_values, values);
// Now check the sent messages on consumption
// Should be Ack for each message, then Drop
for _ in &expected_values {
match rx.try_recv().expect("failed to receive Ack") {
StreamMessage::Ack(2) => (),
other => panic!("should have been an Ack: {other:?}"),
}
}
match rx.try_recv().expect("failed to receive Drop") {
StreamMessage::Drop(2) => (),
other => panic!("should have been a Drop: {other:?}"),
}
Ok(())
}
#[test]
fn stream_manager_multi_stream_read_scenario() -> Result<(), ShellError> {
let manager = StreamManager::new();
let handle = manager.get_handle();
let (tx, rx) = mpsc::channel();
let readable_list = handle.read_stream::<Value, _>(2, tx.clone())?;
let readable_raw = handle.read_stream::<Result<Vec<u8>, _>, _>(3, tx)?;
let expected_values = (1..100).map(Value::test_int).collect::<Vec<_>>();
let expected_raw_buffers = (1..100).map(|n| vec![n]).collect::<Vec<Vec<u8>>>();
for (value, buf) in expected_values.iter().zip(&expected_raw_buffers) {
manager.handle_message(StreamMessage::Data(2, value.clone().into()))?;
manager.handle_message(StreamMessage::Data(3, StreamData::Raw(Ok(buf.clone()))))?;
}
manager.handle_message(StreamMessage::End(2))?;
manager.handle_message(StreamMessage::End(3))?;
let values = readable_list.collect::<Vec<Value>>();
let bufs = readable_raw.collect::<Result<Vec<Vec<u8>>, _>>()?;
for (expected_value, value) in expected_values.iter().zip(&values) {
assert_eq!(expected_value, value, "in List stream");
}
for (expected_buf, buf) in expected_raw_buffers.iter().zip(&bufs) {
assert_eq!(expected_buf, buf, "in Raw stream");
}
// Now check the sent messages on consumption
// Should be Ack for each message, then Drop
for _ in &expected_values {
match rx.try_recv().expect("failed to receive Ack") {
StreamMessage::Ack(2) => (),
other => panic!("should have been an Ack(2): {other:?}"),
}
}
match rx.try_recv().expect("failed to receive Drop") {
StreamMessage::Drop(2) => (),
other => panic!("should have been a Drop(2): {other:?}"),
}
for _ in &expected_values {
match rx.try_recv().expect("failed to receive Ack") {
StreamMessage::Ack(3) => (),
other => panic!("should have been an Ack(3): {other:?}"),
}
}
match rx.try_recv().expect("failed to receive Drop") {
StreamMessage::Drop(3) => (),
other => panic!("should have been a Drop(3): {other:?}"),
}
// Should be end of stream
assert!(
rx.try_recv().is_err(),
"more messages written to stream than expected"
);
Ok(())
}
#[test]
fn stream_manager_write_scenario() -> Result<(), ShellError> {
let manager = StreamManager::new();
let handle = manager.get_handle();
let (tx, rx) = mpsc::channel();
let mut writable = handle.write_stream(4, tx, 100)?;
let expected_values = vec![b"hello".to_vec(), b"world".to_vec(), b"test".to_vec()];
for value in &expected_values {
writable.write(Ok::<_, ShellError>(value.clone()))?;
}
// Now try signalling ack
assert_eq!(
expected_values.len() as i32,
writable.signal.lock()?.unacknowledged,
"unacknowledged initial count",
);
manager.handle_message(StreamMessage::Ack(4))?;
assert_eq!(
expected_values.len() as i32 - 1,
writable.signal.lock()?.unacknowledged,
"unacknowledged post-Ack count",
);
// ...and Drop
manager.handle_message(StreamMessage::Drop(4))?;
assert!(writable.is_dropped()?);
// Drop the StreamWriter...
drop(writable);
// now check what was actually written
for value in &expected_values {
match rx.try_recv().expect("failed to receive Data") {
StreamMessage::Data(4, StreamData::Raw(Ok(received))) => {
assert_eq!(*value, received);
}
other @ StreamMessage::Data(..) => panic!("wrong Data for {value:?}: {other:?}"),
other => panic!("should have been Data: {other:?}"),
}
}
match rx.try_recv().expect("failed to receive End") {
StreamMessage::End(4) => (),
other => panic!("should have been End: {other:?}"),
}
Ok(())
}
#[test]
fn stream_manager_broadcast_read_error() -> Result<(), ShellError> {
let manager = StreamManager::new();
let handle = manager.get_handle();
let mut readable0 = handle.read_stream::<Value, _>(0, TestSink::default())?;
let mut readable1 = handle.read_stream::<Result<Vec<u8>, _>, _>(1, TestSink::default())?;
let error = ShellError::PluginFailedToDecode {
msg: "test decode error".into(),
};
manager.broadcast_read_error(error.clone())?;
drop(manager);
assert_eq!(
error.to_string(),
readable0
.recv()
.transpose()
.expect("nothing received from readable0")
.expect_err("not an error received from readable0")
.to_string()
);
assert_eq!(
error.to_string(),
readable1
.next()
.expect("nothing received from readable1")
.expect_err("not an error received from readable1")
.to_string()
);
Ok(())
}
#[test]
fn stream_manager_drop_writers_on_drop() -> Result<(), ShellError> {
let manager = StreamManager::new();
let handle = manager.get_handle();
let writable = handle.write_stream(4, TestSink::default(), 100)?;
assert!(!writable.is_dropped()?);
drop(manager);
assert!(writable.is_dropped()?);
Ok(())
}

View File

@ -0,0 +1,134 @@
use nu_protocol::ShellError;
use std::{
collections::VecDeque,
sync::{Arc, Mutex},
};
use crate::{PluginRead, PluginWrite};
const FAILED: &str = "failed to lock TestCase";
/// Mock read/write helper for the engine and plugin interfaces.
#[derive(Debug, Clone)]
pub struct TestCase<I, O> {
r#in: Arc<Mutex<TestData<I>>>,
out: Arc<Mutex<TestData<O>>>,
}
#[derive(Debug)]
pub struct TestData<T> {
data: VecDeque<T>,
error: Option<ShellError>,
flushed: bool,
}
impl<T> Default for TestData<T> {
fn default() -> Self {
TestData {
data: VecDeque::new(),
error: None,
flushed: false,
}
}
}
impl<I, O> PluginRead<I> for TestCase<I, O> {
fn read(&mut self) -> Result<Option<I>, ShellError> {
let mut lock = self.r#in.lock().expect(FAILED);
if let Some(err) = lock.error.take() {
Err(err)
} else {
Ok(lock.data.pop_front())
}
}
}
impl<I, O> PluginWrite<O> for TestCase<I, O>
where
I: Send + Clone,
O: Send + Clone,
{
fn write(&self, data: &O) -> Result<(), ShellError> {
let mut lock = self.out.lock().expect(FAILED);
lock.flushed = false;
if let Some(err) = lock.error.take() {
Err(err)
} else {
lock.data.push_back(data.clone());
Ok(())
}
}
fn flush(&self) -> Result<(), ShellError> {
let mut lock = self.out.lock().expect(FAILED);
lock.flushed = true;
Ok(())
}
}
#[allow(dead_code)]
impl<I, O> TestCase<I, O> {
pub fn new() -> TestCase<I, O> {
TestCase {
r#in: Default::default(),
out: Default::default(),
}
}
/// Clear the read buffer.
pub fn clear(&self) {
self.r#in.lock().expect(FAILED).data.truncate(0);
}
/// Add input that will be read by the interface.
pub fn add(&self, input: impl Into<I>) {
self.r#in.lock().expect(FAILED).data.push_back(input.into());
}
/// Add multiple inputs that will be read by the interface.
pub fn extend(&self, inputs: impl IntoIterator<Item = I>) {
self.r#in.lock().expect(FAILED).data.extend(inputs);
}
/// Return an error from the next read operation.
pub fn set_read_error(&self, err: ShellError) {
self.r#in.lock().expect(FAILED).error = Some(err);
}
/// Return an error from the next write operation.
pub fn set_write_error(&self, err: ShellError) {
self.out.lock().expect(FAILED).error = Some(err);
}
/// Get the next output that was written.
pub fn next_written(&self) -> Option<O> {
self.out.lock().expect(FAILED).data.pop_front()
}
/// Iterator over written data.
pub fn written(&self) -> impl Iterator<Item = O> + '_ {
std::iter::from_fn(|| self.next_written())
}
/// Returns true if the writer was flushed after the last write operation.
pub fn was_flushed(&self) -> bool {
self.out.lock().expect(FAILED).flushed
}
/// Returns true if the reader has unconsumed reads.
pub fn has_unconsumed_read(&self) -> bool {
!self.r#in.lock().expect(FAILED).data.is_empty()
}
/// Returns true if the writer has unconsumed writes.
pub fn has_unconsumed_write(&self) -> bool {
!self.out.lock().expect(FAILED).data.is_empty()
}
}
impl<I, O> Default for TestCase<I, O> {
fn default() -> Self {
Self::new()
}
}

View File

@ -0,0 +1,573 @@
use crate::util::Sequence;
use super::{
stream::{StreamManager, StreamManagerHandle},
test_util::TestCase,
Interface, InterfaceManager, PluginRead, PluginWrite,
};
use nu_plugin_protocol::{
ExternalStreamInfo, ListStreamInfo, PipelineDataHeader, PluginInput, PluginOutput,
RawStreamInfo, StreamData, StreamMessage,
};
use nu_protocol::{
DataSource, ListStream, PipelineData, PipelineMetadata, RawStream, ShellError, Span, Value,
};
use std::{path::Path, sync::Arc};
fn test_metadata() -> PipelineMetadata {
PipelineMetadata {
data_source: DataSource::FilePath("/test/path".into()),
}
}
#[derive(Debug)]
struct TestInterfaceManager {
stream_manager: StreamManager,
test: TestCase<PluginInput, PluginOutput>,
seq: Arc<Sequence>,
}
#[derive(Debug, Clone)]
struct TestInterface {
stream_manager_handle: StreamManagerHandle,
test: TestCase<PluginInput, PluginOutput>,
seq: Arc<Sequence>,
}
impl TestInterfaceManager {
fn new(test: &TestCase<PluginInput, PluginOutput>) -> TestInterfaceManager {
TestInterfaceManager {
stream_manager: StreamManager::new(),
test: test.clone(),
seq: Arc::new(Sequence::default()),
}
}
fn consume_all(&mut self) -> Result<(), ShellError> {
while let Some(msg) = self.test.read()? {
self.consume(msg)?;
}
Ok(())
}
}
impl InterfaceManager for TestInterfaceManager {
type Interface = TestInterface;
type Input = PluginInput;
fn get_interface(&self) -> Self::Interface {
TestInterface {
stream_manager_handle: self.stream_manager.get_handle(),
test: self.test.clone(),
seq: self.seq.clone(),
}
}
fn consume(&mut self, input: Self::Input) -> Result<(), ShellError> {
match input {
PluginInput::Data(..)
| PluginInput::End(..)
| PluginInput::Drop(..)
| PluginInput::Ack(..) => self.consume_stream_message(
input
.try_into()
.expect("failed to convert message to StreamMessage"),
),
_ => unimplemented!(),
}
}
fn stream_manager(&self) -> &StreamManager {
&self.stream_manager
}
fn prepare_pipeline_data(&self, data: PipelineData) -> Result<PipelineData, ShellError> {
Ok(data.set_metadata(Some(test_metadata())))
}
}
impl Interface for TestInterface {
type Output = PluginOutput;
type DataContext = ();
fn write(&self, output: Self::Output) -> Result<(), ShellError> {
self.test.write(&output)
}
fn flush(&self) -> Result<(), ShellError> {
Ok(())
}
fn stream_id_sequence(&self) -> &Sequence {
&self.seq
}
fn stream_manager_handle(&self) -> &StreamManagerHandle {
&self.stream_manager_handle
}
fn prepare_pipeline_data(
&self,
data: PipelineData,
_context: &(),
) -> Result<PipelineData, ShellError> {
// Add an arbitrary check to the data to verify this is being called
match data {
PipelineData::Value(Value::Binary { .. }, None) => Err(ShellError::NushellFailed {
msg: "TEST can't send binary".into(),
}),
_ => Ok(data),
}
}
}
#[test]
fn read_pipeline_data_empty() -> Result<(), ShellError> {
let manager = TestInterfaceManager::new(&TestCase::new());
let header = PipelineDataHeader::Empty;
assert!(matches!(
manager.read_pipeline_data(header, None)?,
PipelineData::Empty
));
Ok(())
}
#[test]
fn read_pipeline_data_value() -> Result<(), ShellError> {
let manager = TestInterfaceManager::new(&TestCase::new());
let value = Value::test_int(4);
let header = PipelineDataHeader::Value(value.clone());
match manager.read_pipeline_data(header, None)? {
PipelineData::Value(read_value, _) => assert_eq!(value, read_value),
PipelineData::ListStream(_, _) => panic!("unexpected ListStream"),
PipelineData::ExternalStream { .. } => panic!("unexpected ExternalStream"),
PipelineData::Empty => panic!("unexpected Empty"),
}
Ok(())
}
#[test]
fn read_pipeline_data_list_stream() -> Result<(), ShellError> {
let test = TestCase::new();
let mut manager = TestInterfaceManager::new(&test);
let data = (0..100).map(Value::test_int).collect::<Vec<_>>();
for value in &data {
test.add(StreamMessage::Data(7, value.clone().into()));
}
test.add(StreamMessage::End(7));
let header = PipelineDataHeader::ListStream(ListStreamInfo { id: 7 });
let pipe = manager.read_pipeline_data(header, None)?;
assert!(
matches!(pipe, PipelineData::ListStream(..)),
"unexpected PipelineData: {pipe:?}"
);
// need to consume input
manager.consume_all()?;
let mut count = 0;
for (expected, read) in data.into_iter().zip(pipe) {
assert_eq!(expected, read);
count += 1;
}
assert_eq!(100, count);
assert!(test.has_unconsumed_write());
Ok(())
}
#[test]
fn read_pipeline_data_external_stream() -> Result<(), ShellError> {
let test = TestCase::new();
let mut manager = TestInterfaceManager::new(&test);
let iterations = 100;
let out_pattern = b"hello".to_vec();
let err_pattern = vec![5, 4, 3, 2];
test.add(StreamMessage::Data(14, Value::test_int(1).into()));
for _ in 0..iterations {
test.add(StreamMessage::Data(
12,
StreamData::Raw(Ok(out_pattern.clone())),
));
test.add(StreamMessage::Data(
13,
StreamData::Raw(Ok(err_pattern.clone())),
));
}
test.add(StreamMessage::End(12));
test.add(StreamMessage::End(13));
test.add(StreamMessage::End(14));
let test_span = Span::new(10, 13);
let header = PipelineDataHeader::ExternalStream(ExternalStreamInfo {
span: test_span,
stdout: Some(RawStreamInfo {
id: 12,
is_binary: false,
known_size: Some((out_pattern.len() * iterations) as u64),
}),
stderr: Some(RawStreamInfo {
id: 13,
is_binary: true,
known_size: None,
}),
exit_code: Some(ListStreamInfo { id: 14 }),
trim_end_newline: true,
});
let pipe = manager.read_pipeline_data(header, None)?;
// need to consume input
manager.consume_all()?;
match pipe {
PipelineData::ExternalStream {
stdout,
stderr,
exit_code,
span,
metadata,
trim_end_newline,
} => {
let stdout = stdout.expect("stdout is None");
let stderr = stderr.expect("stderr is None");
let exit_code = exit_code.expect("exit_code is None");
assert_eq!(test_span, span);
assert!(
metadata.is_some(),
"expected metadata to be Some due to prepare_pipeline_data()"
);
assert!(trim_end_newline);
assert!(!stdout.is_binary);
assert!(stderr.is_binary);
assert_eq!(
Some((out_pattern.len() * iterations) as u64),
stdout.known_size
);
assert_eq!(None, stderr.known_size);
// check the streams
let mut count = 0;
for chunk in stdout.stream {
assert_eq!(out_pattern, chunk?);
count += 1;
}
assert_eq!(iterations, count, "stdout length");
let mut count = 0;
for chunk in stderr.stream {
assert_eq!(err_pattern, chunk?);
count += 1;
}
assert_eq!(iterations, count, "stderr length");
assert_eq!(vec![Value::test_int(1)], exit_code.collect::<Vec<_>>());
}
_ => panic!("unexpected PipelineData: {pipe:?}"),
}
// Don't need to check exactly what was written, just be sure that there is some output
assert!(test.has_unconsumed_write());
Ok(())
}
#[test]
fn read_pipeline_data_ctrlc() -> Result<(), ShellError> {
let manager = TestInterfaceManager::new(&TestCase::new());
let header = PipelineDataHeader::ListStream(ListStreamInfo { id: 0 });
let ctrlc = Default::default();
match manager.read_pipeline_data(header, Some(&ctrlc))? {
PipelineData::ListStream(
ListStream {
ctrlc: stream_ctrlc,
..
},
_,
) => {
assert!(Arc::ptr_eq(&ctrlc, &stream_ctrlc.expect("ctrlc not set")));
Ok(())
}
_ => panic!("Unexpected PipelineData, should have been ListStream"),
}
}
#[test]
fn read_pipeline_data_prepared_properly() -> Result<(), ShellError> {
let manager = TestInterfaceManager::new(&TestCase::new());
let header = PipelineDataHeader::ListStream(ListStreamInfo { id: 0 });
match manager.read_pipeline_data(header, None)? {
PipelineData::ListStream(_, meta) => match meta {
Some(PipelineMetadata { data_source }) => match data_source {
DataSource::FilePath(path) => {
assert_eq!(Path::new("/test/path"), path);
Ok(())
}
_ => panic!("wrong metadata: {data_source:?}"),
},
None => panic!("metadata not set"),
},
_ => panic!("Unexpected PipelineData, should have been ListStream"),
}
}
#[test]
fn write_pipeline_data_empty() -> Result<(), ShellError> {
let test = TestCase::new();
let manager = TestInterfaceManager::new(&test);
let interface = manager.get_interface();
let (header, writer) = interface.init_write_pipeline_data(PipelineData::Empty, &())?;
assert!(matches!(header, PipelineDataHeader::Empty));
writer.write()?;
assert!(
!test.has_unconsumed_write(),
"Empty shouldn't write any stream messages, test: {test:#?}"
);
Ok(())
}
#[test]
fn write_pipeline_data_value() -> Result<(), ShellError> {
let test = TestCase::new();
let manager = TestInterfaceManager::new(&test);
let interface = manager.get_interface();
let value = Value::test_int(7);
let (header, writer) =
interface.init_write_pipeline_data(PipelineData::Value(value.clone(), None), &())?;
match header {
PipelineDataHeader::Value(read_value) => assert_eq!(value, read_value),
_ => panic!("unexpected header: {header:?}"),
}
writer.write()?;
assert!(
!test.has_unconsumed_write(),
"Value shouldn't write any stream messages, test: {test:#?}"
);
Ok(())
}
#[test]
fn write_pipeline_data_prepared_properly() {
let manager = TestInterfaceManager::new(&TestCase::new());
let interface = manager.get_interface();
// Sending a binary should be an error in our test scenario
let value = Value::test_binary(vec![7, 8]);
match interface.init_write_pipeline_data(PipelineData::Value(value, None), &()) {
Ok(_) => panic!("prepare_pipeline_data was not called"),
Err(err) => {
assert_eq!(
ShellError::NushellFailed {
msg: "TEST can't send binary".into()
}
.to_string(),
err.to_string()
);
}
}
}
#[test]
fn write_pipeline_data_list_stream() -> Result<(), ShellError> {
let test = TestCase::new();
let manager = TestInterfaceManager::new(&test);
let interface = manager.get_interface();
let values = vec![
Value::test_int(40),
Value::test_bool(false),
Value::test_string("this is a test"),
];
// Set up pipeline data for a list stream
let pipe = PipelineData::ListStream(
ListStream::from_stream(values.clone().into_iter(), None),
None,
);
let (header, writer) = interface.init_write_pipeline_data(pipe, &())?;
let info = match header {
PipelineDataHeader::ListStream(info) => info,
_ => panic!("unexpected header: {header:?}"),
};
writer.write()?;
// Now make sure the stream messages have been written
for value in values {
match test.next_written().expect("unexpected end of stream") {
PluginOutput::Data(id, data) => {
assert_eq!(info.id, id, "Data id");
match data {
StreamData::List(read_value) => assert_eq!(value, read_value, "Data value"),
_ => panic!("unexpected Data: {data:?}"),
}
}
other => panic!("unexpected output: {other:?}"),
}
}
match test.next_written().expect("unexpected end of stream") {
PluginOutput::End(id) => {
assert_eq!(info.id, id, "End id");
}
other => panic!("unexpected output: {other:?}"),
}
assert!(!test.has_unconsumed_write());
Ok(())
}
#[test]
fn write_pipeline_data_external_stream() -> Result<(), ShellError> {
let test = TestCase::new();
let manager = TestInterfaceManager::new(&test);
let interface = manager.get_interface();
let stdout_bufs = vec![
b"hello".to_vec(),
b"world".to_vec(),
b"these are tests".to_vec(),
];
let stdout_len = stdout_bufs.iter().map(|b| b.len() as u64).sum::<u64>();
let stderr_bufs = vec![b"error messages".to_vec(), b"go here".to_vec()];
let exit_code = Value::test_int(7);
let span = Span::new(400, 500);
// Set up pipeline data for an external stream
let pipe = PipelineData::ExternalStream {
stdout: Some(RawStream::new(
Box::new(stdout_bufs.clone().into_iter().map(Ok)),
None,
span,
Some(stdout_len),
)),
stderr: Some(RawStream::new(
Box::new(stderr_bufs.clone().into_iter().map(Ok)),
None,
span,
None,
)),
exit_code: Some(ListStream::from_stream(
std::iter::once(exit_code.clone()),
None,
)),
span,
metadata: None,
trim_end_newline: true,
};
let (header, writer) = interface.init_write_pipeline_data(pipe, &())?;
let info = match header {
PipelineDataHeader::ExternalStream(info) => info,
_ => panic!("unexpected header: {header:?}"),
};
writer.write()?;
let stdout_info = info.stdout.as_ref().expect("stdout info is None");
let stderr_info = info.stderr.as_ref().expect("stderr info is None");
let exit_code_info = info.exit_code.as_ref().expect("exit code info is None");
assert_eq!(span, info.span);
assert!(info.trim_end_newline);
assert_eq!(Some(stdout_len), stdout_info.known_size);
assert_eq!(None, stderr_info.known_size);
// Now make sure the stream messages have been written
let mut stdout_iter = stdout_bufs.into_iter();
let mut stderr_iter = stderr_bufs.into_iter();
let mut exit_code_iter = std::iter::once(exit_code);
let mut stdout_ended = false;
let mut stderr_ended = false;
let mut exit_code_ended = false;
// There's no specific order these messages must come in with respect to how the streams are
// interleaved, but all of the data for each stream must be in its original order, and the
// End must come after all Data
for msg in test.written() {
match msg {
PluginOutput::Data(id, data) => {
if id == stdout_info.id {
let result: Result<Vec<u8>, ShellError> =
data.try_into().expect("wrong data in stdout stream");
assert_eq!(
stdout_iter.next().expect("too much data in stdout"),
result.expect("unexpected error in stdout stream")
);
} else if id == stderr_info.id {
let result: Result<Vec<u8>, ShellError> =
data.try_into().expect("wrong data in stderr stream");
assert_eq!(
stderr_iter.next().expect("too much data in stderr"),
result.expect("unexpected error in stderr stream")
);
} else if id == exit_code_info.id {
let code: Value = data.try_into().expect("wrong data in stderr stream");
assert_eq!(
exit_code_iter.next().expect("too much data in stderr"),
code
);
} else {
panic!("unrecognized stream id: {id}");
}
}
PluginOutput::End(id) => {
if id == stdout_info.id {
assert!(!stdout_ended, "double End of stdout");
assert!(stdout_iter.next().is_none(), "unexpected end of stdout");
stdout_ended = true;
} else if id == stderr_info.id {
assert!(!stderr_ended, "double End of stderr");
assert!(stderr_iter.next().is_none(), "unexpected end of stderr");
stderr_ended = true;
} else if id == exit_code_info.id {
assert!(!exit_code_ended, "double End of exit_code");
assert!(
exit_code_iter.next().is_none(),
"unexpected end of exit_code"
);
exit_code_ended = true;
} else {
panic!("unrecognized stream id: {id}");
}
}
other => panic!("unexpected output: {other:?}"),
}
}
assert!(stdout_ended, "stdout did not End");
assert!(stderr_ended, "stderr did not End");
assert!(exit_code_ended, "exit_code did not End");
Ok(())
}

View File

@ -0,0 +1,24 @@
//! Functionality and types shared between the plugin and the engine, other than protocol types.
//!
//! If you are writing a plugin, you probably don't need this crate. We will make fewer guarantees
//! for the stability of the interface of this crate than for `nu_plugin`.
pub mod util;
mod communication_mode;
mod interface;
mod serializers;
pub use communication_mode::{
ClientCommunicationIo, CommunicationMode, PreparedServerCommunication, ServerCommunicationIo,
};
pub use interface::{
stream::{FromShellError, StreamManager, StreamManagerHandle, StreamReader, StreamWriter},
Interface, InterfaceManager, PipelineDataWriter, PluginRead, PluginWrite,
};
pub use serializers::{
json::JsonSerializer, msgpack::MsgPackSerializer, Encoder, EncodingType, PluginEncoder,
};
#[doc(hidden)]
pub use interface::test_util as interface_test_util;

View File

@ -0,0 +1,133 @@
use nu_plugin_protocol::{PluginInput, PluginOutput};
use nu_protocol::ShellError;
use serde::Deserialize;
use crate::{Encoder, PluginEncoder};
/// A `PluginEncoder` that enables the plugin to communicate with Nushell with JSON
/// serialized data.
///
/// Each message in the stream is followed by a newline when serializing, but is not required for
/// deserialization. The output is not pretty printed and each object does not contain newlines.
/// If it is more convenient, a plugin may choose to separate messages by newline.
#[derive(Clone, Copy, Debug)]
pub struct JsonSerializer;
impl PluginEncoder for JsonSerializer {
fn name(&self) -> &str {
"json"
}
}
impl Encoder<PluginInput> for JsonSerializer {
fn encode(
&self,
plugin_input: &PluginInput,
writer: &mut impl std::io::Write,
) -> Result<(), nu_protocol::ShellError> {
serde_json::to_writer(&mut *writer, plugin_input).map_err(json_encode_err)?;
writer.write_all(b"\n").map_err(|err| ShellError::IOError {
msg: err.to_string(),
})
}
fn decode(
&self,
reader: &mut impl std::io::BufRead,
) -> Result<Option<PluginInput>, nu_protocol::ShellError> {
let mut de = serde_json::Deserializer::from_reader(reader);
PluginInput::deserialize(&mut de)
.map(Some)
.or_else(json_decode_err)
}
}
impl Encoder<PluginOutput> for JsonSerializer {
fn encode(
&self,
plugin_output: &PluginOutput,
writer: &mut impl std::io::Write,
) -> Result<(), ShellError> {
serde_json::to_writer(&mut *writer, plugin_output).map_err(json_encode_err)?;
writer.write_all(b"\n").map_err(|err| ShellError::IOError {
msg: err.to_string(),
})
}
fn decode(
&self,
reader: &mut impl std::io::BufRead,
) -> Result<Option<PluginOutput>, ShellError> {
let mut de = serde_json::Deserializer::from_reader(reader);
PluginOutput::deserialize(&mut de)
.map(Some)
.or_else(json_decode_err)
}
}
/// Handle a `serde_json` encode error.
fn json_encode_err(err: serde_json::Error) -> ShellError {
if err.is_io() {
ShellError::IOError {
msg: err.to_string(),
}
} else {
ShellError::PluginFailedToEncode {
msg: err.to_string(),
}
}
}
/// Handle a `serde_json` decode error. Returns `Ok(None)` on eof.
fn json_decode_err<T>(err: serde_json::Error) -> Result<Option<T>, ShellError> {
if err.is_eof() {
Ok(None)
} else if err.is_io() {
Err(ShellError::IOError {
msg: err.to_string(),
})
} else {
Err(ShellError::PluginFailedToDecode {
msg: err.to_string(),
})
}
}
#[cfg(test)]
mod tests {
use super::*;
crate::serializers::tests::generate_tests!(JsonSerializer {});
#[test]
fn json_ends_in_newline() {
let mut out = vec![];
JsonSerializer {}
.encode(&PluginInput::Call(0, PluginCall::Signature), &mut out)
.expect("serialization error");
let string = std::str::from_utf8(&out).expect("utf-8 error");
assert!(
string.ends_with('\n'),
"doesn't end with newline: {:?}",
string
);
}
#[test]
fn json_has_no_other_newlines() {
let mut out = vec![];
// use something deeply nested, to try to trigger any pretty printing
let output = PluginOutput::Data(
0,
StreamData::List(Value::test_list(vec![
Value::test_int(4),
// in case escaping failed
Value::test_string("newline\ncontaining\nstring"),
])),
);
JsonSerializer {}
.encode(&output, &mut out)
.expect("serialization error");
let string = std::str::from_utf8(&out).expect("utf-8 error");
assert_eq!(1, string.chars().filter(|ch| *ch == '\n').count());
}
}

View File

@ -0,0 +1,71 @@
use nu_plugin_protocol::{PluginInput, PluginOutput};
use nu_protocol::ShellError;
pub mod json;
pub mod msgpack;
#[cfg(test)]
mod tests;
/// Encoder for a specific message type. Usually implemented on [`PluginInput`]
/// and [`PluginOutput`].
pub trait Encoder<T>: Clone + Send + Sync {
/// Serialize a value in the [`PluginEncoder`]s format
///
/// Returns [`ShellError::IOError`] if there was a problem writing, or
/// [`ShellError::PluginFailedToEncode`] for a serialization error.
fn encode(&self, data: &T, writer: &mut impl std::io::Write) -> Result<(), ShellError>;
/// Deserialize a value from the [`PluginEncoder`]'s format
///
/// Returns `None` if there is no more output to receive.
///
/// Returns [`ShellError::IOError`] if there was a problem reading, or
/// [`ShellError::PluginFailedToDecode`] for a deserialization error.
fn decode(&self, reader: &mut impl std::io::BufRead) -> Result<Option<T>, ShellError>;
}
/// Encoding scheme that defines a plugin's communication protocol with Nu
pub trait PluginEncoder: Encoder<PluginInput> + Encoder<PluginOutput> {
/// The name of the encoder (e.g., `json`)
fn name(&self) -> &str;
}
/// Enum that supports all of the plugin serialization formats.
#[derive(Clone, Copy, Debug)]
pub enum EncodingType {
Json(json::JsonSerializer),
MsgPack(msgpack::MsgPackSerializer),
}
impl EncodingType {
/// Determine the plugin encoding type from the provided byte string (either `b"json"` or
/// `b"msgpack"`).
pub fn try_from_bytes(bytes: &[u8]) -> Option<Self> {
match bytes {
b"json" => Some(Self::Json(json::JsonSerializer {})),
b"msgpack" => Some(Self::MsgPack(msgpack::MsgPackSerializer {})),
_ => None,
}
}
}
impl<T> Encoder<T> for EncodingType
where
json::JsonSerializer: Encoder<T>,
msgpack::MsgPackSerializer: Encoder<T>,
{
fn encode(&self, data: &T, writer: &mut impl std::io::Write) -> Result<(), ShellError> {
match self {
EncodingType::Json(encoder) => encoder.encode(data, writer),
EncodingType::MsgPack(encoder) => encoder.encode(data, writer),
}
}
fn decode(&self, reader: &mut impl std::io::BufRead) -> Result<Option<T>, ShellError> {
match self {
EncodingType::Json(encoder) => encoder.decode(reader),
EncodingType::MsgPack(encoder) => encoder.decode(reader),
}
}
}

View File

@ -0,0 +1,108 @@
use std::io::ErrorKind;
use nu_plugin_protocol::{PluginInput, PluginOutput};
use nu_protocol::ShellError;
use serde::Deserialize;
use crate::{Encoder, PluginEncoder};
/// A `PluginEncoder` that enables the plugin to communicate with Nushell with MsgPack
/// serialized data.
///
/// Each message is written as a MessagePack object. There is no message envelope or separator.
#[derive(Clone, Copy, Debug)]
pub struct MsgPackSerializer;
impl PluginEncoder for MsgPackSerializer {
fn name(&self) -> &str {
"msgpack"
}
}
impl Encoder<PluginInput> for MsgPackSerializer {
fn encode(
&self,
plugin_input: &PluginInput,
writer: &mut impl std::io::Write,
) -> Result<(), nu_protocol::ShellError> {
rmp_serde::encode::write_named(writer, plugin_input).map_err(rmp_encode_err)
}
fn decode(
&self,
reader: &mut impl std::io::BufRead,
) -> Result<Option<PluginInput>, ShellError> {
let mut de = rmp_serde::Deserializer::new(reader);
PluginInput::deserialize(&mut de)
.map(Some)
.or_else(rmp_decode_err)
}
}
impl Encoder<PluginOutput> for MsgPackSerializer {
fn encode(
&self,
plugin_output: &PluginOutput,
writer: &mut impl std::io::Write,
) -> Result<(), ShellError> {
rmp_serde::encode::write_named(writer, plugin_output).map_err(rmp_encode_err)
}
fn decode(
&self,
reader: &mut impl std::io::BufRead,
) -> Result<Option<PluginOutput>, ShellError> {
let mut de = rmp_serde::Deserializer::new(reader);
PluginOutput::deserialize(&mut de)
.map(Some)
.or_else(rmp_decode_err)
}
}
/// Handle a msgpack encode error
fn rmp_encode_err(err: rmp_serde::encode::Error) -> ShellError {
match err {
rmp_serde::encode::Error::InvalidValueWrite(_) => {
// I/O error
ShellError::IOError {
msg: err.to_string(),
}
}
_ => {
// Something else
ShellError::PluginFailedToEncode {
msg: err.to_string(),
}
}
}
}
/// Handle a msgpack decode error. Returns `Ok(None)` on eof
fn rmp_decode_err<T>(err: rmp_serde::decode::Error) -> Result<Option<T>, ShellError> {
match err {
rmp_serde::decode::Error::InvalidMarkerRead(err)
| rmp_serde::decode::Error::InvalidDataRead(err) => {
if matches!(err.kind(), ErrorKind::UnexpectedEof) {
// EOF
Ok(None)
} else {
// I/O error
Err(ShellError::IOError {
msg: err.to_string(),
})
}
}
_ => {
// Something else
Err(ShellError::PluginFailedToDecode {
msg: err.to_string(),
})
}
}
}
#[cfg(test)]
mod tests {
use super::*;
crate::serializers::tests::generate_tests!(MsgPackSerializer {});
}

View File

@ -0,0 +1,557 @@
macro_rules! generate_tests {
($encoder:expr) => {
use nu_plugin_protocol::{
CallInfo, CustomValueOp, EvaluatedCall, PipelineDataHeader, PluginCall,
PluginCallResponse, PluginCustomValue, PluginInput, PluginOption, PluginOutput,
StreamData,
};
use nu_protocol::{
LabeledError, PluginSignature, Signature, Span, Spanned, SyntaxShape, Value,
};
#[test]
fn decode_eof() {
let mut buffer: &[u8] = &[];
let encoder = $encoder;
let result: Option<PluginInput> = encoder
.decode(&mut buffer)
.expect("eof should not result in an error");
assert!(result.is_none(), "decode result: {result:?}");
let result: Option<PluginOutput> = encoder
.decode(&mut buffer)
.expect("eof should not result in an error");
assert!(result.is_none(), "decode result: {result:?}");
}
#[test]
fn decode_io_error() {
struct ErrorProducer;
impl std::io::Read for ErrorProducer {
fn read(&mut self, _buf: &mut [u8]) -> std::io::Result<usize> {
Err(std::io::Error::from(std::io::ErrorKind::ConnectionReset))
}
}
let encoder = $encoder;
let mut buffered = std::io::BufReader::new(ErrorProducer);
match Encoder::<PluginInput>::decode(&encoder, &mut buffered) {
Ok(_) => panic!("decode: i/o error was not passed through"),
Err(ShellError::IOError { .. }) => (), // okay
Err(other) => panic!(
"decode: got other error, should have been a \
ShellError::IOError: {other:?}"
),
}
match Encoder::<PluginOutput>::decode(&encoder, &mut buffered) {
Ok(_) => panic!("decode: i/o error was not passed through"),
Err(ShellError::IOError { .. }) => (), // okay
Err(other) => panic!(
"decode: got other error, should have been a \
ShellError::IOError: {other:?}"
),
}
}
#[test]
fn decode_gibberish() {
// just a sequence of bytes that shouldn't be valid in anything we use
let gibberish: &[u8] = &[
0, 80, 74, 85, 117, 122, 86, 100, 74, 115, 20, 104, 55, 98, 67, 203, 83, 85, 77,
112, 74, 79, 254, 71, 80,
];
let encoder = $encoder;
let mut buffered = std::io::BufReader::new(&gibberish[..]);
match Encoder::<PluginInput>::decode(&encoder, &mut buffered) {
Ok(value) => panic!("decode: parsed successfully => {value:?}"),
Err(ShellError::PluginFailedToDecode { .. }) => (), // okay
Err(other) => panic!(
"decode: got other error, should have been a \
ShellError::PluginFailedToDecode: {other:?}"
),
}
let mut buffered = std::io::BufReader::new(&gibberish[..]);
match Encoder::<PluginOutput>::decode(&encoder, &mut buffered) {
Ok(value) => panic!("decode: parsed successfully => {value:?}"),
Err(ShellError::PluginFailedToDecode { .. }) => (), // okay
Err(other) => panic!(
"decode: got other error, should have been a \
ShellError::PluginFailedToDecode: {other:?}"
),
}
}
#[test]
fn call_round_trip_signature() {
let plugin_call = PluginCall::Signature;
let plugin_input = PluginInput::Call(0, plugin_call);
let encoder = $encoder;
let mut buffer: Vec<u8> = Vec::new();
encoder
.encode(&plugin_input, &mut buffer)
.expect("unable to serialize message");
let returned = encoder
.decode(&mut buffer.as_slice())
.expect("unable to deserialize message")
.expect("eof");
match returned {
PluginInput::Call(0, PluginCall::Signature) => {}
_ => panic!("decoded into wrong value: {returned:?}"),
}
}
#[test]
fn call_round_trip_run() {
let name = "test".to_string();
let input = Value::bool(false, Span::new(1, 20));
let call = EvaluatedCall {
head: Span::new(0, 10),
positional: vec![
Value::float(1.0, Span::new(0, 10)),
Value::string("something", Span::new(0, 10)),
],
named: vec![(
Spanned {
item: "name".to_string(),
span: Span::new(0, 10),
},
Some(Value::float(1.0, Span::new(0, 10))),
)],
};
let plugin_call = PluginCall::Run(CallInfo {
name: name.clone(),
call: call.clone(),
input: PipelineDataHeader::Value(input.clone()),
});
let plugin_input = PluginInput::Call(1, plugin_call);
let encoder = $encoder;
let mut buffer: Vec<u8> = Vec::new();
encoder
.encode(&plugin_input, &mut buffer)
.expect("unable to serialize message");
let returned = encoder
.decode(&mut buffer.as_slice())
.expect("unable to deserialize message")
.expect("eof");
match returned {
PluginInput::Call(1, PluginCall::Run(call_info)) => {
assert_eq!(name, call_info.name);
assert_eq!(PipelineDataHeader::Value(input), call_info.input);
assert_eq!(call.head, call_info.call.head);
assert_eq!(call.positional.len(), call_info.call.positional.len());
call.positional
.iter()
.zip(call_info.call.positional.iter())
.for_each(|(lhs, rhs)| assert_eq!(lhs, rhs));
call.named
.iter()
.zip(call_info.call.named.iter())
.for_each(|(lhs, rhs)| {
// Comparing the keys
assert_eq!(lhs.0.item, rhs.0.item);
match (&lhs.1, &rhs.1) {
(None, None) => {}
(Some(a), Some(b)) => assert_eq!(a, b),
_ => panic!("not matching values"),
}
});
}
_ => panic!("decoded into wrong value: {returned:?}"),
}
}
#[test]
fn call_round_trip_customvalueop() {
let data = vec![1, 2, 3, 4, 5, 6, 7];
let span = Span::new(0, 20);
let custom_value_op = PluginCall::CustomValueOp(
Spanned {
item: PluginCustomValue::new("Foo".into(), data.clone(), false),
span,
},
CustomValueOp::ToBaseValue,
);
let plugin_input = PluginInput::Call(2, custom_value_op);
let encoder = $encoder;
let mut buffer: Vec<u8> = Vec::new();
encoder
.encode(&plugin_input, &mut buffer)
.expect("unable to serialize message");
let returned = encoder
.decode(&mut buffer.as_slice())
.expect("unable to deserialize message")
.expect("eof");
match returned {
PluginInput::Call(2, PluginCall::CustomValueOp(val, op)) => {
assert_eq!("Foo", val.item.name());
assert_eq!(data, val.item.data());
assert_eq!(span, val.span);
#[allow(unreachable_patterns)]
match op {
CustomValueOp::ToBaseValue => (),
_ => panic!("wrong op: {op:?}"),
}
}
_ => panic!("decoded into wrong value: {returned:?}"),
}
}
#[test]
fn response_round_trip_signature() {
let signature = PluginSignature::new(
Signature::build("nu-plugin")
.required("first", SyntaxShape::String, "first required")
.required("second", SyntaxShape::Int, "second required")
.required_named("first-named", SyntaxShape::String, "first named", Some('f'))
.required_named(
"second-named",
SyntaxShape::String,
"second named",
Some('s'),
)
.rest("remaining", SyntaxShape::Int, "remaining"),
vec![],
);
let response = PluginCallResponse::Signature(vec![signature.clone()]);
let output = PluginOutput::CallResponse(3, response);
let encoder = $encoder;
let mut buffer: Vec<u8> = Vec::new();
encoder
.encode(&output, &mut buffer)
.expect("unable to serialize message");
let returned = encoder
.decode(&mut buffer.as_slice())
.expect("unable to deserialize message")
.expect("eof");
match returned {
PluginOutput::CallResponse(
3,
PluginCallResponse::Signature(returned_signature),
) => {
assert_eq!(returned_signature.len(), 1);
assert_eq!(signature.sig.name, returned_signature[0].sig.name);
assert_eq!(signature.sig.usage, returned_signature[0].sig.usage);
assert_eq!(
signature.sig.extra_usage,
returned_signature[0].sig.extra_usage
);
assert_eq!(signature.sig.is_filter, returned_signature[0].sig.is_filter);
signature
.sig
.required_positional
.iter()
.zip(returned_signature[0].sig.required_positional.iter())
.for_each(|(lhs, rhs)| assert_eq!(lhs, rhs));
signature
.sig
.optional_positional
.iter()
.zip(returned_signature[0].sig.optional_positional.iter())
.for_each(|(lhs, rhs)| assert_eq!(lhs, rhs));
signature
.sig
.named
.iter()
.zip(returned_signature[0].sig.named.iter())
.for_each(|(lhs, rhs)| assert_eq!(lhs, rhs));
assert_eq!(
signature.sig.rest_positional,
returned_signature[0].sig.rest_positional,
);
}
_ => panic!("decoded into wrong value: {returned:?}"),
}
}
#[test]
fn response_round_trip_value() {
let value = Value::int(10, Span::new(2, 30));
let response = PluginCallResponse::value(value.clone());
let output = PluginOutput::CallResponse(4, response);
let encoder = $encoder;
let mut buffer: Vec<u8> = Vec::new();
encoder
.encode(&output, &mut buffer)
.expect("unable to serialize message");
let returned = encoder
.decode(&mut buffer.as_slice())
.expect("unable to deserialize message")
.expect("eof");
match returned {
PluginOutput::CallResponse(
4,
PluginCallResponse::PipelineData(PipelineDataHeader::Value(returned_value)),
) => {
assert_eq!(value, returned_value)
}
_ => panic!("decoded into wrong value: {returned:?}"),
}
}
#[test]
fn response_round_trip_plugin_custom_value() {
let name = "test";
let data = vec![1, 2, 3, 4, 5];
let span = Span::new(2, 30);
let value = Value::custom(
Box::new(PluginCustomValue::new(name.into(), data.clone(), true)),
span,
);
let response = PluginCallResponse::PipelineData(PipelineDataHeader::Value(value));
let output = PluginOutput::CallResponse(5, response);
let encoder = $encoder;
let mut buffer: Vec<u8> = Vec::new();
encoder
.encode(&output, &mut buffer)
.expect("unable to serialize message");
let returned = encoder
.decode(&mut buffer.as_slice())
.expect("unable to deserialize message")
.expect("eof");
match returned {
PluginOutput::CallResponse(
5,
PluginCallResponse::PipelineData(PipelineDataHeader::Value(returned_value)),
) => {
assert_eq!(span, returned_value.span());
if let Some(plugin_val) = returned_value
.as_custom_value()
.unwrap()
.as_any()
.downcast_ref::<PluginCustomValue>()
{
assert_eq!(name, plugin_val.name());
assert_eq!(data, plugin_val.data());
assert!(plugin_val.notify_on_drop());
} else {
panic!("returned CustomValue is not a PluginCustomValue");
}
}
_ => panic!("decoded into wrong value: {returned:?}"),
}
}
#[test]
fn response_round_trip_error() {
let error = LabeledError::new("label")
.with_code("test::error")
.with_url("https://example.org/test/error")
.with_help("some help")
.with_label("msg", Span::new(2, 30))
.with_inner(ShellError::IOError {
msg: "io error".into(),
});
let response = PluginCallResponse::Error(error.clone());
let output = PluginOutput::CallResponse(6, response);
let encoder = $encoder;
let mut buffer: Vec<u8> = Vec::new();
encoder
.encode(&output, &mut buffer)
.expect("unable to serialize message");
let returned = encoder
.decode(&mut buffer.as_slice())
.expect("unable to deserialize message")
.expect("eof");
match returned {
PluginOutput::CallResponse(6, PluginCallResponse::Error(msg)) => {
assert_eq!(error, msg)
}
_ => panic!("decoded into wrong value: {returned:?}"),
}
}
#[test]
fn response_round_trip_error_none() {
let error = LabeledError::new("error");
let response = PluginCallResponse::Error(error.clone());
let output = PluginOutput::CallResponse(7, response);
let encoder = $encoder;
let mut buffer: Vec<u8> = Vec::new();
encoder
.encode(&output, &mut buffer)
.expect("unable to serialize message");
let returned = encoder
.decode(&mut buffer.as_slice())
.expect("unable to deserialize message")
.expect("eof");
match returned {
PluginOutput::CallResponse(7, PluginCallResponse::Error(msg)) => {
assert_eq!(error, msg)
}
_ => panic!("decoded into wrong value: {returned:?}"),
}
}
#[test]
fn input_round_trip_stream_data_list() {
let span = Span::new(12, 30);
let item = Value::int(1, span);
let stream_data = StreamData::List(item.clone());
let plugin_input = PluginInput::Data(0, stream_data);
let encoder = $encoder;
let mut buffer: Vec<u8> = Vec::new();
encoder
.encode(&plugin_input, &mut buffer)
.expect("unable to serialize message");
let returned = encoder
.decode(&mut buffer.as_slice())
.expect("unable to deserialize message")
.expect("eof");
match returned {
PluginInput::Data(id, StreamData::List(list_data)) => {
assert_eq!(0, id);
assert_eq!(item, list_data);
}
_ => panic!("decoded into wrong value: {returned:?}"),
}
}
#[test]
fn input_round_trip_stream_data_raw() {
let data = b"Hello world";
let stream_data = StreamData::Raw(Ok(data.to_vec()));
let plugin_input = PluginInput::Data(1, stream_data);
let encoder = $encoder;
let mut buffer: Vec<u8> = Vec::new();
encoder
.encode(&plugin_input, &mut buffer)
.expect("unable to serialize message");
let returned = encoder
.decode(&mut buffer.as_slice())
.expect("unable to deserialize message")
.expect("eof");
match returned {
PluginInput::Data(id, StreamData::Raw(bytes)) => {
assert_eq!(1, id);
match bytes {
Ok(bytes) => assert_eq!(data, &bytes[..]),
Err(err) => panic!("decoded into error variant: {err:?}"),
}
}
_ => panic!("decoded into wrong value: {returned:?}"),
}
}
#[test]
fn output_round_trip_stream_data_list() {
let span = Span::new(12, 30);
let item = Value::int(1, span);
let stream_data = StreamData::List(item.clone());
let plugin_output = PluginOutput::Data(4, stream_data);
let encoder = $encoder;
let mut buffer: Vec<u8> = Vec::new();
encoder
.encode(&plugin_output, &mut buffer)
.expect("unable to serialize message");
let returned = encoder
.decode(&mut buffer.as_slice())
.expect("unable to deserialize message")
.expect("eof");
match returned {
PluginOutput::Data(id, StreamData::List(list_data)) => {
assert_eq!(4, id);
assert_eq!(item, list_data);
}
_ => panic!("decoded into wrong value: {returned:?}"),
}
}
#[test]
fn output_round_trip_stream_data_raw() {
let data = b"Hello world";
let stream_data = StreamData::Raw(Ok(data.to_vec()));
let plugin_output = PluginOutput::Data(5, stream_data);
let encoder = $encoder;
let mut buffer: Vec<u8> = Vec::new();
encoder
.encode(&plugin_output, &mut buffer)
.expect("unable to serialize message");
let returned = encoder
.decode(&mut buffer.as_slice())
.expect("unable to deserialize message")
.expect("eof");
match returned {
PluginOutput::Data(id, StreamData::Raw(bytes)) => {
assert_eq!(5, id);
match bytes {
Ok(bytes) => assert_eq!(data, &bytes[..]),
Err(err) => panic!("decoded into error variant: {err:?}"),
}
}
_ => panic!("decoded into wrong value: {returned:?}"),
}
}
#[test]
fn output_round_trip_option() {
let plugin_output = PluginOutput::Option(PluginOption::GcDisabled(true));
let encoder = $encoder;
let mut buffer: Vec<u8> = Vec::new();
encoder
.encode(&plugin_output, &mut buffer)
.expect("unable to serialize message");
let returned = encoder
.decode(&mut buffer.as_slice())
.expect("unable to deserialize message")
.expect("eof");
match returned {
PluginOutput::Option(PluginOption::GcDisabled(disabled)) => {
assert!(disabled);
}
_ => panic!("decoded into wrong value: {returned:?}"),
}
}
};
}
pub(crate) use generate_tests;

View File

@ -0,0 +1,7 @@
mod sequence;
mod waitable;
mod with_custom_values_in;
pub use sequence::Sequence;
pub use waitable::*;
pub use with_custom_values_in::with_custom_values_in;

View File

@ -0,0 +1,64 @@
use nu_protocol::ShellError;
use std::sync::atomic::{AtomicUsize, Ordering::Relaxed};
/// Implements an atomically incrementing sequential series of numbers
#[derive(Debug, Default)]
pub struct Sequence(AtomicUsize);
impl Sequence {
/// Return the next available id from a sequence, returning an error on overflow
#[track_caller]
pub fn next(&self) -> Result<usize, ShellError> {
// It's totally safe to use Relaxed ordering here, as there aren't other memory operations
// that depend on this value having been set for safety
//
// We're only not using `fetch_add` so that we can check for overflow, as wrapping with the
// identifier would lead to a serious bug - however unlikely that is.
self.0
.fetch_update(Relaxed, Relaxed, |current| current.checked_add(1))
.map_err(|_| ShellError::NushellFailedHelp {
msg: "an accumulator for identifiers overflowed".into(),
help: format!("see {}", std::panic::Location::caller()),
})
}
}
#[test]
fn output_is_sequential() {
let sequence = Sequence::default();
for (expected, generated) in (0..1000).zip(std::iter::repeat_with(|| sequence.next())) {
assert_eq!(expected, generated.expect("error in sequence"));
}
}
#[test]
fn output_is_unique_even_under_contention() {
let sequence = Sequence::default();
std::thread::scope(|scope| {
// Spawn four threads, all advancing the sequence simultaneously
let threads = (0..4)
.map(|_| {
scope.spawn(|| {
(0..100000)
.map(|_| sequence.next())
.collect::<Result<Vec<_>, _>>()
})
})
.collect::<Vec<_>>();
// Collect all of the results into a single flat vec
let mut results = threads
.into_iter()
.flat_map(|thread| thread.join().expect("panicked").expect("error"))
.collect::<Vec<usize>>();
// Check uniqueness
results.sort();
let initial_length = results.len();
results.dedup();
let deduplicated_length = results.len();
assert_eq!(initial_length, deduplicated_length);
})
}

View File

@ -0,0 +1,181 @@
use std::sync::{
atomic::{AtomicBool, Ordering},
Arc, Condvar, Mutex, MutexGuard, PoisonError,
};
use nu_protocol::ShellError;
/// A shared container that may be empty, and allows threads to block until it has a value.
///
/// This side is read-only - use [`WaitableMut`] on threads that might write a value.
#[derive(Debug, Clone)]
pub struct Waitable<T: Clone + Send> {
shared: Arc<WaitableShared<T>>,
}
#[derive(Debug)]
pub struct WaitableMut<T: Clone + Send> {
shared: Arc<WaitableShared<T>>,
}
#[derive(Debug)]
struct WaitableShared<T: Clone + Send> {
is_set: AtomicBool,
mutex: Mutex<SyncState<T>>,
condvar: Condvar,
}
#[derive(Debug)]
struct SyncState<T: Clone + Send> {
writers: usize,
value: Option<T>,
}
#[track_caller]
fn fail_if_poisoned<'a, T>(
result: Result<MutexGuard<'a, T>, PoisonError<MutexGuard<'a, T>>>,
) -> Result<MutexGuard<'a, T>, ShellError> {
match result {
Ok(guard) => Ok(guard),
Err(_) => Err(ShellError::NushellFailedHelp {
msg: "Waitable mutex poisoned".into(),
help: std::panic::Location::caller().to_string(),
}),
}
}
impl<T: Clone + Send> WaitableMut<T> {
/// Create a new empty `WaitableMut`. Call [`.reader()`] to get [`Waitable`].
pub fn new() -> WaitableMut<T> {
WaitableMut {
shared: Arc::new(WaitableShared {
is_set: AtomicBool::new(false),
mutex: Mutex::new(SyncState {
writers: 1,
value: None,
}),
condvar: Condvar::new(),
}),
}
}
pub fn reader(&self) -> Waitable<T> {
Waitable {
shared: self.shared.clone(),
}
}
/// Set the value and let waiting threads know.
#[track_caller]
pub fn set(&self, value: T) -> Result<(), ShellError> {
let mut sync_state = fail_if_poisoned(self.shared.mutex.lock())?;
self.shared.is_set.store(true, Ordering::SeqCst);
sync_state.value = Some(value);
self.shared.condvar.notify_all();
Ok(())
}
}
impl<T: Clone + Send> Default for WaitableMut<T> {
fn default() -> Self {
Self::new()
}
}
impl<T: Clone + Send> Clone for WaitableMut<T> {
fn clone(&self) -> Self {
let shared = self.shared.clone();
shared
.mutex
.lock()
.expect("failed to lock mutex to increment writers")
.writers += 1;
WaitableMut { shared }
}
}
impl<T: Clone + Send> Drop for WaitableMut<T> {
fn drop(&mut self) {
// Decrement writers...
if let Ok(mut sync_state) = self.shared.mutex.lock() {
sync_state.writers = sync_state
.writers
.checked_sub(1)
.expect("would decrement writers below zero");
}
// and notify waiting threads so they have a chance to see it.
self.shared.condvar.notify_all();
}
}
impl<T: Clone + Send> Waitable<T> {
/// Wait for a value to be available and then clone it.
///
/// Returns `Ok(None)` if there are no writers left that could possibly place a value.
#[track_caller]
pub fn get(&self) -> Result<Option<T>, ShellError> {
let sync_state = fail_if_poisoned(self.shared.mutex.lock())?;
if let Some(value) = sync_state.value.clone() {
Ok(Some(value))
} else if sync_state.writers == 0 {
// There can't possibly be a value written, so no point in waiting.
Ok(None)
} else {
let sync_state = fail_if_poisoned(
self.shared
.condvar
.wait_while(sync_state, |g| g.writers > 0 && g.value.is_none()),
)?;
Ok(sync_state.value.clone())
}
}
/// Clone the value if one is available, but don't wait if not.
#[track_caller]
pub fn try_get(&self) -> Result<Option<T>, ShellError> {
let sync_state = fail_if_poisoned(self.shared.mutex.lock())?;
Ok(sync_state.value.clone())
}
/// Returns true if value is available.
#[track_caller]
pub fn is_set(&self) -> bool {
self.shared.is_set.load(Ordering::SeqCst)
}
}
#[test]
fn set_from_other_thread() -> Result<(), ShellError> {
let waitable_mut = WaitableMut::new();
let waitable = waitable_mut.reader();
assert!(!waitable.is_set());
std::thread::spawn(move || {
waitable_mut.set(42).expect("error on set");
});
assert_eq!(Some(42), waitable.get()?);
assert_eq!(Some(42), waitable.try_get()?);
assert!(waitable.is_set());
Ok(())
}
#[test]
fn dont_deadlock_if_waiting_without_writer() {
use std::time::Duration;
let (tx, rx) = std::sync::mpsc::channel();
let writer = WaitableMut::<()>::new();
let waitable = writer.reader();
// Ensure there are no writers
drop(writer);
std::thread::spawn(move || {
let _ = tx.send(waitable.get());
});
let result = rx
.recv_timeout(Duration::from_secs(10))
.expect("timed out")
.expect("error");
assert!(result.is_none());
}

View File

@ -0,0 +1,96 @@
use nu_protocol::{CustomValue, IntoSpanned, ShellError, Spanned, Value};
/// Do something with all [`CustomValue`]s recursively within a `Value`. This is not limited to
/// plugin custom values.
///
/// `LazyRecord`s will be collected to plain values for completeness.
pub fn with_custom_values_in<E>(
value: &mut Value,
mut f: impl FnMut(Spanned<&mut Box<dyn CustomValue>>) -> Result<(), E>,
) -> Result<(), E>
where
E: From<ShellError>,
{
value.recurse_mut(&mut |value| {
let span = value.span();
match value {
Value::Custom { val, .. } => {
// Operate on a CustomValue.
f(val.into_spanned(span))
}
// LazyRecord would be a problem for us, since it could return something else the
// next time, and we have to collect it anyway to serialize it. Collect it in place,
// and then use the result
Value::LazyRecord { val, .. } => {
*value = val.collect()?;
Ok(())
}
_ => Ok(()),
}
})
}
#[test]
fn find_custom_values() {
use nu_plugin_protocol::test_util::test_plugin_custom_value;
use nu_protocol::{engine::Closure, record, LazyRecord, Span};
#[derive(Debug, Clone)]
struct Lazy;
impl<'a> LazyRecord<'a> for Lazy {
fn column_names(&'a self) -> Vec<&'a str> {
vec!["custom", "plain"]
}
fn get_column_value(&self, column: &str) -> Result<Value, ShellError> {
Ok(match column {
"custom" => Value::test_custom_value(Box::new(test_plugin_custom_value())),
"plain" => Value::test_int(42),
_ => unimplemented!(),
})
}
fn span(&self) -> Span {
Span::test_data()
}
fn clone_value(&self, span: Span) -> Value {
Value::lazy_record(Box::new(self.clone()), span)
}
}
let mut cv = Value::test_custom_value(Box::new(test_plugin_custom_value()));
let mut value = Value::test_record(record! {
"bare" => cv.clone(),
"list" => Value::test_list(vec![
cv.clone(),
Value::test_int(4),
]),
"closure" => Value::test_closure(
Closure {
block_id: 0,
captures: vec![(0, cv.clone()), (1, Value::test_string("foo"))]
}
),
"lazy" => Value::test_lazy_record(Box::new(Lazy)),
});
// Do with_custom_values_in, and count the number of custom values found
let mut found = 0;
with_custom_values_in::<ShellError>(&mut value, |_| {
found += 1;
Ok(())
})
.expect("error");
assert_eq!(4, found, "found in value");
// Try it on bare custom value too
found = 0;
with_custom_values_in::<ShellError>(&mut cv, |_| {
found += 1;
Ok(())
})
.expect("error");
assert_eq!(1, found, "bare custom value didn't work");
}