Split the plugin crate (#12563)

# Description

This breaks `nu-plugin` up into four crates:

- `nu-plugin-protocol`: just the type definitions for the protocol, no
I/O. If someone wanted to wire up something more bare metal, maybe for
async I/O, they could use this.
- `nu-plugin-core`: the shared stuff between engine/plugin. Less stable
interface.
- `nu-plugin-engine`: everything required for the engine to talk to
plugins. Less stable interface.
- `nu-plugin`: everything required for the plugin to talk to the engine,
what plugin developers use. Should be the most stable interface.

No changes are made to the interface exposed by `nu-plugin` - it should
all still be there. Re-exports from `nu-plugin-protocol` or
`nu-plugin-core` are used as required. Plugins shouldn't ever have to
use those crates directly.

This should be somewhat faster to compile as `nu-plugin-engine` and
`nu-plugin` can compile in parallel, and the engine doesn't need
`nu-plugin` and plugins don't need `nu-plugin-engine` (except for test
support), so that should reduce what needs to be compiled too.

The only significant change here other than splitting stuff up was to
break the `source` out of `PluginCustomValue` and create a new
`PluginCustomValueWithSource` type that contains that instead. One bonus
of that is we get rid of the option and it's now more type-safe, but it
also means that the logic for that stuff (actually running the plugin
for custom value ops) can live entirely within the `nu-plugin-engine`
crate.

# User-Facing Changes
- New crates.
- Added `local-socket` feature for `nu` to try to make it possible to
compile without that support if needed.

# Tests + Formatting
- 🟢 `toolkit fmt`
- 🟢 `toolkit clippy`
- 🟢 `toolkit test`
- 🟢 `toolkit test stdlib`
This commit is contained in:
Devyn Cairns
2024-04-27 10:08:12 -07:00
committed by GitHub
parent 884d5312bb
commit 0c4d5330ee
74 changed files with 3514 additions and 3110 deletions

View File

@ -0,0 +1,84 @@
use std::ffi::OsString;
#[cfg(test)]
pub(crate) mod tests;
/// Generate a name to be used for a local socket specific to this `nu` process, described by the
/// given `unique_id`, which should be unique to the purpose of the socket.
///
/// On Unix, this is a path, which should generally be 100 characters or less for compatibility. On
/// Windows, this is a name within the `\\.\pipe` namespace.
#[cfg(unix)]
pub fn make_local_socket_name(unique_id: &str) -> OsString {
// Prefer to put it in XDG_RUNTIME_DIR if set, since that's user-local
let mut base = if let Some(runtime_dir) = std::env::var_os("XDG_RUNTIME_DIR") {
std::path::PathBuf::from(runtime_dir)
} else {
// Use std::env::temp_dir() for portability, especially since on Android this is probably
// not `/tmp`
std::env::temp_dir()
};
let socket_name = format!("nu.{}.{}.sock", std::process::id(), unique_id);
base.push(socket_name);
base.into()
}
/// Generate a name to be used for a local socket specific to this `nu` process, described by the
/// given `unique_id`, which should be unique to the purpose of the socket.
///
/// On Unix, this is a path, which should generally be 100 characters or less for compatibility. On
/// Windows, this is a name within the `\\.\pipe` namespace.
#[cfg(windows)]
pub fn make_local_socket_name(unique_id: &str) -> OsString {
format!("nu.{}.{}", std::process::id(), unique_id).into()
}
/// Determine if the error is just due to the listener not being ready yet in asynchronous mode
#[cfg(not(windows))]
pub fn is_would_block_err(err: &std::io::Error) -> bool {
err.kind() == std::io::ErrorKind::WouldBlock
}
/// Determine if the error is just due to the listener not being ready yet in asynchronous mode
#[cfg(windows)]
pub fn is_would_block_err(err: &std::io::Error) -> bool {
err.kind() == std::io::ErrorKind::WouldBlock
|| err.raw_os_error().is_some_and(|e| {
// Windows returns this error when trying to accept a pipe in non-blocking mode
e as i64 == windows::Win32::Foundation::ERROR_PIPE_LISTENING.0 as i64
})
}
/// Wraps the `interprocess` local socket stream for greater compatibility
#[derive(Debug)]
pub struct LocalSocketStream(pub interprocess::local_socket::LocalSocketStream);
impl From<interprocess::local_socket::LocalSocketStream> for LocalSocketStream {
fn from(value: interprocess::local_socket::LocalSocketStream) -> Self {
LocalSocketStream(value)
}
}
impl std::io::Read for LocalSocketStream {
fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
self.0.read(buf)
}
}
impl std::io::Write for LocalSocketStream {
fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
self.0.write(buf)
}
fn flush(&mut self) -> std::io::Result<()> {
// We don't actually flush the underlying socket on Windows. The flush operation on a
// Windows named pipe actually synchronizes with read on the other side, and won't finish
// until the other side is empty. This isn't how most of our other I/O methods work, so we
// just won't do it. The BufWriter above this will have still made a write call with the
// contents of the buffer, which should be good enough.
if cfg!(not(windows)) {
self.0.flush()?;
}
Ok(())
}
}

View File

@ -0,0 +1,19 @@
use super::make_local_socket_name;
#[test]
fn local_socket_path_contains_pid() {
let name = make_local_socket_name("test-string")
.to_string_lossy()
.into_owned();
println!("{}", name);
assert!(name.to_string().contains(&std::process::id().to_string()));
}
#[test]
fn local_socket_path_contains_provided_name() {
let name = make_local_socket_name("test-string")
.to_string_lossy()
.into_owned();
println!("{}", name);
assert!(name.to_string().contains("test-string"));
}

View File

@ -0,0 +1,249 @@
use std::ffi::OsStr;
use std::io::{Stdin, Stdout};
use std::process::{Child, ChildStdin, ChildStdout, Command, Stdio};
use nu_protocol::ShellError;
#[cfg(feature = "local-socket")]
use interprocess::local_socket::LocalSocketListener;
#[cfg(feature = "local-socket")]
mod local_socket;
#[cfg(feature = "local-socket")]
use local_socket::*;
/// The type of communication used between the plugin and the engine.
///
/// `Stdio` is required to be supported by all plugins, and is attempted initially. If the
/// `local-socket` feature is enabled and the plugin supports it, `LocalSocket` may be attempted.
///
/// Local socket communication has the benefit of not tying up stdio, so it's more compatible with
/// plugins that want to take user input from the terminal in some way.
#[derive(Debug, Clone)]
pub enum CommunicationMode {
/// Communicate using `stdin` and `stdout`.
Stdio,
/// Communicate using an operating system-specific local socket.
#[cfg(feature = "local-socket")]
LocalSocket(std::ffi::OsString),
}
impl CommunicationMode {
/// Generate a new local socket communication mode based on the given plugin exe path.
#[cfg(feature = "local-socket")]
pub fn local_socket(plugin_exe: &std::path::Path) -> CommunicationMode {
use std::hash::{Hash, Hasher};
use std::time::SystemTime;
// Generate the unique ID based on the plugin path and the current time. The actual
// algorithm here is not very important, we just want this to be relatively unique very
// briefly. Using the default hasher in the stdlib means zero extra dependencies.
let mut hasher = std::collections::hash_map::DefaultHasher::new();
plugin_exe.hash(&mut hasher);
SystemTime::now().hash(&mut hasher);
let unique_id = format!("{:016x}", hasher.finish());
CommunicationMode::LocalSocket(make_local_socket_name(&unique_id))
}
pub fn args(&self) -> Vec<&OsStr> {
match self {
CommunicationMode::Stdio => vec![OsStr::new("--stdio")],
#[cfg(feature = "local-socket")]
CommunicationMode::LocalSocket(path) => {
vec![OsStr::new("--local-socket"), path.as_os_str()]
}
}
}
pub fn setup_command_io(&self, command: &mut Command) {
match self {
CommunicationMode::Stdio => {
// Both stdout and stdin are piped so we can receive information from the plugin
command.stdin(Stdio::piped());
command.stdout(Stdio::piped());
}
#[cfg(feature = "local-socket")]
CommunicationMode::LocalSocket(_) => {
// Stdio can be used by the plugin to talk to the terminal in local socket mode,
// which is the big benefit
command.stdin(Stdio::inherit());
command.stdout(Stdio::inherit());
}
}
}
pub fn serve(&self) -> Result<PreparedServerCommunication, ShellError> {
match self {
// Nothing to set up for stdio - we just take it from the child.
CommunicationMode::Stdio => Ok(PreparedServerCommunication::Stdio),
// For sockets: we need to create the server so that the child won't fail to connect.
#[cfg(feature = "local-socket")]
CommunicationMode::LocalSocket(name) => {
let listener = LocalSocketListener::bind(name.as_os_str()).map_err(|err| {
ShellError::IOError {
msg: format!("failed to open socket for plugin: {err}"),
}
})?;
Ok(PreparedServerCommunication::LocalSocket {
name: name.clone(),
listener,
})
}
}
}
pub fn connect_as_client(&self) -> Result<ClientCommunicationIo, ShellError> {
match self {
CommunicationMode::Stdio => Ok(ClientCommunicationIo::Stdio(
std::io::stdin(),
std::io::stdout(),
)),
#[cfg(feature = "local-socket")]
CommunicationMode::LocalSocket(name) => {
// Connect to the specified socket.
let get_socket = || {
use interprocess::local_socket as ls;
ls::LocalSocketStream::connect(name.as_os_str())
.map_err(|err| ShellError::IOError {
msg: format!("failed to connect to socket: {err}"),
})
.map(LocalSocketStream::from)
};
// Reverse order from the server: read in, write out
let read_in = get_socket()?;
let write_out = get_socket()?;
Ok(ClientCommunicationIo::LocalSocket { read_in, write_out })
}
}
}
}
/// The result of [`CommunicationMode::serve()`], which acts as an intermediate stage for
/// communication modes that require some kind of socket binding to occur before the client process
/// can be started. Call [`.connect()`] once the client process has been started.
///
/// The socket may be cleaned up on `Drop` if applicable.
pub enum PreparedServerCommunication {
/// Will take stdin and stdout from the process on [`.connect()`].
Stdio,
/// Contains the listener to accept connections on. On Unix, the socket is unlinked on `Drop`.
#[cfg(feature = "local-socket")]
LocalSocket {
#[cfg_attr(windows, allow(dead_code))] // not used on Windows
name: std::ffi::OsString,
listener: LocalSocketListener,
},
}
impl PreparedServerCommunication {
pub fn connect(&self, child: &mut Child) -> Result<ServerCommunicationIo, ShellError> {
match self {
PreparedServerCommunication::Stdio => {
let stdin = child
.stdin
.take()
.ok_or_else(|| ShellError::PluginFailedToLoad {
msg: "Plugin missing stdin writer".into(),
})?;
let stdout = child
.stdout
.take()
.ok_or_else(|| ShellError::PluginFailedToLoad {
msg: "Plugin missing stdout writer".into(),
})?;
Ok(ServerCommunicationIo::Stdio(stdin, stdout))
}
#[cfg(feature = "local-socket")]
PreparedServerCommunication::LocalSocket { listener, .. } => {
use std::time::{Duration, Instant};
const RETRY_PERIOD: Duration = Duration::from_millis(1);
const TIMEOUT: Duration = Duration::from_secs(10);
let start = Instant::now();
// Use a loop to try to get two clients from the listener: one for read (the plugin
// output) and one for write (the plugin input)
listener.set_nonblocking(true)?;
let mut get_socket = || {
let mut result = None;
while let Ok(None) = child.try_wait() {
match listener.accept() {
Ok(stream) => {
// Success! But make sure the stream is in blocking mode.
stream.set_nonblocking(false)?;
result = Some(stream);
break;
}
Err(err) => {
if !is_would_block_err(&err) {
// `WouldBlock` is ok, just means it's not ready yet, but some other
// kind of error should be reported
return Err(err.into());
}
}
}
if Instant::now().saturating_duration_since(start) > TIMEOUT {
return Err(ShellError::PluginFailedToLoad {
msg: "Plugin timed out while waiting to connect to socket".into(),
});
} else {
std::thread::sleep(RETRY_PERIOD);
}
}
if let Some(stream) = result {
Ok(LocalSocketStream(stream))
} else {
// The process may have exited
Err(ShellError::PluginFailedToLoad {
msg: "Plugin exited without connecting".into(),
})
}
};
// Input stream always comes before output
let write_in = get_socket()?;
let read_out = get_socket()?;
Ok(ServerCommunicationIo::LocalSocket { read_out, write_in })
}
}
}
}
impl Drop for PreparedServerCommunication {
fn drop(&mut self) {
match self {
#[cfg(all(unix, feature = "local-socket"))]
PreparedServerCommunication::LocalSocket { name: path, .. } => {
// Just try to remove the socket file, it's ok if this fails
let _ = std::fs::remove_file(path);
}
_ => (),
}
}
}
/// The required streams for communication from the engine side, i.e. the server in socket terms.
pub enum ServerCommunicationIo {
Stdio(ChildStdin, ChildStdout),
#[cfg(feature = "local-socket")]
LocalSocket {
read_out: LocalSocketStream,
write_in: LocalSocketStream,
},
}
/// The required streams for communication from the plugin side, i.e. the client in socket terms.
pub enum ClientCommunicationIo {
Stdio(Stdin, Stdout),
#[cfg(feature = "local-socket")]
LocalSocket {
read_in: LocalSocketStream,
write_out: LocalSocketStream,
},
}

View File

@ -0,0 +1,453 @@
//! Implements the stream multiplexing interface for both the plugin side and the engine side.
use nu_plugin_protocol::{
ExternalStreamInfo, ListStreamInfo, PipelineDataHeader, RawStreamInfo, StreamMessage,
};
use nu_protocol::{ListStream, PipelineData, RawStream, ShellError};
use std::{
io::Write,
sync::{
atomic::{AtomicBool, Ordering::Relaxed},
Arc, Mutex,
},
thread,
};
pub mod stream;
use crate::{util::Sequence, Encoder};
use self::stream::{StreamManager, StreamManagerHandle, StreamWriter, WriteStreamMessage};
pub mod test_util;
#[cfg(test)]
mod tests;
/// The maximum number of list stream values to send without acknowledgement. This should be tuned
/// with consideration for memory usage.
const LIST_STREAM_HIGH_PRESSURE: i32 = 100;
/// The maximum number of raw stream buffers to send without acknowledgement. This should be tuned
/// with consideration for memory usage.
const RAW_STREAM_HIGH_PRESSURE: i32 = 50;
/// Read input/output from the stream.
pub trait PluginRead<T> {
/// Returns `Ok(None)` on end of stream.
fn read(&mut self) -> Result<Option<T>, ShellError>;
}
impl<R, E, T> PluginRead<T> for (R, E)
where
R: std::io::BufRead,
E: Encoder<T>,
{
fn read(&mut self) -> Result<Option<T>, ShellError> {
self.1.decode(&mut self.0)
}
}
impl<R, T> PluginRead<T> for &mut R
where
R: PluginRead<T>,
{
fn read(&mut self) -> Result<Option<T>, ShellError> {
(**self).read()
}
}
/// Write input/output to the stream.
///
/// The write should be atomic, without interference from other threads.
pub trait PluginWrite<T>: Send + Sync {
fn write(&self, data: &T) -> Result<(), ShellError>;
/// Flush any internal buffers, if applicable.
fn flush(&self) -> Result<(), ShellError>;
/// True if this output is stdout, so that plugins can avoid using stdout for their own purpose
fn is_stdout(&self) -> bool {
false
}
}
impl<E, T> PluginWrite<T> for (std::io::Stdout, E)
where
E: Encoder<T>,
{
fn write(&self, data: &T) -> Result<(), ShellError> {
let mut lock = self.0.lock();
self.1.encode(data, &mut lock)
}
fn flush(&self) -> Result<(), ShellError> {
self.0.lock().flush().map_err(|err| ShellError::IOError {
msg: err.to_string(),
})
}
fn is_stdout(&self) -> bool {
true
}
}
impl<W, E, T> PluginWrite<T> for (Mutex<W>, E)
where
W: std::io::Write + Send,
E: Encoder<T>,
{
fn write(&self, data: &T) -> Result<(), ShellError> {
let mut lock = self.0.lock().map_err(|_| ShellError::NushellFailed {
msg: "writer mutex poisoned".into(),
})?;
self.1.encode(data, &mut *lock)
}
fn flush(&self) -> Result<(), ShellError> {
let mut lock = self.0.lock().map_err(|_| ShellError::NushellFailed {
msg: "writer mutex poisoned".into(),
})?;
lock.flush().map_err(|err| ShellError::IOError {
msg: err.to_string(),
})
}
}
impl<W, T> PluginWrite<T> for &W
where
W: PluginWrite<T>,
{
fn write(&self, data: &T) -> Result<(), ShellError> {
(**self).write(data)
}
fn flush(&self) -> Result<(), ShellError> {
(**self).flush()
}
fn is_stdout(&self) -> bool {
(**self).is_stdout()
}
}
/// An interface manager handles I/O and state management for communication between a plugin and
/// the engine. See `PluginInterfaceManager` in `nu-plugin-engine` for communication from the engine
/// side to a plugin, or `EngineInterfaceManager` in `nu-plugin` for communication from the plugin
/// side to the engine.
///
/// There is typically one [`InterfaceManager`] consuming input from a background thread, and
/// managing shared state.
pub trait InterfaceManager {
/// The corresponding interface type.
type Interface: Interface + 'static;
/// The input message type.
type Input;
/// Make a new interface that communicates with this [`InterfaceManager`].
fn get_interface(&self) -> Self::Interface;
/// Consume an input message.
///
/// When implementing, call [`.consume_stream_message()`] for any encapsulated
/// [`StreamMessage`]s received.
fn consume(&mut self, input: Self::Input) -> Result<(), ShellError>;
/// Get the [`StreamManager`] for handling operations related to stream messages.
fn stream_manager(&self) -> &StreamManager;
/// Prepare [`PipelineData`] after reading. This is called by `read_pipeline_data()` as
/// a hook so that values that need special handling can be taken care of.
fn prepare_pipeline_data(&self, data: PipelineData) -> Result<PipelineData, ShellError>;
/// Consume an input stream message.
///
/// This method is provided for implementors to use.
fn consume_stream_message(&mut self, message: StreamMessage) -> Result<(), ShellError> {
self.stream_manager().handle_message(message)
}
/// Generate `PipelineData` for reading a stream, given a [`PipelineDataHeader`] that was
/// received from the other side.
///
/// This method is provided for implementors to use.
fn read_pipeline_data(
&self,
header: PipelineDataHeader,
ctrlc: Option<&Arc<AtomicBool>>,
) -> Result<PipelineData, ShellError> {
self.prepare_pipeline_data(match header {
PipelineDataHeader::Empty => PipelineData::Empty,
PipelineDataHeader::Value(value) => PipelineData::Value(value, None),
PipelineDataHeader::ListStream(info) => {
let handle = self.stream_manager().get_handle();
let reader = handle.read_stream(info.id, self.get_interface())?;
PipelineData::ListStream(ListStream::from_stream(reader, ctrlc.cloned()), None)
}
PipelineDataHeader::ExternalStream(info) => {
let handle = self.stream_manager().get_handle();
let span = info.span;
let new_raw_stream = |raw_info: RawStreamInfo| {
let reader = handle.read_stream(raw_info.id, self.get_interface())?;
let mut stream =
RawStream::new(Box::new(reader), ctrlc.cloned(), span, raw_info.known_size);
stream.is_binary = raw_info.is_binary;
Ok::<_, ShellError>(stream)
};
PipelineData::ExternalStream {
stdout: info.stdout.map(new_raw_stream).transpose()?,
stderr: info.stderr.map(new_raw_stream).transpose()?,
exit_code: info
.exit_code
.map(|list_info| {
handle
.read_stream(list_info.id, self.get_interface())
.map(|reader| ListStream::from_stream(reader, ctrlc.cloned()))
})
.transpose()?,
span: info.span,
metadata: None,
trim_end_newline: info.trim_end_newline,
}
}
})
}
}
/// An interface provides an API for communicating with a plugin or the engine and facilitates
/// stream I/O. See `PluginInterface` in `nu-plugin-engine` for the API from the engine side to a
/// plugin, or `EngineInterface` in `nu-plugin` for the API from the plugin side to the engine.
///
/// There can be multiple copies of the interface managed by a single [`InterfaceManager`].
pub trait Interface: Clone + Send {
/// The output message type, which must be capable of encapsulating a [`StreamMessage`].
type Output: From<StreamMessage>;
/// Any context required to construct [`PipelineData`]. Can be `()` if not needed.
type DataContext;
/// Write an output message.
fn write(&self, output: Self::Output) -> Result<(), ShellError>;
/// Flush the output buffer, so messages are visible to the other side.
fn flush(&self) -> Result<(), ShellError>;
/// Get the sequence for generating new [`StreamId`](nu_plugin_protocol::StreamId)s.
fn stream_id_sequence(&self) -> &Sequence;
/// Get the [`StreamManagerHandle`] for doing stream operations.
fn stream_manager_handle(&self) -> &StreamManagerHandle;
/// Prepare [`PipelineData`] to be written. This is called by `init_write_pipeline_data()` as
/// a hook so that values that need special handling can be taken care of.
fn prepare_pipeline_data(
&self,
data: PipelineData,
context: &Self::DataContext,
) -> Result<PipelineData, ShellError>;
/// Initialize a write for [`PipelineData`]. This returns two parts: the header, which can be
/// embedded in the particular message that references the stream, and a writer, which will
/// write out all of the data in the pipeline when `.write()` is called.
///
/// Note that not all [`PipelineData`] starts a stream. You should call `write()` anyway, as
/// it will automatically handle this case.
///
/// This method is provided for implementors to use.
fn init_write_pipeline_data(
&self,
data: PipelineData,
context: &Self::DataContext,
) -> Result<(PipelineDataHeader, PipelineDataWriter<Self>), ShellError> {
// Allocate a stream id and a writer
let new_stream = |high_pressure_mark: i32| {
// Get a free stream id
let id = self.stream_id_sequence().next()?;
// Create the writer
let writer =
self.stream_manager_handle()
.write_stream(id, self.clone(), high_pressure_mark)?;
Ok::<_, ShellError>((id, writer))
};
match self.prepare_pipeline_data(data, context)? {
PipelineData::Value(value, _) => {
Ok((PipelineDataHeader::Value(value), PipelineDataWriter::None))
}
PipelineData::Empty => Ok((PipelineDataHeader::Empty, PipelineDataWriter::None)),
PipelineData::ListStream(stream, _) => {
let (id, writer) = new_stream(LIST_STREAM_HIGH_PRESSURE)?;
Ok((
PipelineDataHeader::ListStream(ListStreamInfo { id }),
PipelineDataWriter::ListStream(writer, stream),
))
}
PipelineData::ExternalStream {
stdout,
stderr,
exit_code,
span,
metadata: _,
trim_end_newline,
} => {
// Create the writers and stream ids
let stdout_stream = stdout
.is_some()
.then(|| new_stream(RAW_STREAM_HIGH_PRESSURE))
.transpose()?;
let stderr_stream = stderr
.is_some()
.then(|| new_stream(RAW_STREAM_HIGH_PRESSURE))
.transpose()?;
let exit_code_stream = exit_code
.is_some()
.then(|| new_stream(LIST_STREAM_HIGH_PRESSURE))
.transpose()?;
// Generate the header, with the stream ids
let header = PipelineDataHeader::ExternalStream(ExternalStreamInfo {
span,
stdout: stdout
.as_ref()
.zip(stdout_stream.as_ref())
.map(|(stream, (id, _))| RawStreamInfo::new(*id, stream)),
stderr: stderr
.as_ref()
.zip(stderr_stream.as_ref())
.map(|(stream, (id, _))| RawStreamInfo::new(*id, stream)),
exit_code: exit_code_stream
.as_ref()
.map(|&(id, _)| ListStreamInfo { id }),
trim_end_newline,
});
// Collect the writers
let writer = PipelineDataWriter::ExternalStream {
stdout: stdout_stream.map(|(_, writer)| writer).zip(stdout),
stderr: stderr_stream.map(|(_, writer)| writer).zip(stderr),
exit_code: exit_code_stream.map(|(_, writer)| writer).zip(exit_code),
};
Ok((header, writer))
}
}
}
}
impl<T> WriteStreamMessage for T
where
T: Interface,
{
fn write_stream_message(&mut self, msg: StreamMessage) -> Result<(), ShellError> {
self.write(msg.into())
}
fn flush(&mut self) -> Result<(), ShellError> {
<Self as Interface>::flush(self)
}
}
/// Completes the write operation for a [`PipelineData`]. You must call
/// [`PipelineDataWriter::write()`] to write all of the data contained within the streams.
#[derive(Default)]
#[must_use]
pub enum PipelineDataWriter<W: WriteStreamMessage> {
#[default]
None,
ListStream(StreamWriter<W>, ListStream),
ExternalStream {
stdout: Option<(StreamWriter<W>, RawStream)>,
stderr: Option<(StreamWriter<W>, RawStream)>,
exit_code: Option<(StreamWriter<W>, ListStream)>,
},
}
impl<W> PipelineDataWriter<W>
where
W: WriteStreamMessage + Send + 'static,
{
/// Write all of the data in each of the streams. This method waits for completion.
pub fn write(self) -> Result<(), ShellError> {
match self {
// If no stream was contained in the PipelineData, do nothing.
PipelineDataWriter::None => Ok(()),
// Write a list stream.
PipelineDataWriter::ListStream(mut writer, stream) => {
writer.write_all(stream)?;
Ok(())
}
// Write all three possible streams of an ExternalStream on separate threads.
PipelineDataWriter::ExternalStream {
stdout,
stderr,
exit_code,
} => {
thread::scope(|scope| {
let stderr_thread = stderr
.map(|(mut writer, stream)| {
thread::Builder::new()
.name("plugin stderr writer".into())
.spawn_scoped(scope, move || {
writer.write_all(raw_stream_iter(stream))
})
})
.transpose()?;
let exit_code_thread = exit_code
.map(|(mut writer, stream)| {
thread::Builder::new()
.name("plugin exit_code writer".into())
.spawn_scoped(scope, move || writer.write_all(stream))
})
.transpose()?;
// Optimize for stdout: if only stdout is present, don't spawn any other
// threads.
if let Some((mut writer, stream)) = stdout {
writer.write_all(raw_stream_iter(stream))?;
}
let panicked = |thread_name: &str| {
Err(ShellError::NushellFailed {
msg: format!(
"{thread_name} thread panicked in PipelineDataWriter::write"
),
})
};
stderr_thread
.map(|t| t.join().unwrap_or_else(|_| panicked("stderr")))
.transpose()?;
exit_code_thread
.map(|t| t.join().unwrap_or_else(|_| panicked("exit_code")))
.transpose()?;
Ok(())
})
}
}
}
/// Write all of the data in each of the streams. This method returns immediately; any necessary
/// write will happen in the background. If a thread was spawned, its handle is returned.
pub fn write_background(
self,
) -> Result<Option<thread::JoinHandle<Result<(), ShellError>>>, ShellError> {
match self {
PipelineDataWriter::None => Ok(None),
_ => Ok(Some(
thread::Builder::new()
.name("plugin stream background writer".into())
.spawn(move || {
let result = self.write();
if let Err(ref err) = result {
// Assume that the background thread error probably won't be handled and log it
// here just in case.
log::warn!("Error while writing pipeline in background: {err}");
}
result
})?,
)),
}
}
}
/// Custom iterator for [`RawStream`] that respects ctrlc, but still has binary chunks
fn raw_stream_iter(stream: RawStream) -> impl Iterator<Item = Result<Vec<u8>, ShellError>> {
let ctrlc = stream.ctrlc;
stream
.stream
.take_while(move |_| ctrlc.as_ref().map(|b| !b.load(Relaxed)).unwrap_or(true))
}

View File

@ -0,0 +1,628 @@
use nu_plugin_protocol::{StreamData, StreamId, StreamMessage};
use nu_protocol::{ShellError, Span, Value};
use std::{
collections::{btree_map, BTreeMap},
iter::FusedIterator,
marker::PhantomData,
sync::{mpsc, Arc, Condvar, Mutex, MutexGuard, Weak},
};
#[cfg(test)]
mod tests;
/// Receives messages from a stream read from input by a [`StreamManager`].
///
/// The receiver reads for messages of type `Result<Option<StreamData>, ShellError>` from the
/// channel, which is managed by a [`StreamManager`]. Signalling for end-of-stream is explicit
/// through `Ok(Some)`.
///
/// Failing to receive is an error. When end-of-stream is received, the `receiver` is set to `None`
/// and all further calls to `next()` return `None`.
///
/// The type `T` must implement [`FromShellError`], so that errors in the stream can be represented,
/// and `TryFrom<StreamData>` to convert it to the correct type.
///
/// For each message read, it sends [`StreamMessage::Ack`] to the writer. When dropped,
/// it sends [`StreamMessage::Drop`].
#[derive(Debug)]
pub struct StreamReader<T, W>
where
W: WriteStreamMessage,
{
id: StreamId,
receiver: Option<mpsc::Receiver<Result<Option<StreamData>, ShellError>>>,
writer: W,
/// Iterator requires the item type to be fixed, so we have to keep it as part of the type,
/// even though we're actually receiving dynamic data.
marker: PhantomData<fn() -> T>,
}
impl<T, W> StreamReader<T, W>
where
T: TryFrom<StreamData, Error = ShellError>,
W: WriteStreamMessage,
{
/// Create a new StreamReader from parts
fn new(
id: StreamId,
receiver: mpsc::Receiver<Result<Option<StreamData>, ShellError>>,
writer: W,
) -> StreamReader<T, W> {
StreamReader {
id,
receiver: Some(receiver),
writer,
marker: PhantomData,
}
}
/// Receive a message from the channel, or return an error if:
///
/// * the channel couldn't be received from
/// * an error was sent on the channel
/// * the message received couldn't be converted to `T`
pub fn recv(&mut self) -> Result<Option<T>, ShellError> {
let connection_lost = || ShellError::GenericError {
error: "Stream ended unexpectedly".into(),
msg: "connection lost before explicit end of stream".into(),
span: None,
help: None,
inner: vec![],
};
if let Some(ref rx) = self.receiver {
// Try to receive a message first
let msg = match rx.try_recv() {
Ok(msg) => msg?,
Err(mpsc::TryRecvError::Empty) => {
// The receiver doesn't have any messages waiting for us. It's possible that the
// other side hasn't seen our acknowledgements. Let's flush the writer and then
// wait
self.writer.flush()?;
rx.recv().map_err(|_| connection_lost())??
}
Err(mpsc::TryRecvError::Disconnected) => return Err(connection_lost()),
};
if let Some(data) = msg {
// Acknowledge the message
self.writer
.write_stream_message(StreamMessage::Ack(self.id))?;
// Try to convert it into the correct type
Ok(Some(data.try_into()?))
} else {
// Remove the receiver, so that future recv() calls always return Ok(None)
self.receiver = None;
Ok(None)
}
} else {
// Closed already
Ok(None)
}
}
}
impl<T, W> Iterator for StreamReader<T, W>
where
T: FromShellError + TryFrom<StreamData, Error = ShellError>,
W: WriteStreamMessage,
{
type Item = T;
fn next(&mut self) -> Option<T> {
// Converting the error to the value here makes the implementation a lot easier
match self.recv() {
Ok(option) => option,
Err(err) => {
// Drop the receiver so we don't keep returning errors
self.receiver = None;
Some(T::from_shell_error(err))
}
}
}
}
// Guaranteed not to return anything after the end
impl<T, W> FusedIterator for StreamReader<T, W>
where
T: FromShellError + TryFrom<StreamData, Error = ShellError>,
W: WriteStreamMessage,
{
}
impl<T, W> Drop for StreamReader<T, W>
where
W: WriteStreamMessage,
{
fn drop(&mut self) {
if let Err(err) = self
.writer
.write_stream_message(StreamMessage::Drop(self.id))
.and_then(|_| self.writer.flush())
{
log::warn!("Failed to send message to drop stream: {err}");
}
}
}
/// Values that can contain a `ShellError` to signal an error has occurred.
pub trait FromShellError {
fn from_shell_error(err: ShellError) -> Self;
}
// For List streams.
impl FromShellError for Value {
fn from_shell_error(err: ShellError) -> Self {
Value::error(err, Span::unknown())
}
}
// For Raw streams, mostly.
impl<T> FromShellError for Result<T, ShellError> {
fn from_shell_error(err: ShellError) -> Self {
Err(err)
}
}
/// Writes messages to a stream, with flow control.
///
/// The `signal` contained
#[derive(Debug)]
pub struct StreamWriter<W: WriteStreamMessage> {
id: StreamId,
signal: Arc<StreamWriterSignal>,
writer: W,
ended: bool,
}
impl<W> StreamWriter<W>
where
W: WriteStreamMessage,
{
fn new(id: StreamId, signal: Arc<StreamWriterSignal>, writer: W) -> StreamWriter<W> {
StreamWriter {
id,
signal,
writer,
ended: false,
}
}
/// Check if the stream was dropped from the other end. Recommended to do this before calling
/// [`.write()`], especially in a loop.
pub fn is_dropped(&self) -> Result<bool, ShellError> {
self.signal.is_dropped()
}
/// Write a single piece of data to the stream.
///
/// Error if something failed with the write, or if [`.end()`] was already called
/// previously.
pub fn write(&mut self, data: impl Into<StreamData>) -> Result<(), ShellError> {
if !self.ended {
self.writer
.write_stream_message(StreamMessage::Data(self.id, data.into()))?;
// This implements flow control, so we don't write too many messages:
if !self.signal.notify_sent()? {
// Flush the output, and then wait for acknowledgements
self.writer.flush()?;
self.signal.wait_for_drain()
} else {
Ok(())
}
} else {
Err(ShellError::GenericError {
error: "Wrote to a stream after it ended".into(),
msg: format!(
"tried to write to stream {} after it was already ended",
self.id
),
span: None,
help: Some("this may be a bug in the nu-plugin crate".into()),
inner: vec![],
})
}
}
/// Write a full iterator to the stream. Note that this doesn't end the stream, so you should
/// still call [`.end()`].
///
/// If the stream is dropped from the other end, the iterator will not be fully consumed, and
/// writing will terminate.
///
/// Returns `Ok(true)` if the iterator was fully consumed, or `Ok(false)` if a drop interrupted
/// the stream from the other side.
pub fn write_all<T>(&mut self, data: impl IntoIterator<Item = T>) -> Result<bool, ShellError>
where
T: Into<StreamData>,
{
// Check before starting
if self.is_dropped()? {
return Ok(false);
}
for item in data {
// Check again after each item is consumed from the iterator, just in case the iterator
// takes a while to produce a value
if self.is_dropped()? {
return Ok(false);
}
self.write(item)?;
}
Ok(true)
}
/// End the stream. Recommend doing this instead of relying on `Drop` so that you can catch the
/// error.
pub fn end(&mut self) -> Result<(), ShellError> {
if !self.ended {
// Set the flag first so we don't double-report in the Drop
self.ended = true;
self.writer
.write_stream_message(StreamMessage::End(self.id))?;
self.writer.flush()
} else {
Ok(())
}
}
}
impl<W> Drop for StreamWriter<W>
where
W: WriteStreamMessage,
{
fn drop(&mut self) {
// Make sure we ended the stream
if let Err(err) = self.end() {
log::warn!("Error while ending stream in Drop for StreamWriter: {err}");
}
}
}
/// Stores stream state for a writer, and can be blocked on to wait for messages to be acknowledged.
/// A key part of managing stream lifecycle and flow control.
#[derive(Debug)]
pub struct StreamWriterSignal {
mutex: Mutex<StreamWriterSignalState>,
change_cond: Condvar,
}
#[derive(Debug)]
pub struct StreamWriterSignalState {
/// Stream has been dropped and consumer is no longer interested in any messages.
dropped: bool,
/// Number of messages that have been sent without acknowledgement.
unacknowledged: i32,
/// Max number of messages to send before waiting for acknowledgement.
high_pressure_mark: i32,
}
impl StreamWriterSignal {
/// Create a new signal.
///
/// If `notify_sent()` is called more than `high_pressure_mark` times, it will wait until
/// `notify_acknowledge()` is called by another thread enough times to bring the number of
/// unacknowledged sent messages below that threshold.
fn new(high_pressure_mark: i32) -> StreamWriterSignal {
assert!(high_pressure_mark > 0);
StreamWriterSignal {
mutex: Mutex::new(StreamWriterSignalState {
dropped: false,
unacknowledged: 0,
high_pressure_mark,
}),
change_cond: Condvar::new(),
}
}
fn lock(&self) -> Result<MutexGuard<StreamWriterSignalState>, ShellError> {
self.mutex.lock().map_err(|_| ShellError::NushellFailed {
msg: "StreamWriterSignal mutex poisoned due to panic".into(),
})
}
/// True if the stream was dropped and the consumer is no longer interested in it. Indicates
/// that no more messages should be sent, other than `End`.
pub fn is_dropped(&self) -> Result<bool, ShellError> {
Ok(self.lock()?.dropped)
}
/// Notify the writers that the stream has been dropped, so they can stop writing.
pub fn set_dropped(&self) -> Result<(), ShellError> {
let mut state = self.lock()?;
state.dropped = true;
// Unblock the writers so they can terminate
self.change_cond.notify_all();
Ok(())
}
/// Track that a message has been sent. Returns `Ok(true)` if more messages can be sent,
/// or `Ok(false)` if the high pressure mark has been reached and [`.wait_for_drain()`] should
/// be called to block.
pub fn notify_sent(&self) -> Result<bool, ShellError> {
let mut state = self.lock()?;
state.unacknowledged =
state
.unacknowledged
.checked_add(1)
.ok_or_else(|| ShellError::NushellFailed {
msg: "Overflow in counter: too many unacknowledged messages".into(),
})?;
Ok(state.unacknowledged < state.high_pressure_mark)
}
/// Wait for acknowledgements before sending more data. Also returns if the stream is dropped.
pub fn wait_for_drain(&self) -> Result<(), ShellError> {
let mut state = self.lock()?;
while !state.dropped && state.unacknowledged >= state.high_pressure_mark {
state = self
.change_cond
.wait(state)
.map_err(|_| ShellError::NushellFailed {
msg: "StreamWriterSignal mutex poisoned due to panic".into(),
})?;
}
Ok(())
}
/// Notify the writers that a message has been acknowledged, so they can continue to write
/// if they were waiting.
pub fn notify_acknowledged(&self) -> Result<(), ShellError> {
let mut state = self.lock()?;
state.unacknowledged =
state
.unacknowledged
.checked_sub(1)
.ok_or_else(|| ShellError::NushellFailed {
msg: "Underflow in counter: too many message acknowledgements".into(),
})?;
// Unblock the writer
self.change_cond.notify_one();
Ok(())
}
}
/// A sink for a [`StreamMessage`]
pub trait WriteStreamMessage {
fn write_stream_message(&mut self, msg: StreamMessage) -> Result<(), ShellError>;
fn flush(&mut self) -> Result<(), ShellError>;
}
#[derive(Debug, Default)]
struct StreamManagerState {
reading_streams: BTreeMap<StreamId, mpsc::Sender<Result<Option<StreamData>, ShellError>>>,
writing_streams: BTreeMap<StreamId, Weak<StreamWriterSignal>>,
}
impl StreamManagerState {
/// Lock the state, or return a [`ShellError`] if the mutex is poisoned.
fn lock(
state: &Mutex<StreamManagerState>,
) -> Result<MutexGuard<StreamManagerState>, ShellError> {
state.lock().map_err(|_| ShellError::NushellFailed {
msg: "StreamManagerState mutex poisoned due to a panic".into(),
})
}
}
#[derive(Debug)]
pub struct StreamManager {
state: Arc<Mutex<StreamManagerState>>,
}
impl StreamManager {
/// Create a new StreamManager.
pub fn new() -> StreamManager {
StreamManager {
state: Default::default(),
}
}
fn lock(&self) -> Result<MutexGuard<StreamManagerState>, ShellError> {
StreamManagerState::lock(&self.state)
}
/// Create a new handle to the StreamManager for registering streams.
pub fn get_handle(&self) -> StreamManagerHandle {
StreamManagerHandle {
state: Arc::downgrade(&self.state),
}
}
/// Process a stream message, and update internal state accordingly.
pub fn handle_message(&self, message: StreamMessage) -> Result<(), ShellError> {
let mut state = self.lock()?;
match message {
StreamMessage::Data(id, data) => {
if let Some(sender) = state.reading_streams.get(&id) {
// We should ignore the error on send. This just means the reader has dropped,
// but it will have sent a Drop message to the other side, and we will receive
// an End message at which point we can remove the channel.
let _ = sender.send(Ok(Some(data)));
Ok(())
} else {
Err(ShellError::PluginFailedToDecode {
msg: format!("received Data for unknown stream {id}"),
})
}
}
StreamMessage::End(id) => {
if let Some(sender) = state.reading_streams.remove(&id) {
// We should ignore the error on the send, because the reader might have dropped
// already
let _ = sender.send(Ok(None));
Ok(())
} else {
Err(ShellError::PluginFailedToDecode {
msg: format!("received End for unknown stream {id}"),
})
}
}
StreamMessage::Drop(id) => {
if let Some(signal) = state.writing_streams.remove(&id) {
if let Some(signal) = signal.upgrade() {
// This will wake blocked writers so they can stop writing, so it's ok
signal.set_dropped()?;
}
}
// It's possible that the stream has already finished writing and we don't have it
// anymore, so we fall through to Ok
Ok(())
}
StreamMessage::Ack(id) => {
if let Some(signal) = state.writing_streams.get(&id) {
if let Some(signal) = signal.upgrade() {
// This will wake up a blocked writer
signal.notify_acknowledged()?;
} else {
// We know it doesn't exist, so might as well remove it
state.writing_streams.remove(&id);
}
}
// It's possible that the stream has already finished writing and we don't have it
// anymore, so we fall through to Ok
Ok(())
}
}
}
/// Broadcast an error to all stream readers. This is useful for error propagation.
pub fn broadcast_read_error(&self, error: ShellError) -> Result<(), ShellError> {
let state = self.lock()?;
for channel in state.reading_streams.values() {
// Ignore send errors.
let _ = channel.send(Err(error.clone()));
}
Ok(())
}
// If the `StreamManager` is dropped, we should let all of the stream writers know that they
// won't be able to write anymore. We don't need to do anything about the readers though
// because they'll know when the `Sender` is dropped automatically
fn drop_all_writers(&self) -> Result<(), ShellError> {
let mut state = self.lock()?;
let writers = std::mem::take(&mut state.writing_streams);
for (_, signal) in writers {
if let Some(signal) = signal.upgrade() {
// more important that we send to all than handling an error
let _ = signal.set_dropped();
}
}
Ok(())
}
}
impl Default for StreamManager {
fn default() -> Self {
Self::new()
}
}
impl Drop for StreamManager {
fn drop(&mut self) {
if let Err(err) = self.drop_all_writers() {
log::warn!("error during Drop for StreamManager: {}", err)
}
}
}
/// A [`StreamManagerHandle`] supports operations for interacting with the [`StreamManager`].
///
/// Streams can be registered for reading, returning a [`StreamReader`], or for writing, returning
/// a [`StreamWriter`].
#[derive(Debug, Clone)]
pub struct StreamManagerHandle {
state: Weak<Mutex<StreamManagerState>>,
}
impl StreamManagerHandle {
/// Because the handle only has a weak reference to the [`StreamManager`] state, we have to
/// first try to upgrade to a strong reference and then lock. This function wraps those two
/// operations together, handling errors appropriately.
fn with_lock<T, F>(&self, f: F) -> Result<T, ShellError>
where
F: FnOnce(MutexGuard<StreamManagerState>) -> Result<T, ShellError>,
{
let upgraded = self
.state
.upgrade()
.ok_or_else(|| ShellError::NushellFailed {
msg: "StreamManager is no longer alive".into(),
})?;
let guard = upgraded.lock().map_err(|_| ShellError::NushellFailed {
msg: "StreamManagerState mutex poisoned due to a panic".into(),
})?;
f(guard)
}
/// Register a new stream for reading, and return a [`StreamReader`] that can be used to iterate
/// on the values received. A [`StreamMessage`] writer is required for writing control messages
/// back to the producer.
pub fn read_stream<T, W>(
&self,
id: StreamId,
writer: W,
) -> Result<StreamReader<T, W>, ShellError>
where
T: TryFrom<StreamData, Error = ShellError>,
W: WriteStreamMessage,
{
let (tx, rx) = mpsc::channel();
self.with_lock(|mut state| {
// Must be exclusive
if let btree_map::Entry::Vacant(e) = state.reading_streams.entry(id) {
e.insert(tx);
Ok(())
} else {
Err(ShellError::GenericError {
error: format!("Failed to acquire reader for stream {id}"),
msg: "tried to get a reader for a stream that's already being read".into(),
span: None,
help: Some("this may be a bug in the nu-plugin crate".into()),
inner: vec![],
})
}
})?;
Ok(StreamReader::new(id, rx, writer))
}
/// Register a new stream for writing, and return a [`StreamWriter`] that can be used to send
/// data to the stream.
///
/// The `high_pressure_mark` value controls how many messages can be written without receiving
/// an acknowledgement before any further attempts to write will wait for the consumer to
/// acknowledge them. This prevents overwhelming the reader.
pub fn write_stream<W>(
&self,
id: StreamId,
writer: W,
high_pressure_mark: i32,
) -> Result<StreamWriter<W>, ShellError>
where
W: WriteStreamMessage,
{
let signal = Arc::new(StreamWriterSignal::new(high_pressure_mark));
self.with_lock(|mut state| {
// Remove dead writing streams
state
.writing_streams
.retain(|_, signal| signal.strong_count() > 0);
// Must be exclusive
if let btree_map::Entry::Vacant(e) = state.writing_streams.entry(id) {
e.insert(Arc::downgrade(&signal));
Ok(())
} else {
Err(ShellError::GenericError {
error: format!("Failed to acquire writer for stream {id}"),
msg: "tried to get a writer for a stream that's already being written".into(),
span: None,
help: Some("this may be a bug in the nu-plugin crate".into()),
inner: vec![],
})
}
})?;
Ok(StreamWriter::new(id, signal, writer))
}
}

View File

@ -0,0 +1,550 @@
use std::{
sync::{
atomic::{AtomicBool, Ordering::Relaxed},
mpsc, Arc,
},
time::{Duration, Instant},
};
use super::{StreamManager, StreamReader, StreamWriter, StreamWriterSignal, WriteStreamMessage};
use nu_plugin_protocol::{StreamData, StreamMessage};
use nu_protocol::{ShellError, Value};
// Should be long enough to definitely complete any quick operation, but not so long that tests are
// slow to complete. 10 ms is a pretty long time
const WAIT_DURATION: Duration = Duration::from_millis(10);
// Maximum time to wait for a condition to be true
const MAX_WAIT_DURATION: Duration = Duration::from_millis(500);
/// Wait for a condition to be true, or panic if the duration exceeds MAX_WAIT_DURATION
#[track_caller]
fn wait_for_condition(mut cond: impl FnMut() -> bool, message: &str) {
// Early check
if cond() {
return;
}
let start = Instant::now();
loop {
std::thread::sleep(Duration::from_millis(10));
if cond() {
return;
}
let elapsed = Instant::now().saturating_duration_since(start);
if elapsed > MAX_WAIT_DURATION {
panic!(
"{message}: Waited {:.2}sec, which is more than the maximum of {:.2}sec",
elapsed.as_secs_f64(),
MAX_WAIT_DURATION.as_secs_f64(),
);
}
}
}
#[derive(Debug, Clone, Default)]
struct TestSink(Vec<StreamMessage>);
impl WriteStreamMessage for TestSink {
fn write_stream_message(&mut self, msg: StreamMessage) -> Result<(), ShellError> {
self.0.push(msg);
Ok(())
}
fn flush(&mut self) -> Result<(), ShellError> {
Ok(())
}
}
impl WriteStreamMessage for mpsc::Sender<StreamMessage> {
fn write_stream_message(&mut self, msg: StreamMessage) -> Result<(), ShellError> {
self.send(msg).map_err(|err| ShellError::NushellFailed {
msg: err.to_string(),
})
}
fn flush(&mut self) -> Result<(), ShellError> {
Ok(())
}
}
#[test]
fn reader_recv_list_messages() -> Result<(), ShellError> {
let (tx, rx) = mpsc::channel();
let mut reader = StreamReader::new(0, rx, TestSink::default());
tx.send(Ok(Some(StreamData::List(Value::test_int(5)))))
.unwrap();
drop(tx);
assert_eq!(Some(Value::test_int(5)), reader.recv()?);
Ok(())
}
#[test]
fn list_reader_recv_wrong_type() -> Result<(), ShellError> {
let (tx, rx) = mpsc::channel();
let mut reader = StreamReader::<Value, _>::new(0, rx, TestSink::default());
tx.send(Ok(Some(StreamData::Raw(Ok(vec![10, 20])))))
.unwrap();
tx.send(Ok(Some(StreamData::List(Value::test_nothing()))))
.unwrap();
drop(tx);
reader.recv().expect_err("should be an error");
reader.recv().expect("should be able to recover");
Ok(())
}
#[test]
fn reader_recv_raw_messages() -> Result<(), ShellError> {
let (tx, rx) = mpsc::channel();
let mut reader =
StreamReader::<Result<Vec<u8>, ShellError>, _>::new(0, rx, TestSink::default());
tx.send(Ok(Some(StreamData::Raw(Ok(vec![10, 20])))))
.unwrap();
drop(tx);
assert_eq!(Some(vec![10, 20]), reader.recv()?.transpose()?);
Ok(())
}
#[test]
fn raw_reader_recv_wrong_type() -> Result<(), ShellError> {
let (tx, rx) = mpsc::channel();
let mut reader =
StreamReader::<Result<Vec<u8>, ShellError>, _>::new(0, rx, TestSink::default());
tx.send(Ok(Some(StreamData::List(Value::test_nothing()))))
.unwrap();
tx.send(Ok(Some(StreamData::Raw(Ok(vec![10, 20])))))
.unwrap();
drop(tx);
reader.recv().expect_err("should be an error");
reader.recv().expect("should be able to recover");
Ok(())
}
#[test]
fn reader_recv_acknowledge() -> Result<(), ShellError> {
let (tx, rx) = mpsc::channel();
let mut reader = StreamReader::<Value, _>::new(0, rx, TestSink::default());
tx.send(Ok(Some(StreamData::List(Value::test_int(5)))))
.unwrap();
tx.send(Ok(Some(StreamData::List(Value::test_int(6)))))
.unwrap();
drop(tx);
reader.recv()?;
reader.recv()?;
let wrote = &reader.writer.0;
assert!(wrote.len() >= 2);
assert!(
matches!(wrote[0], StreamMessage::Ack(0)),
"0 = {:?}",
wrote[0]
);
assert!(
matches!(wrote[1], StreamMessage::Ack(0)),
"1 = {:?}",
wrote[1]
);
Ok(())
}
#[test]
fn reader_recv_end_of_stream() -> Result<(), ShellError> {
let (tx, rx) = mpsc::channel();
let mut reader = StreamReader::<Value, _>::new(0, rx, TestSink::default());
tx.send(Ok(Some(StreamData::List(Value::test_int(5)))))
.unwrap();
tx.send(Ok(None)).unwrap();
drop(tx);
assert!(reader.recv()?.is_some(), "actual message");
assert!(reader.recv()?.is_none(), "on close");
assert!(reader.recv()?.is_none(), "after close");
Ok(())
}
#[test]
fn reader_iter_fuse_on_error() -> Result<(), ShellError> {
let (tx, rx) = mpsc::channel();
let mut reader = StreamReader::<Value, _>::new(0, rx, TestSink::default());
drop(tx); // should cause error, because we didn't explicitly signal the end
assert!(
reader.next().is_some_and(|e| e.is_error()),
"should be error the first time"
);
assert!(reader.next().is_none(), "should be closed the second time");
Ok(())
}
#[test]
fn reader_drop() {
let (_tx, rx) = mpsc::channel();
// Flag set if drop message is received.
struct Check(Arc<AtomicBool>);
impl WriteStreamMessage for Check {
fn write_stream_message(&mut self, msg: StreamMessage) -> Result<(), ShellError> {
assert!(matches!(msg, StreamMessage::Drop(1)), "got {:?}", msg);
self.0.store(true, Relaxed);
Ok(())
}
fn flush(&mut self) -> Result<(), ShellError> {
Ok(())
}
}
let flag = Arc::new(AtomicBool::new(false));
let reader = StreamReader::<Value, _>::new(1, rx, Check(flag.clone()));
drop(reader);
assert!(flag.load(Relaxed));
}
#[test]
fn writer_write_all_stops_if_dropped() -> Result<(), ShellError> {
let signal = Arc::new(StreamWriterSignal::new(20));
let id = 1337;
let mut writer = StreamWriter::new(id, signal.clone(), TestSink::default());
// Simulate this by having it consume a stream that will actually do the drop halfway through
let iter = (0..5).map(Value::test_int).chain({
let mut n = 5;
std::iter::from_fn(move || {
// produces numbers 5..10, but drops for the first one
if n == 5 {
signal.set_dropped().unwrap();
}
if n < 10 {
let value = Value::test_int(n);
n += 1;
Some(value)
} else {
None
}
})
});
writer.write_all(iter)?;
assert!(writer.is_dropped()?);
let wrote = &writer.writer.0;
assert_eq!(5, wrote.len(), "length wrong: {wrote:?}");
for (n, message) in (0..5).zip(wrote) {
match message {
StreamMessage::Data(msg_id, StreamData::List(value)) => {
assert_eq!(id, *msg_id, "id");
assert_eq!(Value::test_int(n), *value, "value");
}
other => panic!("unexpected message: {other:?}"),
}
}
Ok(())
}
#[test]
fn writer_end() -> Result<(), ShellError> {
let signal = Arc::new(StreamWriterSignal::new(20));
let mut writer = StreamWriter::new(9001, signal.clone(), TestSink::default());
writer.end()?;
writer
.write(Value::test_int(2))
.expect_err("shouldn't be able to write after end");
writer.end().expect("end twice should be ok");
let wrote = &writer.writer.0;
assert!(
matches!(wrote.last(), Some(StreamMessage::End(9001))),
"didn't write end message: {wrote:?}"
);
Ok(())
}
#[test]
fn signal_set_dropped() -> Result<(), ShellError> {
let signal = StreamWriterSignal::new(4);
assert!(!signal.is_dropped()?);
signal.set_dropped()?;
assert!(signal.is_dropped()?);
Ok(())
}
#[test]
fn signal_notify_sent_false_if_unacknowledged() -> Result<(), ShellError> {
let signal = StreamWriterSignal::new(2);
assert!(signal.notify_sent()?);
for _ in 0..100 {
assert!(!signal.notify_sent()?);
}
Ok(())
}
#[test]
fn signal_notify_sent_never_false_if_flowing() -> Result<(), ShellError> {
let signal = StreamWriterSignal::new(1);
for _ in 0..100 {
signal.notify_acknowledged()?;
}
for _ in 0..100 {
assert!(signal.notify_sent()?);
}
Ok(())
}
#[test]
fn signal_wait_for_drain_blocks_on_unacknowledged() -> Result<(), ShellError> {
let signal = StreamWriterSignal::new(50);
std::thread::scope(|scope| {
let spawned = scope.spawn(|| {
for _ in 0..100 {
if !signal.notify_sent()? {
signal.wait_for_drain()?;
}
}
Ok(())
});
std::thread::sleep(WAIT_DURATION);
assert!(!spawned.is_finished(), "didn't block");
for _ in 0..100 {
signal.notify_acknowledged()?;
}
wait_for_condition(|| spawned.is_finished(), "blocked at end");
spawned.join().unwrap()
})
}
#[test]
fn signal_wait_for_drain_unblocks_on_dropped() -> Result<(), ShellError> {
let signal = StreamWriterSignal::new(1);
std::thread::scope(|scope| {
let spawned = scope.spawn(|| {
while !signal.is_dropped()? {
if !signal.notify_sent()? {
signal.wait_for_drain()?;
}
}
Ok(())
});
std::thread::sleep(WAIT_DURATION);
assert!(!spawned.is_finished(), "didn't block");
signal.set_dropped()?;
wait_for_condition(|| spawned.is_finished(), "still blocked at end");
spawned.join().unwrap()
})
}
#[test]
fn stream_manager_single_stream_read_scenario() -> Result<(), ShellError> {
let manager = StreamManager::new();
let handle = manager.get_handle();
let (tx, rx) = mpsc::channel();
let readable = handle.read_stream::<Value, _>(2, tx)?;
let expected_values = vec![Value::test_int(40), Value::test_string("hello")];
for value in &expected_values {
manager.handle_message(StreamMessage::Data(2, value.clone().into()))?;
}
manager.handle_message(StreamMessage::End(2))?;
let values = readable.collect::<Vec<Value>>();
assert_eq!(expected_values, values);
// Now check the sent messages on consumption
// Should be Ack for each message, then Drop
for _ in &expected_values {
match rx.try_recv().expect("failed to receive Ack") {
StreamMessage::Ack(2) => (),
other => panic!("should have been an Ack: {other:?}"),
}
}
match rx.try_recv().expect("failed to receive Drop") {
StreamMessage::Drop(2) => (),
other => panic!("should have been a Drop: {other:?}"),
}
Ok(())
}
#[test]
fn stream_manager_multi_stream_read_scenario() -> Result<(), ShellError> {
let manager = StreamManager::new();
let handle = manager.get_handle();
let (tx, rx) = mpsc::channel();
let readable_list = handle.read_stream::<Value, _>(2, tx.clone())?;
let readable_raw = handle.read_stream::<Result<Vec<u8>, _>, _>(3, tx)?;
let expected_values = (1..100).map(Value::test_int).collect::<Vec<_>>();
let expected_raw_buffers = (1..100).map(|n| vec![n]).collect::<Vec<Vec<u8>>>();
for (value, buf) in expected_values.iter().zip(&expected_raw_buffers) {
manager.handle_message(StreamMessage::Data(2, value.clone().into()))?;
manager.handle_message(StreamMessage::Data(3, StreamData::Raw(Ok(buf.clone()))))?;
}
manager.handle_message(StreamMessage::End(2))?;
manager.handle_message(StreamMessage::End(3))?;
let values = readable_list.collect::<Vec<Value>>();
let bufs = readable_raw.collect::<Result<Vec<Vec<u8>>, _>>()?;
for (expected_value, value) in expected_values.iter().zip(&values) {
assert_eq!(expected_value, value, "in List stream");
}
for (expected_buf, buf) in expected_raw_buffers.iter().zip(&bufs) {
assert_eq!(expected_buf, buf, "in Raw stream");
}
// Now check the sent messages on consumption
// Should be Ack for each message, then Drop
for _ in &expected_values {
match rx.try_recv().expect("failed to receive Ack") {
StreamMessage::Ack(2) => (),
other => panic!("should have been an Ack(2): {other:?}"),
}
}
match rx.try_recv().expect("failed to receive Drop") {
StreamMessage::Drop(2) => (),
other => panic!("should have been a Drop(2): {other:?}"),
}
for _ in &expected_values {
match rx.try_recv().expect("failed to receive Ack") {
StreamMessage::Ack(3) => (),
other => panic!("should have been an Ack(3): {other:?}"),
}
}
match rx.try_recv().expect("failed to receive Drop") {
StreamMessage::Drop(3) => (),
other => panic!("should have been a Drop(3): {other:?}"),
}
// Should be end of stream
assert!(
rx.try_recv().is_err(),
"more messages written to stream than expected"
);
Ok(())
}
#[test]
fn stream_manager_write_scenario() -> Result<(), ShellError> {
let manager = StreamManager::new();
let handle = manager.get_handle();
let (tx, rx) = mpsc::channel();
let mut writable = handle.write_stream(4, tx, 100)?;
let expected_values = vec![b"hello".to_vec(), b"world".to_vec(), b"test".to_vec()];
for value in &expected_values {
writable.write(Ok::<_, ShellError>(value.clone()))?;
}
// Now try signalling ack
assert_eq!(
expected_values.len() as i32,
writable.signal.lock()?.unacknowledged,
"unacknowledged initial count",
);
manager.handle_message(StreamMessage::Ack(4))?;
assert_eq!(
expected_values.len() as i32 - 1,
writable.signal.lock()?.unacknowledged,
"unacknowledged post-Ack count",
);
// ...and Drop
manager.handle_message(StreamMessage::Drop(4))?;
assert!(writable.is_dropped()?);
// Drop the StreamWriter...
drop(writable);
// now check what was actually written
for value in &expected_values {
match rx.try_recv().expect("failed to receive Data") {
StreamMessage::Data(4, StreamData::Raw(Ok(received))) => {
assert_eq!(*value, received);
}
other @ StreamMessage::Data(..) => panic!("wrong Data for {value:?}: {other:?}"),
other => panic!("should have been Data: {other:?}"),
}
}
match rx.try_recv().expect("failed to receive End") {
StreamMessage::End(4) => (),
other => panic!("should have been End: {other:?}"),
}
Ok(())
}
#[test]
fn stream_manager_broadcast_read_error() -> Result<(), ShellError> {
let manager = StreamManager::new();
let handle = manager.get_handle();
let mut readable0 = handle.read_stream::<Value, _>(0, TestSink::default())?;
let mut readable1 = handle.read_stream::<Result<Vec<u8>, _>, _>(1, TestSink::default())?;
let error = ShellError::PluginFailedToDecode {
msg: "test decode error".into(),
};
manager.broadcast_read_error(error.clone())?;
drop(manager);
assert_eq!(
error.to_string(),
readable0
.recv()
.transpose()
.expect("nothing received from readable0")
.expect_err("not an error received from readable0")
.to_string()
);
assert_eq!(
error.to_string(),
readable1
.next()
.expect("nothing received from readable1")
.expect_err("not an error received from readable1")
.to_string()
);
Ok(())
}
#[test]
fn stream_manager_drop_writers_on_drop() -> Result<(), ShellError> {
let manager = StreamManager::new();
let handle = manager.get_handle();
let writable = handle.write_stream(4, TestSink::default(), 100)?;
assert!(!writable.is_dropped()?);
drop(manager);
assert!(writable.is_dropped()?);
Ok(())
}

View File

@ -0,0 +1,134 @@
use nu_protocol::ShellError;
use std::{
collections::VecDeque,
sync::{Arc, Mutex},
};
use crate::{PluginRead, PluginWrite};
const FAILED: &str = "failed to lock TestCase";
/// Mock read/write helper for the engine and plugin interfaces.
#[derive(Debug, Clone)]
pub struct TestCase<I, O> {
r#in: Arc<Mutex<TestData<I>>>,
out: Arc<Mutex<TestData<O>>>,
}
#[derive(Debug)]
pub struct TestData<T> {
data: VecDeque<T>,
error: Option<ShellError>,
flushed: bool,
}
impl<T> Default for TestData<T> {
fn default() -> Self {
TestData {
data: VecDeque::new(),
error: None,
flushed: false,
}
}
}
impl<I, O> PluginRead<I> for TestCase<I, O> {
fn read(&mut self) -> Result<Option<I>, ShellError> {
let mut lock = self.r#in.lock().expect(FAILED);
if let Some(err) = lock.error.take() {
Err(err)
} else {
Ok(lock.data.pop_front())
}
}
}
impl<I, O> PluginWrite<O> for TestCase<I, O>
where
I: Send + Clone,
O: Send + Clone,
{
fn write(&self, data: &O) -> Result<(), ShellError> {
let mut lock = self.out.lock().expect(FAILED);
lock.flushed = false;
if let Some(err) = lock.error.take() {
Err(err)
} else {
lock.data.push_back(data.clone());
Ok(())
}
}
fn flush(&self) -> Result<(), ShellError> {
let mut lock = self.out.lock().expect(FAILED);
lock.flushed = true;
Ok(())
}
}
#[allow(dead_code)]
impl<I, O> TestCase<I, O> {
pub fn new() -> TestCase<I, O> {
TestCase {
r#in: Default::default(),
out: Default::default(),
}
}
/// Clear the read buffer.
pub fn clear(&self) {
self.r#in.lock().expect(FAILED).data.truncate(0);
}
/// Add input that will be read by the interface.
pub fn add(&self, input: impl Into<I>) {
self.r#in.lock().expect(FAILED).data.push_back(input.into());
}
/// Add multiple inputs that will be read by the interface.
pub fn extend(&self, inputs: impl IntoIterator<Item = I>) {
self.r#in.lock().expect(FAILED).data.extend(inputs);
}
/// Return an error from the next read operation.
pub fn set_read_error(&self, err: ShellError) {
self.r#in.lock().expect(FAILED).error = Some(err);
}
/// Return an error from the next write operation.
pub fn set_write_error(&self, err: ShellError) {
self.out.lock().expect(FAILED).error = Some(err);
}
/// Get the next output that was written.
pub fn next_written(&self) -> Option<O> {
self.out.lock().expect(FAILED).data.pop_front()
}
/// Iterator over written data.
pub fn written(&self) -> impl Iterator<Item = O> + '_ {
std::iter::from_fn(|| self.next_written())
}
/// Returns true if the writer was flushed after the last write operation.
pub fn was_flushed(&self) -> bool {
self.out.lock().expect(FAILED).flushed
}
/// Returns true if the reader has unconsumed reads.
pub fn has_unconsumed_read(&self) -> bool {
!self.r#in.lock().expect(FAILED).data.is_empty()
}
/// Returns true if the writer has unconsumed writes.
pub fn has_unconsumed_write(&self) -> bool {
!self.out.lock().expect(FAILED).data.is_empty()
}
}
impl<I, O> Default for TestCase<I, O> {
fn default() -> Self {
Self::new()
}
}

View File

@ -0,0 +1,573 @@
use crate::util::Sequence;
use super::{
stream::{StreamManager, StreamManagerHandle},
test_util::TestCase,
Interface, InterfaceManager, PluginRead, PluginWrite,
};
use nu_plugin_protocol::{
ExternalStreamInfo, ListStreamInfo, PipelineDataHeader, PluginInput, PluginOutput,
RawStreamInfo, StreamData, StreamMessage,
};
use nu_protocol::{
DataSource, ListStream, PipelineData, PipelineMetadata, RawStream, ShellError, Span, Value,
};
use std::{path::Path, sync::Arc};
fn test_metadata() -> PipelineMetadata {
PipelineMetadata {
data_source: DataSource::FilePath("/test/path".into()),
}
}
#[derive(Debug)]
struct TestInterfaceManager {
stream_manager: StreamManager,
test: TestCase<PluginInput, PluginOutput>,
seq: Arc<Sequence>,
}
#[derive(Debug, Clone)]
struct TestInterface {
stream_manager_handle: StreamManagerHandle,
test: TestCase<PluginInput, PluginOutput>,
seq: Arc<Sequence>,
}
impl TestInterfaceManager {
fn new(test: &TestCase<PluginInput, PluginOutput>) -> TestInterfaceManager {
TestInterfaceManager {
stream_manager: StreamManager::new(),
test: test.clone(),
seq: Arc::new(Sequence::default()),
}
}
fn consume_all(&mut self) -> Result<(), ShellError> {
while let Some(msg) = self.test.read()? {
self.consume(msg)?;
}
Ok(())
}
}
impl InterfaceManager for TestInterfaceManager {
type Interface = TestInterface;
type Input = PluginInput;
fn get_interface(&self) -> Self::Interface {
TestInterface {
stream_manager_handle: self.stream_manager.get_handle(),
test: self.test.clone(),
seq: self.seq.clone(),
}
}
fn consume(&mut self, input: Self::Input) -> Result<(), ShellError> {
match input {
PluginInput::Data(..)
| PluginInput::End(..)
| PluginInput::Drop(..)
| PluginInput::Ack(..) => self.consume_stream_message(
input
.try_into()
.expect("failed to convert message to StreamMessage"),
),
_ => unimplemented!(),
}
}
fn stream_manager(&self) -> &StreamManager {
&self.stream_manager
}
fn prepare_pipeline_data(&self, data: PipelineData) -> Result<PipelineData, ShellError> {
Ok(data.set_metadata(Some(test_metadata())))
}
}
impl Interface for TestInterface {
type Output = PluginOutput;
type DataContext = ();
fn write(&self, output: Self::Output) -> Result<(), ShellError> {
self.test.write(&output)
}
fn flush(&self) -> Result<(), ShellError> {
Ok(())
}
fn stream_id_sequence(&self) -> &Sequence {
&self.seq
}
fn stream_manager_handle(&self) -> &StreamManagerHandle {
&self.stream_manager_handle
}
fn prepare_pipeline_data(
&self,
data: PipelineData,
_context: &(),
) -> Result<PipelineData, ShellError> {
// Add an arbitrary check to the data to verify this is being called
match data {
PipelineData::Value(Value::Binary { .. }, None) => Err(ShellError::NushellFailed {
msg: "TEST can't send binary".into(),
}),
_ => Ok(data),
}
}
}
#[test]
fn read_pipeline_data_empty() -> Result<(), ShellError> {
let manager = TestInterfaceManager::new(&TestCase::new());
let header = PipelineDataHeader::Empty;
assert!(matches!(
manager.read_pipeline_data(header, None)?,
PipelineData::Empty
));
Ok(())
}
#[test]
fn read_pipeline_data_value() -> Result<(), ShellError> {
let manager = TestInterfaceManager::new(&TestCase::new());
let value = Value::test_int(4);
let header = PipelineDataHeader::Value(value.clone());
match manager.read_pipeline_data(header, None)? {
PipelineData::Value(read_value, _) => assert_eq!(value, read_value),
PipelineData::ListStream(_, _) => panic!("unexpected ListStream"),
PipelineData::ExternalStream { .. } => panic!("unexpected ExternalStream"),
PipelineData::Empty => panic!("unexpected Empty"),
}
Ok(())
}
#[test]
fn read_pipeline_data_list_stream() -> Result<(), ShellError> {
let test = TestCase::new();
let mut manager = TestInterfaceManager::new(&test);
let data = (0..100).map(Value::test_int).collect::<Vec<_>>();
for value in &data {
test.add(StreamMessage::Data(7, value.clone().into()));
}
test.add(StreamMessage::End(7));
let header = PipelineDataHeader::ListStream(ListStreamInfo { id: 7 });
let pipe = manager.read_pipeline_data(header, None)?;
assert!(
matches!(pipe, PipelineData::ListStream(..)),
"unexpected PipelineData: {pipe:?}"
);
// need to consume input
manager.consume_all()?;
let mut count = 0;
for (expected, read) in data.into_iter().zip(pipe) {
assert_eq!(expected, read);
count += 1;
}
assert_eq!(100, count);
assert!(test.has_unconsumed_write());
Ok(())
}
#[test]
fn read_pipeline_data_external_stream() -> Result<(), ShellError> {
let test = TestCase::new();
let mut manager = TestInterfaceManager::new(&test);
let iterations = 100;
let out_pattern = b"hello".to_vec();
let err_pattern = vec![5, 4, 3, 2];
test.add(StreamMessage::Data(14, Value::test_int(1).into()));
for _ in 0..iterations {
test.add(StreamMessage::Data(
12,
StreamData::Raw(Ok(out_pattern.clone())),
));
test.add(StreamMessage::Data(
13,
StreamData::Raw(Ok(err_pattern.clone())),
));
}
test.add(StreamMessage::End(12));
test.add(StreamMessage::End(13));
test.add(StreamMessage::End(14));
let test_span = Span::new(10, 13);
let header = PipelineDataHeader::ExternalStream(ExternalStreamInfo {
span: test_span,
stdout: Some(RawStreamInfo {
id: 12,
is_binary: false,
known_size: Some((out_pattern.len() * iterations) as u64),
}),
stderr: Some(RawStreamInfo {
id: 13,
is_binary: true,
known_size: None,
}),
exit_code: Some(ListStreamInfo { id: 14 }),
trim_end_newline: true,
});
let pipe = manager.read_pipeline_data(header, None)?;
// need to consume input
manager.consume_all()?;
match pipe {
PipelineData::ExternalStream {
stdout,
stderr,
exit_code,
span,
metadata,
trim_end_newline,
} => {
let stdout = stdout.expect("stdout is None");
let stderr = stderr.expect("stderr is None");
let exit_code = exit_code.expect("exit_code is None");
assert_eq!(test_span, span);
assert!(
metadata.is_some(),
"expected metadata to be Some due to prepare_pipeline_data()"
);
assert!(trim_end_newline);
assert!(!stdout.is_binary);
assert!(stderr.is_binary);
assert_eq!(
Some((out_pattern.len() * iterations) as u64),
stdout.known_size
);
assert_eq!(None, stderr.known_size);
// check the streams
let mut count = 0;
for chunk in stdout.stream {
assert_eq!(out_pattern, chunk?);
count += 1;
}
assert_eq!(iterations, count, "stdout length");
let mut count = 0;
for chunk in stderr.stream {
assert_eq!(err_pattern, chunk?);
count += 1;
}
assert_eq!(iterations, count, "stderr length");
assert_eq!(vec![Value::test_int(1)], exit_code.collect::<Vec<_>>());
}
_ => panic!("unexpected PipelineData: {pipe:?}"),
}
// Don't need to check exactly what was written, just be sure that there is some output
assert!(test.has_unconsumed_write());
Ok(())
}
#[test]
fn read_pipeline_data_ctrlc() -> Result<(), ShellError> {
let manager = TestInterfaceManager::new(&TestCase::new());
let header = PipelineDataHeader::ListStream(ListStreamInfo { id: 0 });
let ctrlc = Default::default();
match manager.read_pipeline_data(header, Some(&ctrlc))? {
PipelineData::ListStream(
ListStream {
ctrlc: stream_ctrlc,
..
},
_,
) => {
assert!(Arc::ptr_eq(&ctrlc, &stream_ctrlc.expect("ctrlc not set")));
Ok(())
}
_ => panic!("Unexpected PipelineData, should have been ListStream"),
}
}
#[test]
fn read_pipeline_data_prepared_properly() -> Result<(), ShellError> {
let manager = TestInterfaceManager::new(&TestCase::new());
let header = PipelineDataHeader::ListStream(ListStreamInfo { id: 0 });
match manager.read_pipeline_data(header, None)? {
PipelineData::ListStream(_, meta) => match meta {
Some(PipelineMetadata { data_source }) => match data_source {
DataSource::FilePath(path) => {
assert_eq!(Path::new("/test/path"), path);
Ok(())
}
_ => panic!("wrong metadata: {data_source:?}"),
},
None => panic!("metadata not set"),
},
_ => panic!("Unexpected PipelineData, should have been ListStream"),
}
}
#[test]
fn write_pipeline_data_empty() -> Result<(), ShellError> {
let test = TestCase::new();
let manager = TestInterfaceManager::new(&test);
let interface = manager.get_interface();
let (header, writer) = interface.init_write_pipeline_data(PipelineData::Empty, &())?;
assert!(matches!(header, PipelineDataHeader::Empty));
writer.write()?;
assert!(
!test.has_unconsumed_write(),
"Empty shouldn't write any stream messages, test: {test:#?}"
);
Ok(())
}
#[test]
fn write_pipeline_data_value() -> Result<(), ShellError> {
let test = TestCase::new();
let manager = TestInterfaceManager::new(&test);
let interface = manager.get_interface();
let value = Value::test_int(7);
let (header, writer) =
interface.init_write_pipeline_data(PipelineData::Value(value.clone(), None), &())?;
match header {
PipelineDataHeader::Value(read_value) => assert_eq!(value, read_value),
_ => panic!("unexpected header: {header:?}"),
}
writer.write()?;
assert!(
!test.has_unconsumed_write(),
"Value shouldn't write any stream messages, test: {test:#?}"
);
Ok(())
}
#[test]
fn write_pipeline_data_prepared_properly() {
let manager = TestInterfaceManager::new(&TestCase::new());
let interface = manager.get_interface();
// Sending a binary should be an error in our test scenario
let value = Value::test_binary(vec![7, 8]);
match interface.init_write_pipeline_data(PipelineData::Value(value, None), &()) {
Ok(_) => panic!("prepare_pipeline_data was not called"),
Err(err) => {
assert_eq!(
ShellError::NushellFailed {
msg: "TEST can't send binary".into()
}
.to_string(),
err.to_string()
);
}
}
}
#[test]
fn write_pipeline_data_list_stream() -> Result<(), ShellError> {
let test = TestCase::new();
let manager = TestInterfaceManager::new(&test);
let interface = manager.get_interface();
let values = vec![
Value::test_int(40),
Value::test_bool(false),
Value::test_string("this is a test"),
];
// Set up pipeline data for a list stream
let pipe = PipelineData::ListStream(
ListStream::from_stream(values.clone().into_iter(), None),
None,
);
let (header, writer) = interface.init_write_pipeline_data(pipe, &())?;
let info = match header {
PipelineDataHeader::ListStream(info) => info,
_ => panic!("unexpected header: {header:?}"),
};
writer.write()?;
// Now make sure the stream messages have been written
for value in values {
match test.next_written().expect("unexpected end of stream") {
PluginOutput::Data(id, data) => {
assert_eq!(info.id, id, "Data id");
match data {
StreamData::List(read_value) => assert_eq!(value, read_value, "Data value"),
_ => panic!("unexpected Data: {data:?}"),
}
}
other => panic!("unexpected output: {other:?}"),
}
}
match test.next_written().expect("unexpected end of stream") {
PluginOutput::End(id) => {
assert_eq!(info.id, id, "End id");
}
other => panic!("unexpected output: {other:?}"),
}
assert!(!test.has_unconsumed_write());
Ok(())
}
#[test]
fn write_pipeline_data_external_stream() -> Result<(), ShellError> {
let test = TestCase::new();
let manager = TestInterfaceManager::new(&test);
let interface = manager.get_interface();
let stdout_bufs = vec![
b"hello".to_vec(),
b"world".to_vec(),
b"these are tests".to_vec(),
];
let stdout_len = stdout_bufs.iter().map(|b| b.len() as u64).sum::<u64>();
let stderr_bufs = vec![b"error messages".to_vec(), b"go here".to_vec()];
let exit_code = Value::test_int(7);
let span = Span::new(400, 500);
// Set up pipeline data for an external stream
let pipe = PipelineData::ExternalStream {
stdout: Some(RawStream::new(
Box::new(stdout_bufs.clone().into_iter().map(Ok)),
None,
span,
Some(stdout_len),
)),
stderr: Some(RawStream::new(
Box::new(stderr_bufs.clone().into_iter().map(Ok)),
None,
span,
None,
)),
exit_code: Some(ListStream::from_stream(
std::iter::once(exit_code.clone()),
None,
)),
span,
metadata: None,
trim_end_newline: true,
};
let (header, writer) = interface.init_write_pipeline_data(pipe, &())?;
let info = match header {
PipelineDataHeader::ExternalStream(info) => info,
_ => panic!("unexpected header: {header:?}"),
};
writer.write()?;
let stdout_info = info.stdout.as_ref().expect("stdout info is None");
let stderr_info = info.stderr.as_ref().expect("stderr info is None");
let exit_code_info = info.exit_code.as_ref().expect("exit code info is None");
assert_eq!(span, info.span);
assert!(info.trim_end_newline);
assert_eq!(Some(stdout_len), stdout_info.known_size);
assert_eq!(None, stderr_info.known_size);
// Now make sure the stream messages have been written
let mut stdout_iter = stdout_bufs.into_iter();
let mut stderr_iter = stderr_bufs.into_iter();
let mut exit_code_iter = std::iter::once(exit_code);
let mut stdout_ended = false;
let mut stderr_ended = false;
let mut exit_code_ended = false;
// There's no specific order these messages must come in with respect to how the streams are
// interleaved, but all of the data for each stream must be in its original order, and the
// End must come after all Data
for msg in test.written() {
match msg {
PluginOutput::Data(id, data) => {
if id == stdout_info.id {
let result: Result<Vec<u8>, ShellError> =
data.try_into().expect("wrong data in stdout stream");
assert_eq!(
stdout_iter.next().expect("too much data in stdout"),
result.expect("unexpected error in stdout stream")
);
} else if id == stderr_info.id {
let result: Result<Vec<u8>, ShellError> =
data.try_into().expect("wrong data in stderr stream");
assert_eq!(
stderr_iter.next().expect("too much data in stderr"),
result.expect("unexpected error in stderr stream")
);
} else if id == exit_code_info.id {
let code: Value = data.try_into().expect("wrong data in stderr stream");
assert_eq!(
exit_code_iter.next().expect("too much data in stderr"),
code
);
} else {
panic!("unrecognized stream id: {id}");
}
}
PluginOutput::End(id) => {
if id == stdout_info.id {
assert!(!stdout_ended, "double End of stdout");
assert!(stdout_iter.next().is_none(), "unexpected end of stdout");
stdout_ended = true;
} else if id == stderr_info.id {
assert!(!stderr_ended, "double End of stderr");
assert!(stderr_iter.next().is_none(), "unexpected end of stderr");
stderr_ended = true;
} else if id == exit_code_info.id {
assert!(!exit_code_ended, "double End of exit_code");
assert!(
exit_code_iter.next().is_none(),
"unexpected end of exit_code"
);
exit_code_ended = true;
} else {
panic!("unrecognized stream id: {id}");
}
}
other => panic!("unexpected output: {other:?}"),
}
}
assert!(stdout_ended, "stdout did not End");
assert!(stderr_ended, "stderr did not End");
assert!(exit_code_ended, "exit_code did not End");
Ok(())
}

View File

@ -0,0 +1,24 @@
//! Functionality and types shared between the plugin and the engine, other than protocol types.
//!
//! If you are writing a plugin, you probably don't need this crate. We will make fewer guarantees
//! for the stability of the interface of this crate than for `nu_plugin`.
pub mod util;
mod communication_mode;
mod interface;
mod serializers;
pub use communication_mode::{
ClientCommunicationIo, CommunicationMode, PreparedServerCommunication, ServerCommunicationIo,
};
pub use interface::{
stream::{FromShellError, StreamManager, StreamManagerHandle, StreamReader, StreamWriter},
Interface, InterfaceManager, PipelineDataWriter, PluginRead, PluginWrite,
};
pub use serializers::{
json::JsonSerializer, msgpack::MsgPackSerializer, Encoder, EncodingType, PluginEncoder,
};
#[doc(hidden)]
pub use interface::test_util as interface_test_util;

View File

@ -0,0 +1,133 @@
use nu_plugin_protocol::{PluginInput, PluginOutput};
use nu_protocol::ShellError;
use serde::Deserialize;
use crate::{Encoder, PluginEncoder};
/// A `PluginEncoder` that enables the plugin to communicate with Nushell with JSON
/// serialized data.
///
/// Each message in the stream is followed by a newline when serializing, but is not required for
/// deserialization. The output is not pretty printed and each object does not contain newlines.
/// If it is more convenient, a plugin may choose to separate messages by newline.
#[derive(Clone, Copy, Debug)]
pub struct JsonSerializer;
impl PluginEncoder for JsonSerializer {
fn name(&self) -> &str {
"json"
}
}
impl Encoder<PluginInput> for JsonSerializer {
fn encode(
&self,
plugin_input: &PluginInput,
writer: &mut impl std::io::Write,
) -> Result<(), nu_protocol::ShellError> {
serde_json::to_writer(&mut *writer, plugin_input).map_err(json_encode_err)?;
writer.write_all(b"\n").map_err(|err| ShellError::IOError {
msg: err.to_string(),
})
}
fn decode(
&self,
reader: &mut impl std::io::BufRead,
) -> Result<Option<PluginInput>, nu_protocol::ShellError> {
let mut de = serde_json::Deserializer::from_reader(reader);
PluginInput::deserialize(&mut de)
.map(Some)
.or_else(json_decode_err)
}
}
impl Encoder<PluginOutput> for JsonSerializer {
fn encode(
&self,
plugin_output: &PluginOutput,
writer: &mut impl std::io::Write,
) -> Result<(), ShellError> {
serde_json::to_writer(&mut *writer, plugin_output).map_err(json_encode_err)?;
writer.write_all(b"\n").map_err(|err| ShellError::IOError {
msg: err.to_string(),
})
}
fn decode(
&self,
reader: &mut impl std::io::BufRead,
) -> Result<Option<PluginOutput>, ShellError> {
let mut de = serde_json::Deserializer::from_reader(reader);
PluginOutput::deserialize(&mut de)
.map(Some)
.or_else(json_decode_err)
}
}
/// Handle a `serde_json` encode error.
fn json_encode_err(err: serde_json::Error) -> ShellError {
if err.is_io() {
ShellError::IOError {
msg: err.to_string(),
}
} else {
ShellError::PluginFailedToEncode {
msg: err.to_string(),
}
}
}
/// Handle a `serde_json` decode error. Returns `Ok(None)` on eof.
fn json_decode_err<T>(err: serde_json::Error) -> Result<Option<T>, ShellError> {
if err.is_eof() {
Ok(None)
} else if err.is_io() {
Err(ShellError::IOError {
msg: err.to_string(),
})
} else {
Err(ShellError::PluginFailedToDecode {
msg: err.to_string(),
})
}
}
#[cfg(test)]
mod tests {
use super::*;
crate::serializers::tests::generate_tests!(JsonSerializer {});
#[test]
fn json_ends_in_newline() {
let mut out = vec![];
JsonSerializer {}
.encode(&PluginInput::Call(0, PluginCall::Signature), &mut out)
.expect("serialization error");
let string = std::str::from_utf8(&out).expect("utf-8 error");
assert!(
string.ends_with('\n'),
"doesn't end with newline: {:?}",
string
);
}
#[test]
fn json_has_no_other_newlines() {
let mut out = vec![];
// use something deeply nested, to try to trigger any pretty printing
let output = PluginOutput::Data(
0,
StreamData::List(Value::test_list(vec![
Value::test_int(4),
// in case escaping failed
Value::test_string("newline\ncontaining\nstring"),
])),
);
JsonSerializer {}
.encode(&output, &mut out)
.expect("serialization error");
let string = std::str::from_utf8(&out).expect("utf-8 error");
assert_eq!(1, string.chars().filter(|ch| *ch == '\n').count());
}
}

View File

@ -0,0 +1,71 @@
use nu_plugin_protocol::{PluginInput, PluginOutput};
use nu_protocol::ShellError;
pub mod json;
pub mod msgpack;
#[cfg(test)]
mod tests;
/// Encoder for a specific message type. Usually implemented on [`PluginInput`]
/// and [`PluginOutput`].
pub trait Encoder<T>: Clone + Send + Sync {
/// Serialize a value in the [`PluginEncoder`]s format
///
/// Returns [`ShellError::IOError`] if there was a problem writing, or
/// [`ShellError::PluginFailedToEncode`] for a serialization error.
fn encode(&self, data: &T, writer: &mut impl std::io::Write) -> Result<(), ShellError>;
/// Deserialize a value from the [`PluginEncoder`]'s format
///
/// Returns `None` if there is no more output to receive.
///
/// Returns [`ShellError::IOError`] if there was a problem reading, or
/// [`ShellError::PluginFailedToDecode`] for a deserialization error.
fn decode(&self, reader: &mut impl std::io::BufRead) -> Result<Option<T>, ShellError>;
}
/// Encoding scheme that defines a plugin's communication protocol with Nu
pub trait PluginEncoder: Encoder<PluginInput> + Encoder<PluginOutput> {
/// The name of the encoder (e.g., `json`)
fn name(&self) -> &str;
}
/// Enum that supports all of the plugin serialization formats.
#[derive(Clone, Copy, Debug)]
pub enum EncodingType {
Json(json::JsonSerializer),
MsgPack(msgpack::MsgPackSerializer),
}
impl EncodingType {
/// Determine the plugin encoding type from the provided byte string (either `b"json"` or
/// `b"msgpack"`).
pub fn try_from_bytes(bytes: &[u8]) -> Option<Self> {
match bytes {
b"json" => Some(Self::Json(json::JsonSerializer {})),
b"msgpack" => Some(Self::MsgPack(msgpack::MsgPackSerializer {})),
_ => None,
}
}
}
impl<T> Encoder<T> for EncodingType
where
json::JsonSerializer: Encoder<T>,
msgpack::MsgPackSerializer: Encoder<T>,
{
fn encode(&self, data: &T, writer: &mut impl std::io::Write) -> Result<(), ShellError> {
match self {
EncodingType::Json(encoder) => encoder.encode(data, writer),
EncodingType::MsgPack(encoder) => encoder.encode(data, writer),
}
}
fn decode(&self, reader: &mut impl std::io::BufRead) -> Result<Option<T>, ShellError> {
match self {
EncodingType::Json(encoder) => encoder.decode(reader),
EncodingType::MsgPack(encoder) => encoder.decode(reader),
}
}
}

View File

@ -0,0 +1,108 @@
use std::io::ErrorKind;
use nu_plugin_protocol::{PluginInput, PluginOutput};
use nu_protocol::ShellError;
use serde::Deserialize;
use crate::{Encoder, PluginEncoder};
/// A `PluginEncoder` that enables the plugin to communicate with Nushell with MsgPack
/// serialized data.
///
/// Each message is written as a MessagePack object. There is no message envelope or separator.
#[derive(Clone, Copy, Debug)]
pub struct MsgPackSerializer;
impl PluginEncoder for MsgPackSerializer {
fn name(&self) -> &str {
"msgpack"
}
}
impl Encoder<PluginInput> for MsgPackSerializer {
fn encode(
&self,
plugin_input: &PluginInput,
writer: &mut impl std::io::Write,
) -> Result<(), nu_protocol::ShellError> {
rmp_serde::encode::write_named(writer, plugin_input).map_err(rmp_encode_err)
}
fn decode(
&self,
reader: &mut impl std::io::BufRead,
) -> Result<Option<PluginInput>, ShellError> {
let mut de = rmp_serde::Deserializer::new(reader);
PluginInput::deserialize(&mut de)
.map(Some)
.or_else(rmp_decode_err)
}
}
impl Encoder<PluginOutput> for MsgPackSerializer {
fn encode(
&self,
plugin_output: &PluginOutput,
writer: &mut impl std::io::Write,
) -> Result<(), ShellError> {
rmp_serde::encode::write_named(writer, plugin_output).map_err(rmp_encode_err)
}
fn decode(
&self,
reader: &mut impl std::io::BufRead,
) -> Result<Option<PluginOutput>, ShellError> {
let mut de = rmp_serde::Deserializer::new(reader);
PluginOutput::deserialize(&mut de)
.map(Some)
.or_else(rmp_decode_err)
}
}
/// Handle a msgpack encode error
fn rmp_encode_err(err: rmp_serde::encode::Error) -> ShellError {
match err {
rmp_serde::encode::Error::InvalidValueWrite(_) => {
// I/O error
ShellError::IOError {
msg: err.to_string(),
}
}
_ => {
// Something else
ShellError::PluginFailedToEncode {
msg: err.to_string(),
}
}
}
}
/// Handle a msgpack decode error. Returns `Ok(None)` on eof
fn rmp_decode_err<T>(err: rmp_serde::decode::Error) -> Result<Option<T>, ShellError> {
match err {
rmp_serde::decode::Error::InvalidMarkerRead(err)
| rmp_serde::decode::Error::InvalidDataRead(err) => {
if matches!(err.kind(), ErrorKind::UnexpectedEof) {
// EOF
Ok(None)
} else {
// I/O error
Err(ShellError::IOError {
msg: err.to_string(),
})
}
}
_ => {
// Something else
Err(ShellError::PluginFailedToDecode {
msg: err.to_string(),
})
}
}
}
#[cfg(test)]
mod tests {
use super::*;
crate::serializers::tests::generate_tests!(MsgPackSerializer {});
}

View File

@ -0,0 +1,557 @@
macro_rules! generate_tests {
($encoder:expr) => {
use nu_plugin_protocol::{
CallInfo, CustomValueOp, EvaluatedCall, PipelineDataHeader, PluginCall,
PluginCallResponse, PluginCustomValue, PluginInput, PluginOption, PluginOutput,
StreamData,
};
use nu_protocol::{
LabeledError, PluginSignature, Signature, Span, Spanned, SyntaxShape, Value,
};
#[test]
fn decode_eof() {
let mut buffer: &[u8] = &[];
let encoder = $encoder;
let result: Option<PluginInput> = encoder
.decode(&mut buffer)
.expect("eof should not result in an error");
assert!(result.is_none(), "decode result: {result:?}");
let result: Option<PluginOutput> = encoder
.decode(&mut buffer)
.expect("eof should not result in an error");
assert!(result.is_none(), "decode result: {result:?}");
}
#[test]
fn decode_io_error() {
struct ErrorProducer;
impl std::io::Read for ErrorProducer {
fn read(&mut self, _buf: &mut [u8]) -> std::io::Result<usize> {
Err(std::io::Error::from(std::io::ErrorKind::ConnectionReset))
}
}
let encoder = $encoder;
let mut buffered = std::io::BufReader::new(ErrorProducer);
match Encoder::<PluginInput>::decode(&encoder, &mut buffered) {
Ok(_) => panic!("decode: i/o error was not passed through"),
Err(ShellError::IOError { .. }) => (), // okay
Err(other) => panic!(
"decode: got other error, should have been a \
ShellError::IOError: {other:?}"
),
}
match Encoder::<PluginOutput>::decode(&encoder, &mut buffered) {
Ok(_) => panic!("decode: i/o error was not passed through"),
Err(ShellError::IOError { .. }) => (), // okay
Err(other) => panic!(
"decode: got other error, should have been a \
ShellError::IOError: {other:?}"
),
}
}
#[test]
fn decode_gibberish() {
// just a sequence of bytes that shouldn't be valid in anything we use
let gibberish: &[u8] = &[
0, 80, 74, 85, 117, 122, 86, 100, 74, 115, 20, 104, 55, 98, 67, 203, 83, 85, 77,
112, 74, 79, 254, 71, 80,
];
let encoder = $encoder;
let mut buffered = std::io::BufReader::new(&gibberish[..]);
match Encoder::<PluginInput>::decode(&encoder, &mut buffered) {
Ok(value) => panic!("decode: parsed successfully => {value:?}"),
Err(ShellError::PluginFailedToDecode { .. }) => (), // okay
Err(other) => panic!(
"decode: got other error, should have been a \
ShellError::PluginFailedToDecode: {other:?}"
),
}
let mut buffered = std::io::BufReader::new(&gibberish[..]);
match Encoder::<PluginOutput>::decode(&encoder, &mut buffered) {
Ok(value) => panic!("decode: parsed successfully => {value:?}"),
Err(ShellError::PluginFailedToDecode { .. }) => (), // okay
Err(other) => panic!(
"decode: got other error, should have been a \
ShellError::PluginFailedToDecode: {other:?}"
),
}
}
#[test]
fn call_round_trip_signature() {
let plugin_call = PluginCall::Signature;
let plugin_input = PluginInput::Call(0, plugin_call);
let encoder = $encoder;
let mut buffer: Vec<u8> = Vec::new();
encoder
.encode(&plugin_input, &mut buffer)
.expect("unable to serialize message");
let returned = encoder
.decode(&mut buffer.as_slice())
.expect("unable to deserialize message")
.expect("eof");
match returned {
PluginInput::Call(0, PluginCall::Signature) => {}
_ => panic!("decoded into wrong value: {returned:?}"),
}
}
#[test]
fn call_round_trip_run() {
let name = "test".to_string();
let input = Value::bool(false, Span::new(1, 20));
let call = EvaluatedCall {
head: Span::new(0, 10),
positional: vec![
Value::float(1.0, Span::new(0, 10)),
Value::string("something", Span::new(0, 10)),
],
named: vec![(
Spanned {
item: "name".to_string(),
span: Span::new(0, 10),
},
Some(Value::float(1.0, Span::new(0, 10))),
)],
};
let plugin_call = PluginCall::Run(CallInfo {
name: name.clone(),
call: call.clone(),
input: PipelineDataHeader::Value(input.clone()),
});
let plugin_input = PluginInput::Call(1, plugin_call);
let encoder = $encoder;
let mut buffer: Vec<u8> = Vec::new();
encoder
.encode(&plugin_input, &mut buffer)
.expect("unable to serialize message");
let returned = encoder
.decode(&mut buffer.as_slice())
.expect("unable to deserialize message")
.expect("eof");
match returned {
PluginInput::Call(1, PluginCall::Run(call_info)) => {
assert_eq!(name, call_info.name);
assert_eq!(PipelineDataHeader::Value(input), call_info.input);
assert_eq!(call.head, call_info.call.head);
assert_eq!(call.positional.len(), call_info.call.positional.len());
call.positional
.iter()
.zip(call_info.call.positional.iter())
.for_each(|(lhs, rhs)| assert_eq!(lhs, rhs));
call.named
.iter()
.zip(call_info.call.named.iter())
.for_each(|(lhs, rhs)| {
// Comparing the keys
assert_eq!(lhs.0.item, rhs.0.item);
match (&lhs.1, &rhs.1) {
(None, None) => {}
(Some(a), Some(b)) => assert_eq!(a, b),
_ => panic!("not matching values"),
}
});
}
_ => panic!("decoded into wrong value: {returned:?}"),
}
}
#[test]
fn call_round_trip_customvalueop() {
let data = vec![1, 2, 3, 4, 5, 6, 7];
let span = Span::new(0, 20);
let custom_value_op = PluginCall::CustomValueOp(
Spanned {
item: PluginCustomValue::new("Foo".into(), data.clone(), false),
span,
},
CustomValueOp::ToBaseValue,
);
let plugin_input = PluginInput::Call(2, custom_value_op);
let encoder = $encoder;
let mut buffer: Vec<u8> = Vec::new();
encoder
.encode(&plugin_input, &mut buffer)
.expect("unable to serialize message");
let returned = encoder
.decode(&mut buffer.as_slice())
.expect("unable to deserialize message")
.expect("eof");
match returned {
PluginInput::Call(2, PluginCall::CustomValueOp(val, op)) => {
assert_eq!("Foo", val.item.name());
assert_eq!(data, val.item.data());
assert_eq!(span, val.span);
#[allow(unreachable_patterns)]
match op {
CustomValueOp::ToBaseValue => (),
_ => panic!("wrong op: {op:?}"),
}
}
_ => panic!("decoded into wrong value: {returned:?}"),
}
}
#[test]
fn response_round_trip_signature() {
let signature = PluginSignature::new(
Signature::build("nu-plugin")
.required("first", SyntaxShape::String, "first required")
.required("second", SyntaxShape::Int, "second required")
.required_named("first-named", SyntaxShape::String, "first named", Some('f'))
.required_named(
"second-named",
SyntaxShape::String,
"second named",
Some('s'),
)
.rest("remaining", SyntaxShape::Int, "remaining"),
vec![],
);
let response = PluginCallResponse::Signature(vec![signature.clone()]);
let output = PluginOutput::CallResponse(3, response);
let encoder = $encoder;
let mut buffer: Vec<u8> = Vec::new();
encoder
.encode(&output, &mut buffer)
.expect("unable to serialize message");
let returned = encoder
.decode(&mut buffer.as_slice())
.expect("unable to deserialize message")
.expect("eof");
match returned {
PluginOutput::CallResponse(
3,
PluginCallResponse::Signature(returned_signature),
) => {
assert_eq!(returned_signature.len(), 1);
assert_eq!(signature.sig.name, returned_signature[0].sig.name);
assert_eq!(signature.sig.usage, returned_signature[0].sig.usage);
assert_eq!(
signature.sig.extra_usage,
returned_signature[0].sig.extra_usage
);
assert_eq!(signature.sig.is_filter, returned_signature[0].sig.is_filter);
signature
.sig
.required_positional
.iter()
.zip(returned_signature[0].sig.required_positional.iter())
.for_each(|(lhs, rhs)| assert_eq!(lhs, rhs));
signature
.sig
.optional_positional
.iter()
.zip(returned_signature[0].sig.optional_positional.iter())
.for_each(|(lhs, rhs)| assert_eq!(lhs, rhs));
signature
.sig
.named
.iter()
.zip(returned_signature[0].sig.named.iter())
.for_each(|(lhs, rhs)| assert_eq!(lhs, rhs));
assert_eq!(
signature.sig.rest_positional,
returned_signature[0].sig.rest_positional,
);
}
_ => panic!("decoded into wrong value: {returned:?}"),
}
}
#[test]
fn response_round_trip_value() {
let value = Value::int(10, Span::new(2, 30));
let response = PluginCallResponse::value(value.clone());
let output = PluginOutput::CallResponse(4, response);
let encoder = $encoder;
let mut buffer: Vec<u8> = Vec::new();
encoder
.encode(&output, &mut buffer)
.expect("unable to serialize message");
let returned = encoder
.decode(&mut buffer.as_slice())
.expect("unable to deserialize message")
.expect("eof");
match returned {
PluginOutput::CallResponse(
4,
PluginCallResponse::PipelineData(PipelineDataHeader::Value(returned_value)),
) => {
assert_eq!(value, returned_value)
}
_ => panic!("decoded into wrong value: {returned:?}"),
}
}
#[test]
fn response_round_trip_plugin_custom_value() {
let name = "test";
let data = vec![1, 2, 3, 4, 5];
let span = Span::new(2, 30);
let value = Value::custom(
Box::new(PluginCustomValue::new(name.into(), data.clone(), true)),
span,
);
let response = PluginCallResponse::PipelineData(PipelineDataHeader::Value(value));
let output = PluginOutput::CallResponse(5, response);
let encoder = $encoder;
let mut buffer: Vec<u8> = Vec::new();
encoder
.encode(&output, &mut buffer)
.expect("unable to serialize message");
let returned = encoder
.decode(&mut buffer.as_slice())
.expect("unable to deserialize message")
.expect("eof");
match returned {
PluginOutput::CallResponse(
5,
PluginCallResponse::PipelineData(PipelineDataHeader::Value(returned_value)),
) => {
assert_eq!(span, returned_value.span());
if let Some(plugin_val) = returned_value
.as_custom_value()
.unwrap()
.as_any()
.downcast_ref::<PluginCustomValue>()
{
assert_eq!(name, plugin_val.name());
assert_eq!(data, plugin_val.data());
assert!(plugin_val.notify_on_drop());
} else {
panic!("returned CustomValue is not a PluginCustomValue");
}
}
_ => panic!("decoded into wrong value: {returned:?}"),
}
}
#[test]
fn response_round_trip_error() {
let error = LabeledError::new("label")
.with_code("test::error")
.with_url("https://example.org/test/error")
.with_help("some help")
.with_label("msg", Span::new(2, 30))
.with_inner(ShellError::IOError {
msg: "io error".into(),
});
let response = PluginCallResponse::Error(error.clone());
let output = PluginOutput::CallResponse(6, response);
let encoder = $encoder;
let mut buffer: Vec<u8> = Vec::new();
encoder
.encode(&output, &mut buffer)
.expect("unable to serialize message");
let returned = encoder
.decode(&mut buffer.as_slice())
.expect("unable to deserialize message")
.expect("eof");
match returned {
PluginOutput::CallResponse(6, PluginCallResponse::Error(msg)) => {
assert_eq!(error, msg)
}
_ => panic!("decoded into wrong value: {returned:?}"),
}
}
#[test]
fn response_round_trip_error_none() {
let error = LabeledError::new("error");
let response = PluginCallResponse::Error(error.clone());
let output = PluginOutput::CallResponse(7, response);
let encoder = $encoder;
let mut buffer: Vec<u8> = Vec::new();
encoder
.encode(&output, &mut buffer)
.expect("unable to serialize message");
let returned = encoder
.decode(&mut buffer.as_slice())
.expect("unable to deserialize message")
.expect("eof");
match returned {
PluginOutput::CallResponse(7, PluginCallResponse::Error(msg)) => {
assert_eq!(error, msg)
}
_ => panic!("decoded into wrong value: {returned:?}"),
}
}
#[test]
fn input_round_trip_stream_data_list() {
let span = Span::new(12, 30);
let item = Value::int(1, span);
let stream_data = StreamData::List(item.clone());
let plugin_input = PluginInput::Data(0, stream_data);
let encoder = $encoder;
let mut buffer: Vec<u8> = Vec::new();
encoder
.encode(&plugin_input, &mut buffer)
.expect("unable to serialize message");
let returned = encoder
.decode(&mut buffer.as_slice())
.expect("unable to deserialize message")
.expect("eof");
match returned {
PluginInput::Data(id, StreamData::List(list_data)) => {
assert_eq!(0, id);
assert_eq!(item, list_data);
}
_ => panic!("decoded into wrong value: {returned:?}"),
}
}
#[test]
fn input_round_trip_stream_data_raw() {
let data = b"Hello world";
let stream_data = StreamData::Raw(Ok(data.to_vec()));
let plugin_input = PluginInput::Data(1, stream_data);
let encoder = $encoder;
let mut buffer: Vec<u8> = Vec::new();
encoder
.encode(&plugin_input, &mut buffer)
.expect("unable to serialize message");
let returned = encoder
.decode(&mut buffer.as_slice())
.expect("unable to deserialize message")
.expect("eof");
match returned {
PluginInput::Data(id, StreamData::Raw(bytes)) => {
assert_eq!(1, id);
match bytes {
Ok(bytes) => assert_eq!(data, &bytes[..]),
Err(err) => panic!("decoded into error variant: {err:?}"),
}
}
_ => panic!("decoded into wrong value: {returned:?}"),
}
}
#[test]
fn output_round_trip_stream_data_list() {
let span = Span::new(12, 30);
let item = Value::int(1, span);
let stream_data = StreamData::List(item.clone());
let plugin_output = PluginOutput::Data(4, stream_data);
let encoder = $encoder;
let mut buffer: Vec<u8> = Vec::new();
encoder
.encode(&plugin_output, &mut buffer)
.expect("unable to serialize message");
let returned = encoder
.decode(&mut buffer.as_slice())
.expect("unable to deserialize message")
.expect("eof");
match returned {
PluginOutput::Data(id, StreamData::List(list_data)) => {
assert_eq!(4, id);
assert_eq!(item, list_data);
}
_ => panic!("decoded into wrong value: {returned:?}"),
}
}
#[test]
fn output_round_trip_stream_data_raw() {
let data = b"Hello world";
let stream_data = StreamData::Raw(Ok(data.to_vec()));
let plugin_output = PluginOutput::Data(5, stream_data);
let encoder = $encoder;
let mut buffer: Vec<u8> = Vec::new();
encoder
.encode(&plugin_output, &mut buffer)
.expect("unable to serialize message");
let returned = encoder
.decode(&mut buffer.as_slice())
.expect("unable to deserialize message")
.expect("eof");
match returned {
PluginOutput::Data(id, StreamData::Raw(bytes)) => {
assert_eq!(5, id);
match bytes {
Ok(bytes) => assert_eq!(data, &bytes[..]),
Err(err) => panic!("decoded into error variant: {err:?}"),
}
}
_ => panic!("decoded into wrong value: {returned:?}"),
}
}
#[test]
fn output_round_trip_option() {
let plugin_output = PluginOutput::Option(PluginOption::GcDisabled(true));
let encoder = $encoder;
let mut buffer: Vec<u8> = Vec::new();
encoder
.encode(&plugin_output, &mut buffer)
.expect("unable to serialize message");
let returned = encoder
.decode(&mut buffer.as_slice())
.expect("unable to deserialize message")
.expect("eof");
match returned {
PluginOutput::Option(PluginOption::GcDisabled(disabled)) => {
assert!(disabled);
}
_ => panic!("decoded into wrong value: {returned:?}"),
}
}
};
}
pub(crate) use generate_tests;

View File

@ -0,0 +1,7 @@
mod sequence;
mod waitable;
mod with_custom_values_in;
pub use sequence::Sequence;
pub use waitable::*;
pub use with_custom_values_in::with_custom_values_in;

View File

@ -0,0 +1,64 @@
use nu_protocol::ShellError;
use std::sync::atomic::{AtomicUsize, Ordering::Relaxed};
/// Implements an atomically incrementing sequential series of numbers
#[derive(Debug, Default)]
pub struct Sequence(AtomicUsize);
impl Sequence {
/// Return the next available id from a sequence, returning an error on overflow
#[track_caller]
pub fn next(&self) -> Result<usize, ShellError> {
// It's totally safe to use Relaxed ordering here, as there aren't other memory operations
// that depend on this value having been set for safety
//
// We're only not using `fetch_add` so that we can check for overflow, as wrapping with the
// identifier would lead to a serious bug - however unlikely that is.
self.0
.fetch_update(Relaxed, Relaxed, |current| current.checked_add(1))
.map_err(|_| ShellError::NushellFailedHelp {
msg: "an accumulator for identifiers overflowed".into(),
help: format!("see {}", std::panic::Location::caller()),
})
}
}
#[test]
fn output_is_sequential() {
let sequence = Sequence::default();
for (expected, generated) in (0..1000).zip(std::iter::repeat_with(|| sequence.next())) {
assert_eq!(expected, generated.expect("error in sequence"));
}
}
#[test]
fn output_is_unique_even_under_contention() {
let sequence = Sequence::default();
std::thread::scope(|scope| {
// Spawn four threads, all advancing the sequence simultaneously
let threads = (0..4)
.map(|_| {
scope.spawn(|| {
(0..100000)
.map(|_| sequence.next())
.collect::<Result<Vec<_>, _>>()
})
})
.collect::<Vec<_>>();
// Collect all of the results into a single flat vec
let mut results = threads
.into_iter()
.flat_map(|thread| thread.join().expect("panicked").expect("error"))
.collect::<Vec<usize>>();
// Check uniqueness
results.sort();
let initial_length = results.len();
results.dedup();
let deduplicated_length = results.len();
assert_eq!(initial_length, deduplicated_length);
})
}

View File

@ -0,0 +1,181 @@
use std::sync::{
atomic::{AtomicBool, Ordering},
Arc, Condvar, Mutex, MutexGuard, PoisonError,
};
use nu_protocol::ShellError;
/// A shared container that may be empty, and allows threads to block until it has a value.
///
/// This side is read-only - use [`WaitableMut`] on threads that might write a value.
#[derive(Debug, Clone)]
pub struct Waitable<T: Clone + Send> {
shared: Arc<WaitableShared<T>>,
}
#[derive(Debug)]
pub struct WaitableMut<T: Clone + Send> {
shared: Arc<WaitableShared<T>>,
}
#[derive(Debug)]
struct WaitableShared<T: Clone + Send> {
is_set: AtomicBool,
mutex: Mutex<SyncState<T>>,
condvar: Condvar,
}
#[derive(Debug)]
struct SyncState<T: Clone + Send> {
writers: usize,
value: Option<T>,
}
#[track_caller]
fn fail_if_poisoned<'a, T>(
result: Result<MutexGuard<'a, T>, PoisonError<MutexGuard<'a, T>>>,
) -> Result<MutexGuard<'a, T>, ShellError> {
match result {
Ok(guard) => Ok(guard),
Err(_) => Err(ShellError::NushellFailedHelp {
msg: "Waitable mutex poisoned".into(),
help: std::panic::Location::caller().to_string(),
}),
}
}
impl<T: Clone + Send> WaitableMut<T> {
/// Create a new empty `WaitableMut`. Call [`.reader()`] to get [`Waitable`].
pub fn new() -> WaitableMut<T> {
WaitableMut {
shared: Arc::new(WaitableShared {
is_set: AtomicBool::new(false),
mutex: Mutex::new(SyncState {
writers: 1,
value: None,
}),
condvar: Condvar::new(),
}),
}
}
pub fn reader(&self) -> Waitable<T> {
Waitable {
shared: self.shared.clone(),
}
}
/// Set the value and let waiting threads know.
#[track_caller]
pub fn set(&self, value: T) -> Result<(), ShellError> {
let mut sync_state = fail_if_poisoned(self.shared.mutex.lock())?;
self.shared.is_set.store(true, Ordering::SeqCst);
sync_state.value = Some(value);
self.shared.condvar.notify_all();
Ok(())
}
}
impl<T: Clone + Send> Default for WaitableMut<T> {
fn default() -> Self {
Self::new()
}
}
impl<T: Clone + Send> Clone for WaitableMut<T> {
fn clone(&self) -> Self {
let shared = self.shared.clone();
shared
.mutex
.lock()
.expect("failed to lock mutex to increment writers")
.writers += 1;
WaitableMut { shared }
}
}
impl<T: Clone + Send> Drop for WaitableMut<T> {
fn drop(&mut self) {
// Decrement writers...
if let Ok(mut sync_state) = self.shared.mutex.lock() {
sync_state.writers = sync_state
.writers
.checked_sub(1)
.expect("would decrement writers below zero");
}
// and notify waiting threads so they have a chance to see it.
self.shared.condvar.notify_all();
}
}
impl<T: Clone + Send> Waitable<T> {
/// Wait for a value to be available and then clone it.
///
/// Returns `Ok(None)` if there are no writers left that could possibly place a value.
#[track_caller]
pub fn get(&self) -> Result<Option<T>, ShellError> {
let sync_state = fail_if_poisoned(self.shared.mutex.lock())?;
if let Some(value) = sync_state.value.clone() {
Ok(Some(value))
} else if sync_state.writers == 0 {
// There can't possibly be a value written, so no point in waiting.
Ok(None)
} else {
let sync_state = fail_if_poisoned(
self.shared
.condvar
.wait_while(sync_state, |g| g.writers > 0 && g.value.is_none()),
)?;
Ok(sync_state.value.clone())
}
}
/// Clone the value if one is available, but don't wait if not.
#[track_caller]
pub fn try_get(&self) -> Result<Option<T>, ShellError> {
let sync_state = fail_if_poisoned(self.shared.mutex.lock())?;
Ok(sync_state.value.clone())
}
/// Returns true if value is available.
#[track_caller]
pub fn is_set(&self) -> bool {
self.shared.is_set.load(Ordering::SeqCst)
}
}
#[test]
fn set_from_other_thread() -> Result<(), ShellError> {
let waitable_mut = WaitableMut::new();
let waitable = waitable_mut.reader();
assert!(!waitable.is_set());
std::thread::spawn(move || {
waitable_mut.set(42).expect("error on set");
});
assert_eq!(Some(42), waitable.get()?);
assert_eq!(Some(42), waitable.try_get()?);
assert!(waitable.is_set());
Ok(())
}
#[test]
fn dont_deadlock_if_waiting_without_writer() {
use std::time::Duration;
let (tx, rx) = std::sync::mpsc::channel();
let writer = WaitableMut::<()>::new();
let waitable = writer.reader();
// Ensure there are no writers
drop(writer);
std::thread::spawn(move || {
let _ = tx.send(waitable.get());
});
let result = rx
.recv_timeout(Duration::from_secs(10))
.expect("timed out")
.expect("error");
assert!(result.is_none());
}

View File

@ -0,0 +1,96 @@
use nu_protocol::{CustomValue, IntoSpanned, ShellError, Spanned, Value};
/// Do something with all [`CustomValue`]s recursively within a `Value`. This is not limited to
/// plugin custom values.
///
/// `LazyRecord`s will be collected to plain values for completeness.
pub fn with_custom_values_in<E>(
value: &mut Value,
mut f: impl FnMut(Spanned<&mut Box<dyn CustomValue>>) -> Result<(), E>,
) -> Result<(), E>
where
E: From<ShellError>,
{
value.recurse_mut(&mut |value| {
let span = value.span();
match value {
Value::Custom { val, .. } => {
// Operate on a CustomValue.
f(val.into_spanned(span))
}
// LazyRecord would be a problem for us, since it could return something else the
// next time, and we have to collect it anyway to serialize it. Collect it in place,
// and then use the result
Value::LazyRecord { val, .. } => {
*value = val.collect()?;
Ok(())
}
_ => Ok(()),
}
})
}
#[test]
fn find_custom_values() {
use nu_plugin_protocol::test_util::test_plugin_custom_value;
use nu_protocol::{engine::Closure, record, LazyRecord, Span};
#[derive(Debug, Clone)]
struct Lazy;
impl<'a> LazyRecord<'a> for Lazy {
fn column_names(&'a self) -> Vec<&'a str> {
vec!["custom", "plain"]
}
fn get_column_value(&self, column: &str) -> Result<Value, ShellError> {
Ok(match column {
"custom" => Value::test_custom_value(Box::new(test_plugin_custom_value())),
"plain" => Value::test_int(42),
_ => unimplemented!(),
})
}
fn span(&self) -> Span {
Span::test_data()
}
fn clone_value(&self, span: Span) -> Value {
Value::lazy_record(Box::new(self.clone()), span)
}
}
let mut cv = Value::test_custom_value(Box::new(test_plugin_custom_value()));
let mut value = Value::test_record(record! {
"bare" => cv.clone(),
"list" => Value::test_list(vec![
cv.clone(),
Value::test_int(4),
]),
"closure" => Value::test_closure(
Closure {
block_id: 0,
captures: vec![(0, cv.clone()), (1, Value::test_string("foo"))]
}
),
"lazy" => Value::test_lazy_record(Box::new(Lazy)),
});
// Do with_custom_values_in, and count the number of custom values found
let mut found = 0;
with_custom_values_in::<ShellError>(&mut value, |_| {
found += 1;
Ok(())
})
.expect("error");
assert_eq!(4, found, "found in value");
// Try it on bare custom value too
found = 0;
with_custom_values_in::<ShellError>(&mut cv, |_| {
found += 1;
Ok(())
})
.expect("error");
assert_eq!(1, found, "bare custom value didn't work");
}