mirror of
https://github.com/nushell/nushell.git
synced 2025-08-09 16:05:01 +02:00
Split the plugin crate (#12563)
# Description This breaks `nu-plugin` up into four crates: - `nu-plugin-protocol`: just the type definitions for the protocol, no I/O. If someone wanted to wire up something more bare metal, maybe for async I/O, they could use this. - `nu-plugin-core`: the shared stuff between engine/plugin. Less stable interface. - `nu-plugin-engine`: everything required for the engine to talk to plugins. Less stable interface. - `nu-plugin`: everything required for the plugin to talk to the engine, what plugin developers use. Should be the most stable interface. No changes are made to the interface exposed by `nu-plugin` - it should all still be there. Re-exports from `nu-plugin-protocol` or `nu-plugin-core` are used as required. Plugins shouldn't ever have to use those crates directly. This should be somewhat faster to compile as `nu-plugin-engine` and `nu-plugin` can compile in parallel, and the engine doesn't need `nu-plugin` and plugins don't need `nu-plugin-engine` (except for test support), so that should reduce what needs to be compiled too. The only significant change here other than splitting stuff up was to break the `source` out of `PluginCustomValue` and create a new `PluginCustomValueWithSource` type that contains that instead. One bonus of that is we get rid of the option and it's now more type-safe, but it also means that the logic for that stuff (actually running the plugin for custom value ops) can live entirely within the `nu-plugin-engine` crate. # User-Facing Changes - New crates. - Added `local-socket` feature for `nu` to try to make it possible to compile without that support if needed. # Tests + Formatting - 🟢 `toolkit fmt` - 🟢 `toolkit clippy` - 🟢 `toolkit test` - 🟢 `toolkit test stdlib`
This commit is contained in:
453
crates/nu-plugin-core/src/interface/mod.rs
Normal file
453
crates/nu-plugin-core/src/interface/mod.rs
Normal file
@ -0,0 +1,453 @@
|
||||
//! Implements the stream multiplexing interface for both the plugin side and the engine side.
|
||||
|
||||
use nu_plugin_protocol::{
|
||||
ExternalStreamInfo, ListStreamInfo, PipelineDataHeader, RawStreamInfo, StreamMessage,
|
||||
};
|
||||
use nu_protocol::{ListStream, PipelineData, RawStream, ShellError};
|
||||
use std::{
|
||||
io::Write,
|
||||
sync::{
|
||||
atomic::{AtomicBool, Ordering::Relaxed},
|
||||
Arc, Mutex,
|
||||
},
|
||||
thread,
|
||||
};
|
||||
|
||||
pub mod stream;
|
||||
|
||||
use crate::{util::Sequence, Encoder};
|
||||
|
||||
use self::stream::{StreamManager, StreamManagerHandle, StreamWriter, WriteStreamMessage};
|
||||
|
||||
pub mod test_util;
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests;
|
||||
|
||||
/// The maximum number of list stream values to send without acknowledgement. This should be tuned
|
||||
/// with consideration for memory usage.
|
||||
const LIST_STREAM_HIGH_PRESSURE: i32 = 100;
|
||||
|
||||
/// The maximum number of raw stream buffers to send without acknowledgement. This should be tuned
|
||||
/// with consideration for memory usage.
|
||||
const RAW_STREAM_HIGH_PRESSURE: i32 = 50;
|
||||
|
||||
/// Read input/output from the stream.
|
||||
pub trait PluginRead<T> {
|
||||
/// Returns `Ok(None)` on end of stream.
|
||||
fn read(&mut self) -> Result<Option<T>, ShellError>;
|
||||
}
|
||||
|
||||
impl<R, E, T> PluginRead<T> for (R, E)
|
||||
where
|
||||
R: std::io::BufRead,
|
||||
E: Encoder<T>,
|
||||
{
|
||||
fn read(&mut self) -> Result<Option<T>, ShellError> {
|
||||
self.1.decode(&mut self.0)
|
||||
}
|
||||
}
|
||||
|
||||
impl<R, T> PluginRead<T> for &mut R
|
||||
where
|
||||
R: PluginRead<T>,
|
||||
{
|
||||
fn read(&mut self) -> Result<Option<T>, ShellError> {
|
||||
(**self).read()
|
||||
}
|
||||
}
|
||||
|
||||
/// Write input/output to the stream.
|
||||
///
|
||||
/// The write should be atomic, without interference from other threads.
|
||||
pub trait PluginWrite<T>: Send + Sync {
|
||||
fn write(&self, data: &T) -> Result<(), ShellError>;
|
||||
|
||||
/// Flush any internal buffers, if applicable.
|
||||
fn flush(&self) -> Result<(), ShellError>;
|
||||
|
||||
/// True if this output is stdout, so that plugins can avoid using stdout for their own purpose
|
||||
fn is_stdout(&self) -> bool {
|
||||
false
|
||||
}
|
||||
}
|
||||
|
||||
impl<E, T> PluginWrite<T> for (std::io::Stdout, E)
|
||||
where
|
||||
E: Encoder<T>,
|
||||
{
|
||||
fn write(&self, data: &T) -> Result<(), ShellError> {
|
||||
let mut lock = self.0.lock();
|
||||
self.1.encode(data, &mut lock)
|
||||
}
|
||||
|
||||
fn flush(&self) -> Result<(), ShellError> {
|
||||
self.0.lock().flush().map_err(|err| ShellError::IOError {
|
||||
msg: err.to_string(),
|
||||
})
|
||||
}
|
||||
|
||||
fn is_stdout(&self) -> bool {
|
||||
true
|
||||
}
|
||||
}
|
||||
|
||||
impl<W, E, T> PluginWrite<T> for (Mutex<W>, E)
|
||||
where
|
||||
W: std::io::Write + Send,
|
||||
E: Encoder<T>,
|
||||
{
|
||||
fn write(&self, data: &T) -> Result<(), ShellError> {
|
||||
let mut lock = self.0.lock().map_err(|_| ShellError::NushellFailed {
|
||||
msg: "writer mutex poisoned".into(),
|
||||
})?;
|
||||
self.1.encode(data, &mut *lock)
|
||||
}
|
||||
|
||||
fn flush(&self) -> Result<(), ShellError> {
|
||||
let mut lock = self.0.lock().map_err(|_| ShellError::NushellFailed {
|
||||
msg: "writer mutex poisoned".into(),
|
||||
})?;
|
||||
lock.flush().map_err(|err| ShellError::IOError {
|
||||
msg: err.to_string(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl<W, T> PluginWrite<T> for &W
|
||||
where
|
||||
W: PluginWrite<T>,
|
||||
{
|
||||
fn write(&self, data: &T) -> Result<(), ShellError> {
|
||||
(**self).write(data)
|
||||
}
|
||||
|
||||
fn flush(&self) -> Result<(), ShellError> {
|
||||
(**self).flush()
|
||||
}
|
||||
|
||||
fn is_stdout(&self) -> bool {
|
||||
(**self).is_stdout()
|
||||
}
|
||||
}
|
||||
|
||||
/// An interface manager handles I/O and state management for communication between a plugin and
|
||||
/// the engine. See `PluginInterfaceManager` in `nu-plugin-engine` for communication from the engine
|
||||
/// side to a plugin, or `EngineInterfaceManager` in `nu-plugin` for communication from the plugin
|
||||
/// side to the engine.
|
||||
///
|
||||
/// There is typically one [`InterfaceManager`] consuming input from a background thread, and
|
||||
/// managing shared state.
|
||||
pub trait InterfaceManager {
|
||||
/// The corresponding interface type.
|
||||
type Interface: Interface + 'static;
|
||||
|
||||
/// The input message type.
|
||||
type Input;
|
||||
|
||||
/// Make a new interface that communicates with this [`InterfaceManager`].
|
||||
fn get_interface(&self) -> Self::Interface;
|
||||
|
||||
/// Consume an input message.
|
||||
///
|
||||
/// When implementing, call [`.consume_stream_message()`] for any encapsulated
|
||||
/// [`StreamMessage`]s received.
|
||||
fn consume(&mut self, input: Self::Input) -> Result<(), ShellError>;
|
||||
|
||||
/// Get the [`StreamManager`] for handling operations related to stream messages.
|
||||
fn stream_manager(&self) -> &StreamManager;
|
||||
|
||||
/// Prepare [`PipelineData`] after reading. This is called by `read_pipeline_data()` as
|
||||
/// a hook so that values that need special handling can be taken care of.
|
||||
fn prepare_pipeline_data(&self, data: PipelineData) -> Result<PipelineData, ShellError>;
|
||||
|
||||
/// Consume an input stream message.
|
||||
///
|
||||
/// This method is provided for implementors to use.
|
||||
fn consume_stream_message(&mut self, message: StreamMessage) -> Result<(), ShellError> {
|
||||
self.stream_manager().handle_message(message)
|
||||
}
|
||||
|
||||
/// Generate `PipelineData` for reading a stream, given a [`PipelineDataHeader`] that was
|
||||
/// received from the other side.
|
||||
///
|
||||
/// This method is provided for implementors to use.
|
||||
fn read_pipeline_data(
|
||||
&self,
|
||||
header: PipelineDataHeader,
|
||||
ctrlc: Option<&Arc<AtomicBool>>,
|
||||
) -> Result<PipelineData, ShellError> {
|
||||
self.prepare_pipeline_data(match header {
|
||||
PipelineDataHeader::Empty => PipelineData::Empty,
|
||||
PipelineDataHeader::Value(value) => PipelineData::Value(value, None),
|
||||
PipelineDataHeader::ListStream(info) => {
|
||||
let handle = self.stream_manager().get_handle();
|
||||
let reader = handle.read_stream(info.id, self.get_interface())?;
|
||||
PipelineData::ListStream(ListStream::from_stream(reader, ctrlc.cloned()), None)
|
||||
}
|
||||
PipelineDataHeader::ExternalStream(info) => {
|
||||
let handle = self.stream_manager().get_handle();
|
||||
let span = info.span;
|
||||
let new_raw_stream = |raw_info: RawStreamInfo| {
|
||||
let reader = handle.read_stream(raw_info.id, self.get_interface())?;
|
||||
let mut stream =
|
||||
RawStream::new(Box::new(reader), ctrlc.cloned(), span, raw_info.known_size);
|
||||
stream.is_binary = raw_info.is_binary;
|
||||
Ok::<_, ShellError>(stream)
|
||||
};
|
||||
PipelineData::ExternalStream {
|
||||
stdout: info.stdout.map(new_raw_stream).transpose()?,
|
||||
stderr: info.stderr.map(new_raw_stream).transpose()?,
|
||||
exit_code: info
|
||||
.exit_code
|
||||
.map(|list_info| {
|
||||
handle
|
||||
.read_stream(list_info.id, self.get_interface())
|
||||
.map(|reader| ListStream::from_stream(reader, ctrlc.cloned()))
|
||||
})
|
||||
.transpose()?,
|
||||
span: info.span,
|
||||
metadata: None,
|
||||
trim_end_newline: info.trim_end_newline,
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
/// An interface provides an API for communicating with a plugin or the engine and facilitates
|
||||
/// stream I/O. See `PluginInterface` in `nu-plugin-engine` for the API from the engine side to a
|
||||
/// plugin, or `EngineInterface` in `nu-plugin` for the API from the plugin side to the engine.
|
||||
///
|
||||
/// There can be multiple copies of the interface managed by a single [`InterfaceManager`].
|
||||
pub trait Interface: Clone + Send {
|
||||
/// The output message type, which must be capable of encapsulating a [`StreamMessage`].
|
||||
type Output: From<StreamMessage>;
|
||||
|
||||
/// Any context required to construct [`PipelineData`]. Can be `()` if not needed.
|
||||
type DataContext;
|
||||
|
||||
/// Write an output message.
|
||||
fn write(&self, output: Self::Output) -> Result<(), ShellError>;
|
||||
|
||||
/// Flush the output buffer, so messages are visible to the other side.
|
||||
fn flush(&self) -> Result<(), ShellError>;
|
||||
|
||||
/// Get the sequence for generating new [`StreamId`](nu_plugin_protocol::StreamId)s.
|
||||
fn stream_id_sequence(&self) -> &Sequence;
|
||||
|
||||
/// Get the [`StreamManagerHandle`] for doing stream operations.
|
||||
fn stream_manager_handle(&self) -> &StreamManagerHandle;
|
||||
|
||||
/// Prepare [`PipelineData`] to be written. This is called by `init_write_pipeline_data()` as
|
||||
/// a hook so that values that need special handling can be taken care of.
|
||||
fn prepare_pipeline_data(
|
||||
&self,
|
||||
data: PipelineData,
|
||||
context: &Self::DataContext,
|
||||
) -> Result<PipelineData, ShellError>;
|
||||
|
||||
/// Initialize a write for [`PipelineData`]. This returns two parts: the header, which can be
|
||||
/// embedded in the particular message that references the stream, and a writer, which will
|
||||
/// write out all of the data in the pipeline when `.write()` is called.
|
||||
///
|
||||
/// Note that not all [`PipelineData`] starts a stream. You should call `write()` anyway, as
|
||||
/// it will automatically handle this case.
|
||||
///
|
||||
/// This method is provided for implementors to use.
|
||||
fn init_write_pipeline_data(
|
||||
&self,
|
||||
data: PipelineData,
|
||||
context: &Self::DataContext,
|
||||
) -> Result<(PipelineDataHeader, PipelineDataWriter<Self>), ShellError> {
|
||||
// Allocate a stream id and a writer
|
||||
let new_stream = |high_pressure_mark: i32| {
|
||||
// Get a free stream id
|
||||
let id = self.stream_id_sequence().next()?;
|
||||
// Create the writer
|
||||
let writer =
|
||||
self.stream_manager_handle()
|
||||
.write_stream(id, self.clone(), high_pressure_mark)?;
|
||||
Ok::<_, ShellError>((id, writer))
|
||||
};
|
||||
match self.prepare_pipeline_data(data, context)? {
|
||||
PipelineData::Value(value, _) => {
|
||||
Ok((PipelineDataHeader::Value(value), PipelineDataWriter::None))
|
||||
}
|
||||
PipelineData::Empty => Ok((PipelineDataHeader::Empty, PipelineDataWriter::None)),
|
||||
PipelineData::ListStream(stream, _) => {
|
||||
let (id, writer) = new_stream(LIST_STREAM_HIGH_PRESSURE)?;
|
||||
Ok((
|
||||
PipelineDataHeader::ListStream(ListStreamInfo { id }),
|
||||
PipelineDataWriter::ListStream(writer, stream),
|
||||
))
|
||||
}
|
||||
PipelineData::ExternalStream {
|
||||
stdout,
|
||||
stderr,
|
||||
exit_code,
|
||||
span,
|
||||
metadata: _,
|
||||
trim_end_newline,
|
||||
} => {
|
||||
// Create the writers and stream ids
|
||||
let stdout_stream = stdout
|
||||
.is_some()
|
||||
.then(|| new_stream(RAW_STREAM_HIGH_PRESSURE))
|
||||
.transpose()?;
|
||||
let stderr_stream = stderr
|
||||
.is_some()
|
||||
.then(|| new_stream(RAW_STREAM_HIGH_PRESSURE))
|
||||
.transpose()?;
|
||||
let exit_code_stream = exit_code
|
||||
.is_some()
|
||||
.then(|| new_stream(LIST_STREAM_HIGH_PRESSURE))
|
||||
.transpose()?;
|
||||
// Generate the header, with the stream ids
|
||||
let header = PipelineDataHeader::ExternalStream(ExternalStreamInfo {
|
||||
span,
|
||||
stdout: stdout
|
||||
.as_ref()
|
||||
.zip(stdout_stream.as_ref())
|
||||
.map(|(stream, (id, _))| RawStreamInfo::new(*id, stream)),
|
||||
stderr: stderr
|
||||
.as_ref()
|
||||
.zip(stderr_stream.as_ref())
|
||||
.map(|(stream, (id, _))| RawStreamInfo::new(*id, stream)),
|
||||
exit_code: exit_code_stream
|
||||
.as_ref()
|
||||
.map(|&(id, _)| ListStreamInfo { id }),
|
||||
trim_end_newline,
|
||||
});
|
||||
// Collect the writers
|
||||
let writer = PipelineDataWriter::ExternalStream {
|
||||
stdout: stdout_stream.map(|(_, writer)| writer).zip(stdout),
|
||||
stderr: stderr_stream.map(|(_, writer)| writer).zip(stderr),
|
||||
exit_code: exit_code_stream.map(|(_, writer)| writer).zip(exit_code),
|
||||
};
|
||||
Ok((header, writer))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> WriteStreamMessage for T
|
||||
where
|
||||
T: Interface,
|
||||
{
|
||||
fn write_stream_message(&mut self, msg: StreamMessage) -> Result<(), ShellError> {
|
||||
self.write(msg.into())
|
||||
}
|
||||
|
||||
fn flush(&mut self) -> Result<(), ShellError> {
|
||||
<Self as Interface>::flush(self)
|
||||
}
|
||||
}
|
||||
|
||||
/// Completes the write operation for a [`PipelineData`]. You must call
|
||||
/// [`PipelineDataWriter::write()`] to write all of the data contained within the streams.
|
||||
#[derive(Default)]
|
||||
#[must_use]
|
||||
pub enum PipelineDataWriter<W: WriteStreamMessage> {
|
||||
#[default]
|
||||
None,
|
||||
ListStream(StreamWriter<W>, ListStream),
|
||||
ExternalStream {
|
||||
stdout: Option<(StreamWriter<W>, RawStream)>,
|
||||
stderr: Option<(StreamWriter<W>, RawStream)>,
|
||||
exit_code: Option<(StreamWriter<W>, ListStream)>,
|
||||
},
|
||||
}
|
||||
|
||||
impl<W> PipelineDataWriter<W>
|
||||
where
|
||||
W: WriteStreamMessage + Send + 'static,
|
||||
{
|
||||
/// Write all of the data in each of the streams. This method waits for completion.
|
||||
pub fn write(self) -> Result<(), ShellError> {
|
||||
match self {
|
||||
// If no stream was contained in the PipelineData, do nothing.
|
||||
PipelineDataWriter::None => Ok(()),
|
||||
// Write a list stream.
|
||||
PipelineDataWriter::ListStream(mut writer, stream) => {
|
||||
writer.write_all(stream)?;
|
||||
Ok(())
|
||||
}
|
||||
// Write all three possible streams of an ExternalStream on separate threads.
|
||||
PipelineDataWriter::ExternalStream {
|
||||
stdout,
|
||||
stderr,
|
||||
exit_code,
|
||||
} => {
|
||||
thread::scope(|scope| {
|
||||
let stderr_thread = stderr
|
||||
.map(|(mut writer, stream)| {
|
||||
thread::Builder::new()
|
||||
.name("plugin stderr writer".into())
|
||||
.spawn_scoped(scope, move || {
|
||||
writer.write_all(raw_stream_iter(stream))
|
||||
})
|
||||
})
|
||||
.transpose()?;
|
||||
let exit_code_thread = exit_code
|
||||
.map(|(mut writer, stream)| {
|
||||
thread::Builder::new()
|
||||
.name("plugin exit_code writer".into())
|
||||
.spawn_scoped(scope, move || writer.write_all(stream))
|
||||
})
|
||||
.transpose()?;
|
||||
// Optimize for stdout: if only stdout is present, don't spawn any other
|
||||
// threads.
|
||||
if let Some((mut writer, stream)) = stdout {
|
||||
writer.write_all(raw_stream_iter(stream))?;
|
||||
}
|
||||
let panicked = |thread_name: &str| {
|
||||
Err(ShellError::NushellFailed {
|
||||
msg: format!(
|
||||
"{thread_name} thread panicked in PipelineDataWriter::write"
|
||||
),
|
||||
})
|
||||
};
|
||||
stderr_thread
|
||||
.map(|t| t.join().unwrap_or_else(|_| panicked("stderr")))
|
||||
.transpose()?;
|
||||
exit_code_thread
|
||||
.map(|t| t.join().unwrap_or_else(|_| panicked("exit_code")))
|
||||
.transpose()?;
|
||||
Ok(())
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Write all of the data in each of the streams. This method returns immediately; any necessary
|
||||
/// write will happen in the background. If a thread was spawned, its handle is returned.
|
||||
pub fn write_background(
|
||||
self,
|
||||
) -> Result<Option<thread::JoinHandle<Result<(), ShellError>>>, ShellError> {
|
||||
match self {
|
||||
PipelineDataWriter::None => Ok(None),
|
||||
_ => Ok(Some(
|
||||
thread::Builder::new()
|
||||
.name("plugin stream background writer".into())
|
||||
.spawn(move || {
|
||||
let result = self.write();
|
||||
if let Err(ref err) = result {
|
||||
// Assume that the background thread error probably won't be handled and log it
|
||||
// here just in case.
|
||||
log::warn!("Error while writing pipeline in background: {err}");
|
||||
}
|
||||
result
|
||||
})?,
|
||||
)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Custom iterator for [`RawStream`] that respects ctrlc, but still has binary chunks
|
||||
fn raw_stream_iter(stream: RawStream) -> impl Iterator<Item = Result<Vec<u8>, ShellError>> {
|
||||
let ctrlc = stream.ctrlc;
|
||||
stream
|
||||
.stream
|
||||
.take_while(move |_| ctrlc.as_ref().map(|b| !b.load(Relaxed)).unwrap_or(true))
|
||||
}
|
628
crates/nu-plugin-core/src/interface/stream/mod.rs
Normal file
628
crates/nu-plugin-core/src/interface/stream/mod.rs
Normal file
@ -0,0 +1,628 @@
|
||||
use nu_plugin_protocol::{StreamData, StreamId, StreamMessage};
|
||||
use nu_protocol::{ShellError, Span, Value};
|
||||
use std::{
|
||||
collections::{btree_map, BTreeMap},
|
||||
iter::FusedIterator,
|
||||
marker::PhantomData,
|
||||
sync::{mpsc, Arc, Condvar, Mutex, MutexGuard, Weak},
|
||||
};
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests;
|
||||
|
||||
/// Receives messages from a stream read from input by a [`StreamManager`].
|
||||
///
|
||||
/// The receiver reads for messages of type `Result<Option<StreamData>, ShellError>` from the
|
||||
/// channel, which is managed by a [`StreamManager`]. Signalling for end-of-stream is explicit
|
||||
/// through `Ok(Some)`.
|
||||
///
|
||||
/// Failing to receive is an error. When end-of-stream is received, the `receiver` is set to `None`
|
||||
/// and all further calls to `next()` return `None`.
|
||||
///
|
||||
/// The type `T` must implement [`FromShellError`], so that errors in the stream can be represented,
|
||||
/// and `TryFrom<StreamData>` to convert it to the correct type.
|
||||
///
|
||||
/// For each message read, it sends [`StreamMessage::Ack`] to the writer. When dropped,
|
||||
/// it sends [`StreamMessage::Drop`].
|
||||
#[derive(Debug)]
|
||||
pub struct StreamReader<T, W>
|
||||
where
|
||||
W: WriteStreamMessage,
|
||||
{
|
||||
id: StreamId,
|
||||
receiver: Option<mpsc::Receiver<Result<Option<StreamData>, ShellError>>>,
|
||||
writer: W,
|
||||
/// Iterator requires the item type to be fixed, so we have to keep it as part of the type,
|
||||
/// even though we're actually receiving dynamic data.
|
||||
marker: PhantomData<fn() -> T>,
|
||||
}
|
||||
|
||||
impl<T, W> StreamReader<T, W>
|
||||
where
|
||||
T: TryFrom<StreamData, Error = ShellError>,
|
||||
W: WriteStreamMessage,
|
||||
{
|
||||
/// Create a new StreamReader from parts
|
||||
fn new(
|
||||
id: StreamId,
|
||||
receiver: mpsc::Receiver<Result<Option<StreamData>, ShellError>>,
|
||||
writer: W,
|
||||
) -> StreamReader<T, W> {
|
||||
StreamReader {
|
||||
id,
|
||||
receiver: Some(receiver),
|
||||
writer,
|
||||
marker: PhantomData,
|
||||
}
|
||||
}
|
||||
|
||||
/// Receive a message from the channel, or return an error if:
|
||||
///
|
||||
/// * the channel couldn't be received from
|
||||
/// * an error was sent on the channel
|
||||
/// * the message received couldn't be converted to `T`
|
||||
pub fn recv(&mut self) -> Result<Option<T>, ShellError> {
|
||||
let connection_lost = || ShellError::GenericError {
|
||||
error: "Stream ended unexpectedly".into(),
|
||||
msg: "connection lost before explicit end of stream".into(),
|
||||
span: None,
|
||||
help: None,
|
||||
inner: vec![],
|
||||
};
|
||||
|
||||
if let Some(ref rx) = self.receiver {
|
||||
// Try to receive a message first
|
||||
let msg = match rx.try_recv() {
|
||||
Ok(msg) => msg?,
|
||||
Err(mpsc::TryRecvError::Empty) => {
|
||||
// The receiver doesn't have any messages waiting for us. It's possible that the
|
||||
// other side hasn't seen our acknowledgements. Let's flush the writer and then
|
||||
// wait
|
||||
self.writer.flush()?;
|
||||
rx.recv().map_err(|_| connection_lost())??
|
||||
}
|
||||
Err(mpsc::TryRecvError::Disconnected) => return Err(connection_lost()),
|
||||
};
|
||||
|
||||
if let Some(data) = msg {
|
||||
// Acknowledge the message
|
||||
self.writer
|
||||
.write_stream_message(StreamMessage::Ack(self.id))?;
|
||||
// Try to convert it into the correct type
|
||||
Ok(Some(data.try_into()?))
|
||||
} else {
|
||||
// Remove the receiver, so that future recv() calls always return Ok(None)
|
||||
self.receiver = None;
|
||||
Ok(None)
|
||||
}
|
||||
} else {
|
||||
// Closed already
|
||||
Ok(None)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T, W> Iterator for StreamReader<T, W>
|
||||
where
|
||||
T: FromShellError + TryFrom<StreamData, Error = ShellError>,
|
||||
W: WriteStreamMessage,
|
||||
{
|
||||
type Item = T;
|
||||
|
||||
fn next(&mut self) -> Option<T> {
|
||||
// Converting the error to the value here makes the implementation a lot easier
|
||||
match self.recv() {
|
||||
Ok(option) => option,
|
||||
Err(err) => {
|
||||
// Drop the receiver so we don't keep returning errors
|
||||
self.receiver = None;
|
||||
Some(T::from_shell_error(err))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Guaranteed not to return anything after the end
|
||||
impl<T, W> FusedIterator for StreamReader<T, W>
|
||||
where
|
||||
T: FromShellError + TryFrom<StreamData, Error = ShellError>,
|
||||
W: WriteStreamMessage,
|
||||
{
|
||||
}
|
||||
|
||||
impl<T, W> Drop for StreamReader<T, W>
|
||||
where
|
||||
W: WriteStreamMessage,
|
||||
{
|
||||
fn drop(&mut self) {
|
||||
if let Err(err) = self
|
||||
.writer
|
||||
.write_stream_message(StreamMessage::Drop(self.id))
|
||||
.and_then(|_| self.writer.flush())
|
||||
{
|
||||
log::warn!("Failed to send message to drop stream: {err}");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Values that can contain a `ShellError` to signal an error has occurred.
|
||||
pub trait FromShellError {
|
||||
fn from_shell_error(err: ShellError) -> Self;
|
||||
}
|
||||
|
||||
// For List streams.
|
||||
impl FromShellError for Value {
|
||||
fn from_shell_error(err: ShellError) -> Self {
|
||||
Value::error(err, Span::unknown())
|
||||
}
|
||||
}
|
||||
|
||||
// For Raw streams, mostly.
|
||||
impl<T> FromShellError for Result<T, ShellError> {
|
||||
fn from_shell_error(err: ShellError) -> Self {
|
||||
Err(err)
|
||||
}
|
||||
}
|
||||
|
||||
/// Writes messages to a stream, with flow control.
|
||||
///
|
||||
/// The `signal` contained
|
||||
#[derive(Debug)]
|
||||
pub struct StreamWriter<W: WriteStreamMessage> {
|
||||
id: StreamId,
|
||||
signal: Arc<StreamWriterSignal>,
|
||||
writer: W,
|
||||
ended: bool,
|
||||
}
|
||||
|
||||
impl<W> StreamWriter<W>
|
||||
where
|
||||
W: WriteStreamMessage,
|
||||
{
|
||||
fn new(id: StreamId, signal: Arc<StreamWriterSignal>, writer: W) -> StreamWriter<W> {
|
||||
StreamWriter {
|
||||
id,
|
||||
signal,
|
||||
writer,
|
||||
ended: false,
|
||||
}
|
||||
}
|
||||
|
||||
/// Check if the stream was dropped from the other end. Recommended to do this before calling
|
||||
/// [`.write()`], especially in a loop.
|
||||
pub fn is_dropped(&self) -> Result<bool, ShellError> {
|
||||
self.signal.is_dropped()
|
||||
}
|
||||
|
||||
/// Write a single piece of data to the stream.
|
||||
///
|
||||
/// Error if something failed with the write, or if [`.end()`] was already called
|
||||
/// previously.
|
||||
pub fn write(&mut self, data: impl Into<StreamData>) -> Result<(), ShellError> {
|
||||
if !self.ended {
|
||||
self.writer
|
||||
.write_stream_message(StreamMessage::Data(self.id, data.into()))?;
|
||||
// This implements flow control, so we don't write too many messages:
|
||||
if !self.signal.notify_sent()? {
|
||||
// Flush the output, and then wait for acknowledgements
|
||||
self.writer.flush()?;
|
||||
self.signal.wait_for_drain()
|
||||
} else {
|
||||
Ok(())
|
||||
}
|
||||
} else {
|
||||
Err(ShellError::GenericError {
|
||||
error: "Wrote to a stream after it ended".into(),
|
||||
msg: format!(
|
||||
"tried to write to stream {} after it was already ended",
|
||||
self.id
|
||||
),
|
||||
span: None,
|
||||
help: Some("this may be a bug in the nu-plugin crate".into()),
|
||||
inner: vec![],
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
/// Write a full iterator to the stream. Note that this doesn't end the stream, so you should
|
||||
/// still call [`.end()`].
|
||||
///
|
||||
/// If the stream is dropped from the other end, the iterator will not be fully consumed, and
|
||||
/// writing will terminate.
|
||||
///
|
||||
/// Returns `Ok(true)` if the iterator was fully consumed, or `Ok(false)` if a drop interrupted
|
||||
/// the stream from the other side.
|
||||
pub fn write_all<T>(&mut self, data: impl IntoIterator<Item = T>) -> Result<bool, ShellError>
|
||||
where
|
||||
T: Into<StreamData>,
|
||||
{
|
||||
// Check before starting
|
||||
if self.is_dropped()? {
|
||||
return Ok(false);
|
||||
}
|
||||
|
||||
for item in data {
|
||||
// Check again after each item is consumed from the iterator, just in case the iterator
|
||||
// takes a while to produce a value
|
||||
if self.is_dropped()? {
|
||||
return Ok(false);
|
||||
}
|
||||
self.write(item)?;
|
||||
}
|
||||
Ok(true)
|
||||
}
|
||||
|
||||
/// End the stream. Recommend doing this instead of relying on `Drop` so that you can catch the
|
||||
/// error.
|
||||
pub fn end(&mut self) -> Result<(), ShellError> {
|
||||
if !self.ended {
|
||||
// Set the flag first so we don't double-report in the Drop
|
||||
self.ended = true;
|
||||
self.writer
|
||||
.write_stream_message(StreamMessage::End(self.id))?;
|
||||
self.writer.flush()
|
||||
} else {
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<W> Drop for StreamWriter<W>
|
||||
where
|
||||
W: WriteStreamMessage,
|
||||
{
|
||||
fn drop(&mut self) {
|
||||
// Make sure we ended the stream
|
||||
if let Err(err) = self.end() {
|
||||
log::warn!("Error while ending stream in Drop for StreamWriter: {err}");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Stores stream state for a writer, and can be blocked on to wait for messages to be acknowledged.
|
||||
/// A key part of managing stream lifecycle and flow control.
|
||||
#[derive(Debug)]
|
||||
pub struct StreamWriterSignal {
|
||||
mutex: Mutex<StreamWriterSignalState>,
|
||||
change_cond: Condvar,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct StreamWriterSignalState {
|
||||
/// Stream has been dropped and consumer is no longer interested in any messages.
|
||||
dropped: bool,
|
||||
/// Number of messages that have been sent without acknowledgement.
|
||||
unacknowledged: i32,
|
||||
/// Max number of messages to send before waiting for acknowledgement.
|
||||
high_pressure_mark: i32,
|
||||
}
|
||||
|
||||
impl StreamWriterSignal {
|
||||
/// Create a new signal.
|
||||
///
|
||||
/// If `notify_sent()` is called more than `high_pressure_mark` times, it will wait until
|
||||
/// `notify_acknowledge()` is called by another thread enough times to bring the number of
|
||||
/// unacknowledged sent messages below that threshold.
|
||||
fn new(high_pressure_mark: i32) -> StreamWriterSignal {
|
||||
assert!(high_pressure_mark > 0);
|
||||
|
||||
StreamWriterSignal {
|
||||
mutex: Mutex::new(StreamWriterSignalState {
|
||||
dropped: false,
|
||||
unacknowledged: 0,
|
||||
high_pressure_mark,
|
||||
}),
|
||||
change_cond: Condvar::new(),
|
||||
}
|
||||
}
|
||||
|
||||
fn lock(&self) -> Result<MutexGuard<StreamWriterSignalState>, ShellError> {
|
||||
self.mutex.lock().map_err(|_| ShellError::NushellFailed {
|
||||
msg: "StreamWriterSignal mutex poisoned due to panic".into(),
|
||||
})
|
||||
}
|
||||
|
||||
/// True if the stream was dropped and the consumer is no longer interested in it. Indicates
|
||||
/// that no more messages should be sent, other than `End`.
|
||||
pub fn is_dropped(&self) -> Result<bool, ShellError> {
|
||||
Ok(self.lock()?.dropped)
|
||||
}
|
||||
|
||||
/// Notify the writers that the stream has been dropped, so they can stop writing.
|
||||
pub fn set_dropped(&self) -> Result<(), ShellError> {
|
||||
let mut state = self.lock()?;
|
||||
state.dropped = true;
|
||||
// Unblock the writers so they can terminate
|
||||
self.change_cond.notify_all();
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Track that a message has been sent. Returns `Ok(true)` if more messages can be sent,
|
||||
/// or `Ok(false)` if the high pressure mark has been reached and [`.wait_for_drain()`] should
|
||||
/// be called to block.
|
||||
pub fn notify_sent(&self) -> Result<bool, ShellError> {
|
||||
let mut state = self.lock()?;
|
||||
state.unacknowledged =
|
||||
state
|
||||
.unacknowledged
|
||||
.checked_add(1)
|
||||
.ok_or_else(|| ShellError::NushellFailed {
|
||||
msg: "Overflow in counter: too many unacknowledged messages".into(),
|
||||
})?;
|
||||
|
||||
Ok(state.unacknowledged < state.high_pressure_mark)
|
||||
}
|
||||
|
||||
/// Wait for acknowledgements before sending more data. Also returns if the stream is dropped.
|
||||
pub fn wait_for_drain(&self) -> Result<(), ShellError> {
|
||||
let mut state = self.lock()?;
|
||||
while !state.dropped && state.unacknowledged >= state.high_pressure_mark {
|
||||
state = self
|
||||
.change_cond
|
||||
.wait(state)
|
||||
.map_err(|_| ShellError::NushellFailed {
|
||||
msg: "StreamWriterSignal mutex poisoned due to panic".into(),
|
||||
})?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Notify the writers that a message has been acknowledged, so they can continue to write
|
||||
/// if they were waiting.
|
||||
pub fn notify_acknowledged(&self) -> Result<(), ShellError> {
|
||||
let mut state = self.lock()?;
|
||||
state.unacknowledged =
|
||||
state
|
||||
.unacknowledged
|
||||
.checked_sub(1)
|
||||
.ok_or_else(|| ShellError::NushellFailed {
|
||||
msg: "Underflow in counter: too many message acknowledgements".into(),
|
||||
})?;
|
||||
// Unblock the writer
|
||||
self.change_cond.notify_one();
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
/// A sink for a [`StreamMessage`]
|
||||
pub trait WriteStreamMessage {
|
||||
fn write_stream_message(&mut self, msg: StreamMessage) -> Result<(), ShellError>;
|
||||
fn flush(&mut self) -> Result<(), ShellError>;
|
||||
}
|
||||
|
||||
#[derive(Debug, Default)]
|
||||
struct StreamManagerState {
|
||||
reading_streams: BTreeMap<StreamId, mpsc::Sender<Result<Option<StreamData>, ShellError>>>,
|
||||
writing_streams: BTreeMap<StreamId, Weak<StreamWriterSignal>>,
|
||||
}
|
||||
|
||||
impl StreamManagerState {
|
||||
/// Lock the state, or return a [`ShellError`] if the mutex is poisoned.
|
||||
fn lock(
|
||||
state: &Mutex<StreamManagerState>,
|
||||
) -> Result<MutexGuard<StreamManagerState>, ShellError> {
|
||||
state.lock().map_err(|_| ShellError::NushellFailed {
|
||||
msg: "StreamManagerState mutex poisoned due to a panic".into(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct StreamManager {
|
||||
state: Arc<Mutex<StreamManagerState>>,
|
||||
}
|
||||
|
||||
impl StreamManager {
|
||||
/// Create a new StreamManager.
|
||||
pub fn new() -> StreamManager {
|
||||
StreamManager {
|
||||
state: Default::default(),
|
||||
}
|
||||
}
|
||||
|
||||
fn lock(&self) -> Result<MutexGuard<StreamManagerState>, ShellError> {
|
||||
StreamManagerState::lock(&self.state)
|
||||
}
|
||||
|
||||
/// Create a new handle to the StreamManager for registering streams.
|
||||
pub fn get_handle(&self) -> StreamManagerHandle {
|
||||
StreamManagerHandle {
|
||||
state: Arc::downgrade(&self.state),
|
||||
}
|
||||
}
|
||||
|
||||
/// Process a stream message, and update internal state accordingly.
|
||||
pub fn handle_message(&self, message: StreamMessage) -> Result<(), ShellError> {
|
||||
let mut state = self.lock()?;
|
||||
match message {
|
||||
StreamMessage::Data(id, data) => {
|
||||
if let Some(sender) = state.reading_streams.get(&id) {
|
||||
// We should ignore the error on send. This just means the reader has dropped,
|
||||
// but it will have sent a Drop message to the other side, and we will receive
|
||||
// an End message at which point we can remove the channel.
|
||||
let _ = sender.send(Ok(Some(data)));
|
||||
Ok(())
|
||||
} else {
|
||||
Err(ShellError::PluginFailedToDecode {
|
||||
msg: format!("received Data for unknown stream {id}"),
|
||||
})
|
||||
}
|
||||
}
|
||||
StreamMessage::End(id) => {
|
||||
if let Some(sender) = state.reading_streams.remove(&id) {
|
||||
// We should ignore the error on the send, because the reader might have dropped
|
||||
// already
|
||||
let _ = sender.send(Ok(None));
|
||||
Ok(())
|
||||
} else {
|
||||
Err(ShellError::PluginFailedToDecode {
|
||||
msg: format!("received End for unknown stream {id}"),
|
||||
})
|
||||
}
|
||||
}
|
||||
StreamMessage::Drop(id) => {
|
||||
if let Some(signal) = state.writing_streams.remove(&id) {
|
||||
if let Some(signal) = signal.upgrade() {
|
||||
// This will wake blocked writers so they can stop writing, so it's ok
|
||||
signal.set_dropped()?;
|
||||
}
|
||||
}
|
||||
// It's possible that the stream has already finished writing and we don't have it
|
||||
// anymore, so we fall through to Ok
|
||||
Ok(())
|
||||
}
|
||||
StreamMessage::Ack(id) => {
|
||||
if let Some(signal) = state.writing_streams.get(&id) {
|
||||
if let Some(signal) = signal.upgrade() {
|
||||
// This will wake up a blocked writer
|
||||
signal.notify_acknowledged()?;
|
||||
} else {
|
||||
// We know it doesn't exist, so might as well remove it
|
||||
state.writing_streams.remove(&id);
|
||||
}
|
||||
}
|
||||
// It's possible that the stream has already finished writing and we don't have it
|
||||
// anymore, so we fall through to Ok
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Broadcast an error to all stream readers. This is useful for error propagation.
|
||||
pub fn broadcast_read_error(&self, error: ShellError) -> Result<(), ShellError> {
|
||||
let state = self.lock()?;
|
||||
for channel in state.reading_streams.values() {
|
||||
// Ignore send errors.
|
||||
let _ = channel.send(Err(error.clone()));
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// If the `StreamManager` is dropped, we should let all of the stream writers know that they
|
||||
// won't be able to write anymore. We don't need to do anything about the readers though
|
||||
// because they'll know when the `Sender` is dropped automatically
|
||||
fn drop_all_writers(&self) -> Result<(), ShellError> {
|
||||
let mut state = self.lock()?;
|
||||
let writers = std::mem::take(&mut state.writing_streams);
|
||||
for (_, signal) in writers {
|
||||
if let Some(signal) = signal.upgrade() {
|
||||
// more important that we send to all than handling an error
|
||||
let _ = signal.set_dropped();
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for StreamManager {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for StreamManager {
|
||||
fn drop(&mut self) {
|
||||
if let Err(err) = self.drop_all_writers() {
|
||||
log::warn!("error during Drop for StreamManager: {}", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// A [`StreamManagerHandle`] supports operations for interacting with the [`StreamManager`].
|
||||
///
|
||||
/// Streams can be registered for reading, returning a [`StreamReader`], or for writing, returning
|
||||
/// a [`StreamWriter`].
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct StreamManagerHandle {
|
||||
state: Weak<Mutex<StreamManagerState>>,
|
||||
}
|
||||
|
||||
impl StreamManagerHandle {
|
||||
/// Because the handle only has a weak reference to the [`StreamManager`] state, we have to
|
||||
/// first try to upgrade to a strong reference and then lock. This function wraps those two
|
||||
/// operations together, handling errors appropriately.
|
||||
fn with_lock<T, F>(&self, f: F) -> Result<T, ShellError>
|
||||
where
|
||||
F: FnOnce(MutexGuard<StreamManagerState>) -> Result<T, ShellError>,
|
||||
{
|
||||
let upgraded = self
|
||||
.state
|
||||
.upgrade()
|
||||
.ok_or_else(|| ShellError::NushellFailed {
|
||||
msg: "StreamManager is no longer alive".into(),
|
||||
})?;
|
||||
let guard = upgraded.lock().map_err(|_| ShellError::NushellFailed {
|
||||
msg: "StreamManagerState mutex poisoned due to a panic".into(),
|
||||
})?;
|
||||
f(guard)
|
||||
}
|
||||
|
||||
/// Register a new stream for reading, and return a [`StreamReader`] that can be used to iterate
|
||||
/// on the values received. A [`StreamMessage`] writer is required for writing control messages
|
||||
/// back to the producer.
|
||||
pub fn read_stream<T, W>(
|
||||
&self,
|
||||
id: StreamId,
|
||||
writer: W,
|
||||
) -> Result<StreamReader<T, W>, ShellError>
|
||||
where
|
||||
T: TryFrom<StreamData, Error = ShellError>,
|
||||
W: WriteStreamMessage,
|
||||
{
|
||||
let (tx, rx) = mpsc::channel();
|
||||
self.with_lock(|mut state| {
|
||||
// Must be exclusive
|
||||
if let btree_map::Entry::Vacant(e) = state.reading_streams.entry(id) {
|
||||
e.insert(tx);
|
||||
Ok(())
|
||||
} else {
|
||||
Err(ShellError::GenericError {
|
||||
error: format!("Failed to acquire reader for stream {id}"),
|
||||
msg: "tried to get a reader for a stream that's already being read".into(),
|
||||
span: None,
|
||||
help: Some("this may be a bug in the nu-plugin crate".into()),
|
||||
inner: vec![],
|
||||
})
|
||||
}
|
||||
})?;
|
||||
Ok(StreamReader::new(id, rx, writer))
|
||||
}
|
||||
|
||||
/// Register a new stream for writing, and return a [`StreamWriter`] that can be used to send
|
||||
/// data to the stream.
|
||||
///
|
||||
/// The `high_pressure_mark` value controls how many messages can be written without receiving
|
||||
/// an acknowledgement before any further attempts to write will wait for the consumer to
|
||||
/// acknowledge them. This prevents overwhelming the reader.
|
||||
pub fn write_stream<W>(
|
||||
&self,
|
||||
id: StreamId,
|
||||
writer: W,
|
||||
high_pressure_mark: i32,
|
||||
) -> Result<StreamWriter<W>, ShellError>
|
||||
where
|
||||
W: WriteStreamMessage,
|
||||
{
|
||||
let signal = Arc::new(StreamWriterSignal::new(high_pressure_mark));
|
||||
self.with_lock(|mut state| {
|
||||
// Remove dead writing streams
|
||||
state
|
||||
.writing_streams
|
||||
.retain(|_, signal| signal.strong_count() > 0);
|
||||
// Must be exclusive
|
||||
if let btree_map::Entry::Vacant(e) = state.writing_streams.entry(id) {
|
||||
e.insert(Arc::downgrade(&signal));
|
||||
Ok(())
|
||||
} else {
|
||||
Err(ShellError::GenericError {
|
||||
error: format!("Failed to acquire writer for stream {id}"),
|
||||
msg: "tried to get a writer for a stream that's already being written".into(),
|
||||
span: None,
|
||||
help: Some("this may be a bug in the nu-plugin crate".into()),
|
||||
inner: vec![],
|
||||
})
|
||||
}
|
||||
})?;
|
||||
Ok(StreamWriter::new(id, signal, writer))
|
||||
}
|
||||
}
|
550
crates/nu-plugin-core/src/interface/stream/tests.rs
Normal file
550
crates/nu-plugin-core/src/interface/stream/tests.rs
Normal file
@ -0,0 +1,550 @@
|
||||
use std::{
|
||||
sync::{
|
||||
atomic::{AtomicBool, Ordering::Relaxed},
|
||||
mpsc, Arc,
|
||||
},
|
||||
time::{Duration, Instant},
|
||||
};
|
||||
|
||||
use super::{StreamManager, StreamReader, StreamWriter, StreamWriterSignal, WriteStreamMessage};
|
||||
use nu_plugin_protocol::{StreamData, StreamMessage};
|
||||
use nu_protocol::{ShellError, Value};
|
||||
|
||||
// Should be long enough to definitely complete any quick operation, but not so long that tests are
|
||||
// slow to complete. 10 ms is a pretty long time
|
||||
const WAIT_DURATION: Duration = Duration::from_millis(10);
|
||||
|
||||
// Maximum time to wait for a condition to be true
|
||||
const MAX_WAIT_DURATION: Duration = Duration::from_millis(500);
|
||||
|
||||
/// Wait for a condition to be true, or panic if the duration exceeds MAX_WAIT_DURATION
|
||||
#[track_caller]
|
||||
fn wait_for_condition(mut cond: impl FnMut() -> bool, message: &str) {
|
||||
// Early check
|
||||
if cond() {
|
||||
return;
|
||||
}
|
||||
|
||||
let start = Instant::now();
|
||||
loop {
|
||||
std::thread::sleep(Duration::from_millis(10));
|
||||
|
||||
if cond() {
|
||||
return;
|
||||
}
|
||||
|
||||
let elapsed = Instant::now().saturating_duration_since(start);
|
||||
if elapsed > MAX_WAIT_DURATION {
|
||||
panic!(
|
||||
"{message}: Waited {:.2}sec, which is more than the maximum of {:.2}sec",
|
||||
elapsed.as_secs_f64(),
|
||||
MAX_WAIT_DURATION.as_secs_f64(),
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Default)]
|
||||
struct TestSink(Vec<StreamMessage>);
|
||||
|
||||
impl WriteStreamMessage for TestSink {
|
||||
fn write_stream_message(&mut self, msg: StreamMessage) -> Result<(), ShellError> {
|
||||
self.0.push(msg);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn flush(&mut self) -> Result<(), ShellError> {
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl WriteStreamMessage for mpsc::Sender<StreamMessage> {
|
||||
fn write_stream_message(&mut self, msg: StreamMessage) -> Result<(), ShellError> {
|
||||
self.send(msg).map_err(|err| ShellError::NushellFailed {
|
||||
msg: err.to_string(),
|
||||
})
|
||||
}
|
||||
|
||||
fn flush(&mut self) -> Result<(), ShellError> {
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn reader_recv_list_messages() -> Result<(), ShellError> {
|
||||
let (tx, rx) = mpsc::channel();
|
||||
let mut reader = StreamReader::new(0, rx, TestSink::default());
|
||||
|
||||
tx.send(Ok(Some(StreamData::List(Value::test_int(5)))))
|
||||
.unwrap();
|
||||
drop(tx);
|
||||
|
||||
assert_eq!(Some(Value::test_int(5)), reader.recv()?);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn list_reader_recv_wrong_type() -> Result<(), ShellError> {
|
||||
let (tx, rx) = mpsc::channel();
|
||||
let mut reader = StreamReader::<Value, _>::new(0, rx, TestSink::default());
|
||||
|
||||
tx.send(Ok(Some(StreamData::Raw(Ok(vec![10, 20])))))
|
||||
.unwrap();
|
||||
tx.send(Ok(Some(StreamData::List(Value::test_nothing()))))
|
||||
.unwrap();
|
||||
drop(tx);
|
||||
|
||||
reader.recv().expect_err("should be an error");
|
||||
reader.recv().expect("should be able to recover");
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn reader_recv_raw_messages() -> Result<(), ShellError> {
|
||||
let (tx, rx) = mpsc::channel();
|
||||
let mut reader =
|
||||
StreamReader::<Result<Vec<u8>, ShellError>, _>::new(0, rx, TestSink::default());
|
||||
|
||||
tx.send(Ok(Some(StreamData::Raw(Ok(vec![10, 20])))))
|
||||
.unwrap();
|
||||
drop(tx);
|
||||
|
||||
assert_eq!(Some(vec![10, 20]), reader.recv()?.transpose()?);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn raw_reader_recv_wrong_type() -> Result<(), ShellError> {
|
||||
let (tx, rx) = mpsc::channel();
|
||||
let mut reader =
|
||||
StreamReader::<Result<Vec<u8>, ShellError>, _>::new(0, rx, TestSink::default());
|
||||
|
||||
tx.send(Ok(Some(StreamData::List(Value::test_nothing()))))
|
||||
.unwrap();
|
||||
tx.send(Ok(Some(StreamData::Raw(Ok(vec![10, 20])))))
|
||||
.unwrap();
|
||||
drop(tx);
|
||||
|
||||
reader.recv().expect_err("should be an error");
|
||||
reader.recv().expect("should be able to recover");
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn reader_recv_acknowledge() -> Result<(), ShellError> {
|
||||
let (tx, rx) = mpsc::channel();
|
||||
let mut reader = StreamReader::<Value, _>::new(0, rx, TestSink::default());
|
||||
|
||||
tx.send(Ok(Some(StreamData::List(Value::test_int(5)))))
|
||||
.unwrap();
|
||||
tx.send(Ok(Some(StreamData::List(Value::test_int(6)))))
|
||||
.unwrap();
|
||||
drop(tx);
|
||||
|
||||
reader.recv()?;
|
||||
reader.recv()?;
|
||||
let wrote = &reader.writer.0;
|
||||
assert!(wrote.len() >= 2);
|
||||
assert!(
|
||||
matches!(wrote[0], StreamMessage::Ack(0)),
|
||||
"0 = {:?}",
|
||||
wrote[0]
|
||||
);
|
||||
assert!(
|
||||
matches!(wrote[1], StreamMessage::Ack(0)),
|
||||
"1 = {:?}",
|
||||
wrote[1]
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn reader_recv_end_of_stream() -> Result<(), ShellError> {
|
||||
let (tx, rx) = mpsc::channel();
|
||||
let mut reader = StreamReader::<Value, _>::new(0, rx, TestSink::default());
|
||||
|
||||
tx.send(Ok(Some(StreamData::List(Value::test_int(5)))))
|
||||
.unwrap();
|
||||
tx.send(Ok(None)).unwrap();
|
||||
drop(tx);
|
||||
|
||||
assert!(reader.recv()?.is_some(), "actual message");
|
||||
assert!(reader.recv()?.is_none(), "on close");
|
||||
assert!(reader.recv()?.is_none(), "after close");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn reader_iter_fuse_on_error() -> Result<(), ShellError> {
|
||||
let (tx, rx) = mpsc::channel();
|
||||
let mut reader = StreamReader::<Value, _>::new(0, rx, TestSink::default());
|
||||
|
||||
drop(tx); // should cause error, because we didn't explicitly signal the end
|
||||
|
||||
assert!(
|
||||
reader.next().is_some_and(|e| e.is_error()),
|
||||
"should be error the first time"
|
||||
);
|
||||
assert!(reader.next().is_none(), "should be closed the second time");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn reader_drop() {
|
||||
let (_tx, rx) = mpsc::channel();
|
||||
|
||||
// Flag set if drop message is received.
|
||||
struct Check(Arc<AtomicBool>);
|
||||
|
||||
impl WriteStreamMessage for Check {
|
||||
fn write_stream_message(&mut self, msg: StreamMessage) -> Result<(), ShellError> {
|
||||
assert!(matches!(msg, StreamMessage::Drop(1)), "got {:?}", msg);
|
||||
self.0.store(true, Relaxed);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn flush(&mut self) -> Result<(), ShellError> {
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
let flag = Arc::new(AtomicBool::new(false));
|
||||
|
||||
let reader = StreamReader::<Value, _>::new(1, rx, Check(flag.clone()));
|
||||
drop(reader);
|
||||
|
||||
assert!(flag.load(Relaxed));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn writer_write_all_stops_if_dropped() -> Result<(), ShellError> {
|
||||
let signal = Arc::new(StreamWriterSignal::new(20));
|
||||
let id = 1337;
|
||||
let mut writer = StreamWriter::new(id, signal.clone(), TestSink::default());
|
||||
|
||||
// Simulate this by having it consume a stream that will actually do the drop halfway through
|
||||
let iter = (0..5).map(Value::test_int).chain({
|
||||
let mut n = 5;
|
||||
std::iter::from_fn(move || {
|
||||
// produces numbers 5..10, but drops for the first one
|
||||
if n == 5 {
|
||||
signal.set_dropped().unwrap();
|
||||
}
|
||||
if n < 10 {
|
||||
let value = Value::test_int(n);
|
||||
n += 1;
|
||||
Some(value)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
});
|
||||
|
||||
writer.write_all(iter)?;
|
||||
|
||||
assert!(writer.is_dropped()?);
|
||||
|
||||
let wrote = &writer.writer.0;
|
||||
assert_eq!(5, wrote.len(), "length wrong: {wrote:?}");
|
||||
|
||||
for (n, message) in (0..5).zip(wrote) {
|
||||
match message {
|
||||
StreamMessage::Data(msg_id, StreamData::List(value)) => {
|
||||
assert_eq!(id, *msg_id, "id");
|
||||
assert_eq!(Value::test_int(n), *value, "value");
|
||||
}
|
||||
other => panic!("unexpected message: {other:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn writer_end() -> Result<(), ShellError> {
|
||||
let signal = Arc::new(StreamWriterSignal::new(20));
|
||||
let mut writer = StreamWriter::new(9001, signal.clone(), TestSink::default());
|
||||
|
||||
writer.end()?;
|
||||
writer
|
||||
.write(Value::test_int(2))
|
||||
.expect_err("shouldn't be able to write after end");
|
||||
writer.end().expect("end twice should be ok");
|
||||
|
||||
let wrote = &writer.writer.0;
|
||||
assert!(
|
||||
matches!(wrote.last(), Some(StreamMessage::End(9001))),
|
||||
"didn't write end message: {wrote:?}"
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn signal_set_dropped() -> Result<(), ShellError> {
|
||||
let signal = StreamWriterSignal::new(4);
|
||||
assert!(!signal.is_dropped()?);
|
||||
signal.set_dropped()?;
|
||||
assert!(signal.is_dropped()?);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn signal_notify_sent_false_if_unacknowledged() -> Result<(), ShellError> {
|
||||
let signal = StreamWriterSignal::new(2);
|
||||
assert!(signal.notify_sent()?);
|
||||
for _ in 0..100 {
|
||||
assert!(!signal.notify_sent()?);
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn signal_notify_sent_never_false_if_flowing() -> Result<(), ShellError> {
|
||||
let signal = StreamWriterSignal::new(1);
|
||||
for _ in 0..100 {
|
||||
signal.notify_acknowledged()?;
|
||||
}
|
||||
for _ in 0..100 {
|
||||
assert!(signal.notify_sent()?);
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn signal_wait_for_drain_blocks_on_unacknowledged() -> Result<(), ShellError> {
|
||||
let signal = StreamWriterSignal::new(50);
|
||||
std::thread::scope(|scope| {
|
||||
let spawned = scope.spawn(|| {
|
||||
for _ in 0..100 {
|
||||
if !signal.notify_sent()? {
|
||||
signal.wait_for_drain()?;
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
});
|
||||
std::thread::sleep(WAIT_DURATION);
|
||||
assert!(!spawned.is_finished(), "didn't block");
|
||||
for _ in 0..100 {
|
||||
signal.notify_acknowledged()?;
|
||||
}
|
||||
wait_for_condition(|| spawned.is_finished(), "blocked at end");
|
||||
spawned.join().unwrap()
|
||||
})
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn signal_wait_for_drain_unblocks_on_dropped() -> Result<(), ShellError> {
|
||||
let signal = StreamWriterSignal::new(1);
|
||||
std::thread::scope(|scope| {
|
||||
let spawned = scope.spawn(|| {
|
||||
while !signal.is_dropped()? {
|
||||
if !signal.notify_sent()? {
|
||||
signal.wait_for_drain()?;
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
});
|
||||
std::thread::sleep(WAIT_DURATION);
|
||||
assert!(!spawned.is_finished(), "didn't block");
|
||||
signal.set_dropped()?;
|
||||
wait_for_condition(|| spawned.is_finished(), "still blocked at end");
|
||||
spawned.join().unwrap()
|
||||
})
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn stream_manager_single_stream_read_scenario() -> Result<(), ShellError> {
|
||||
let manager = StreamManager::new();
|
||||
let handle = manager.get_handle();
|
||||
let (tx, rx) = mpsc::channel();
|
||||
let readable = handle.read_stream::<Value, _>(2, tx)?;
|
||||
|
||||
let expected_values = vec![Value::test_int(40), Value::test_string("hello")];
|
||||
|
||||
for value in &expected_values {
|
||||
manager.handle_message(StreamMessage::Data(2, value.clone().into()))?;
|
||||
}
|
||||
manager.handle_message(StreamMessage::End(2))?;
|
||||
|
||||
let values = readable.collect::<Vec<Value>>();
|
||||
|
||||
assert_eq!(expected_values, values);
|
||||
|
||||
// Now check the sent messages on consumption
|
||||
// Should be Ack for each message, then Drop
|
||||
for _ in &expected_values {
|
||||
match rx.try_recv().expect("failed to receive Ack") {
|
||||
StreamMessage::Ack(2) => (),
|
||||
other => panic!("should have been an Ack: {other:?}"),
|
||||
}
|
||||
}
|
||||
match rx.try_recv().expect("failed to receive Drop") {
|
||||
StreamMessage::Drop(2) => (),
|
||||
other => panic!("should have been a Drop: {other:?}"),
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn stream_manager_multi_stream_read_scenario() -> Result<(), ShellError> {
|
||||
let manager = StreamManager::new();
|
||||
let handle = manager.get_handle();
|
||||
let (tx, rx) = mpsc::channel();
|
||||
let readable_list = handle.read_stream::<Value, _>(2, tx.clone())?;
|
||||
let readable_raw = handle.read_stream::<Result<Vec<u8>, _>, _>(3, tx)?;
|
||||
|
||||
let expected_values = (1..100).map(Value::test_int).collect::<Vec<_>>();
|
||||
let expected_raw_buffers = (1..100).map(|n| vec![n]).collect::<Vec<Vec<u8>>>();
|
||||
|
||||
for (value, buf) in expected_values.iter().zip(&expected_raw_buffers) {
|
||||
manager.handle_message(StreamMessage::Data(2, value.clone().into()))?;
|
||||
manager.handle_message(StreamMessage::Data(3, StreamData::Raw(Ok(buf.clone()))))?;
|
||||
}
|
||||
manager.handle_message(StreamMessage::End(2))?;
|
||||
manager.handle_message(StreamMessage::End(3))?;
|
||||
|
||||
let values = readable_list.collect::<Vec<Value>>();
|
||||
let bufs = readable_raw.collect::<Result<Vec<Vec<u8>>, _>>()?;
|
||||
|
||||
for (expected_value, value) in expected_values.iter().zip(&values) {
|
||||
assert_eq!(expected_value, value, "in List stream");
|
||||
}
|
||||
for (expected_buf, buf) in expected_raw_buffers.iter().zip(&bufs) {
|
||||
assert_eq!(expected_buf, buf, "in Raw stream");
|
||||
}
|
||||
|
||||
// Now check the sent messages on consumption
|
||||
// Should be Ack for each message, then Drop
|
||||
for _ in &expected_values {
|
||||
match rx.try_recv().expect("failed to receive Ack") {
|
||||
StreamMessage::Ack(2) => (),
|
||||
other => panic!("should have been an Ack(2): {other:?}"),
|
||||
}
|
||||
}
|
||||
match rx.try_recv().expect("failed to receive Drop") {
|
||||
StreamMessage::Drop(2) => (),
|
||||
other => panic!("should have been a Drop(2): {other:?}"),
|
||||
}
|
||||
for _ in &expected_values {
|
||||
match rx.try_recv().expect("failed to receive Ack") {
|
||||
StreamMessage::Ack(3) => (),
|
||||
other => panic!("should have been an Ack(3): {other:?}"),
|
||||
}
|
||||
}
|
||||
match rx.try_recv().expect("failed to receive Drop") {
|
||||
StreamMessage::Drop(3) => (),
|
||||
other => panic!("should have been a Drop(3): {other:?}"),
|
||||
}
|
||||
|
||||
// Should be end of stream
|
||||
assert!(
|
||||
rx.try_recv().is_err(),
|
||||
"more messages written to stream than expected"
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn stream_manager_write_scenario() -> Result<(), ShellError> {
|
||||
let manager = StreamManager::new();
|
||||
let handle = manager.get_handle();
|
||||
let (tx, rx) = mpsc::channel();
|
||||
let mut writable = handle.write_stream(4, tx, 100)?;
|
||||
|
||||
let expected_values = vec![b"hello".to_vec(), b"world".to_vec(), b"test".to_vec()];
|
||||
|
||||
for value in &expected_values {
|
||||
writable.write(Ok::<_, ShellError>(value.clone()))?;
|
||||
}
|
||||
|
||||
// Now try signalling ack
|
||||
assert_eq!(
|
||||
expected_values.len() as i32,
|
||||
writable.signal.lock()?.unacknowledged,
|
||||
"unacknowledged initial count",
|
||||
);
|
||||
manager.handle_message(StreamMessage::Ack(4))?;
|
||||
assert_eq!(
|
||||
expected_values.len() as i32 - 1,
|
||||
writable.signal.lock()?.unacknowledged,
|
||||
"unacknowledged post-Ack count",
|
||||
);
|
||||
|
||||
// ...and Drop
|
||||
manager.handle_message(StreamMessage::Drop(4))?;
|
||||
assert!(writable.is_dropped()?);
|
||||
|
||||
// Drop the StreamWriter...
|
||||
drop(writable);
|
||||
|
||||
// now check what was actually written
|
||||
for value in &expected_values {
|
||||
match rx.try_recv().expect("failed to receive Data") {
|
||||
StreamMessage::Data(4, StreamData::Raw(Ok(received))) => {
|
||||
assert_eq!(*value, received);
|
||||
}
|
||||
other @ StreamMessage::Data(..) => panic!("wrong Data for {value:?}: {other:?}"),
|
||||
other => panic!("should have been Data: {other:?}"),
|
||||
}
|
||||
}
|
||||
match rx.try_recv().expect("failed to receive End") {
|
||||
StreamMessage::End(4) => (),
|
||||
other => panic!("should have been End: {other:?}"),
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn stream_manager_broadcast_read_error() -> Result<(), ShellError> {
|
||||
let manager = StreamManager::new();
|
||||
let handle = manager.get_handle();
|
||||
let mut readable0 = handle.read_stream::<Value, _>(0, TestSink::default())?;
|
||||
let mut readable1 = handle.read_stream::<Result<Vec<u8>, _>, _>(1, TestSink::default())?;
|
||||
|
||||
let error = ShellError::PluginFailedToDecode {
|
||||
msg: "test decode error".into(),
|
||||
};
|
||||
|
||||
manager.broadcast_read_error(error.clone())?;
|
||||
drop(manager);
|
||||
|
||||
assert_eq!(
|
||||
error.to_string(),
|
||||
readable0
|
||||
.recv()
|
||||
.transpose()
|
||||
.expect("nothing received from readable0")
|
||||
.expect_err("not an error received from readable0")
|
||||
.to_string()
|
||||
);
|
||||
assert_eq!(
|
||||
error.to_string(),
|
||||
readable1
|
||||
.next()
|
||||
.expect("nothing received from readable1")
|
||||
.expect_err("not an error received from readable1")
|
||||
.to_string()
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn stream_manager_drop_writers_on_drop() -> Result<(), ShellError> {
|
||||
let manager = StreamManager::new();
|
||||
let handle = manager.get_handle();
|
||||
let writable = handle.write_stream(4, TestSink::default(), 100)?;
|
||||
|
||||
assert!(!writable.is_dropped()?);
|
||||
|
||||
drop(manager);
|
||||
|
||||
assert!(writable.is_dropped()?);
|
||||
|
||||
Ok(())
|
||||
}
|
134
crates/nu-plugin-core/src/interface/test_util.rs
Normal file
134
crates/nu-plugin-core/src/interface/test_util.rs
Normal file
@ -0,0 +1,134 @@
|
||||
use nu_protocol::ShellError;
|
||||
use std::{
|
||||
collections::VecDeque,
|
||||
sync::{Arc, Mutex},
|
||||
};
|
||||
|
||||
use crate::{PluginRead, PluginWrite};
|
||||
|
||||
const FAILED: &str = "failed to lock TestCase";
|
||||
|
||||
/// Mock read/write helper for the engine and plugin interfaces.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct TestCase<I, O> {
|
||||
r#in: Arc<Mutex<TestData<I>>>,
|
||||
out: Arc<Mutex<TestData<O>>>,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct TestData<T> {
|
||||
data: VecDeque<T>,
|
||||
error: Option<ShellError>,
|
||||
flushed: bool,
|
||||
}
|
||||
|
||||
impl<T> Default for TestData<T> {
|
||||
fn default() -> Self {
|
||||
TestData {
|
||||
data: VecDeque::new(),
|
||||
error: None,
|
||||
flushed: false,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<I, O> PluginRead<I> for TestCase<I, O> {
|
||||
fn read(&mut self) -> Result<Option<I>, ShellError> {
|
||||
let mut lock = self.r#in.lock().expect(FAILED);
|
||||
if let Some(err) = lock.error.take() {
|
||||
Err(err)
|
||||
} else {
|
||||
Ok(lock.data.pop_front())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<I, O> PluginWrite<O> for TestCase<I, O>
|
||||
where
|
||||
I: Send + Clone,
|
||||
O: Send + Clone,
|
||||
{
|
||||
fn write(&self, data: &O) -> Result<(), ShellError> {
|
||||
let mut lock = self.out.lock().expect(FAILED);
|
||||
lock.flushed = false;
|
||||
|
||||
if let Some(err) = lock.error.take() {
|
||||
Err(err)
|
||||
} else {
|
||||
lock.data.push_back(data.clone());
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
fn flush(&self) -> Result<(), ShellError> {
|
||||
let mut lock = self.out.lock().expect(FAILED);
|
||||
lock.flushed = true;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
impl<I, O> TestCase<I, O> {
|
||||
pub fn new() -> TestCase<I, O> {
|
||||
TestCase {
|
||||
r#in: Default::default(),
|
||||
out: Default::default(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Clear the read buffer.
|
||||
pub fn clear(&self) {
|
||||
self.r#in.lock().expect(FAILED).data.truncate(0);
|
||||
}
|
||||
|
||||
/// Add input that will be read by the interface.
|
||||
pub fn add(&self, input: impl Into<I>) {
|
||||
self.r#in.lock().expect(FAILED).data.push_back(input.into());
|
||||
}
|
||||
|
||||
/// Add multiple inputs that will be read by the interface.
|
||||
pub fn extend(&self, inputs: impl IntoIterator<Item = I>) {
|
||||
self.r#in.lock().expect(FAILED).data.extend(inputs);
|
||||
}
|
||||
|
||||
/// Return an error from the next read operation.
|
||||
pub fn set_read_error(&self, err: ShellError) {
|
||||
self.r#in.lock().expect(FAILED).error = Some(err);
|
||||
}
|
||||
|
||||
/// Return an error from the next write operation.
|
||||
pub fn set_write_error(&self, err: ShellError) {
|
||||
self.out.lock().expect(FAILED).error = Some(err);
|
||||
}
|
||||
|
||||
/// Get the next output that was written.
|
||||
pub fn next_written(&self) -> Option<O> {
|
||||
self.out.lock().expect(FAILED).data.pop_front()
|
||||
}
|
||||
|
||||
/// Iterator over written data.
|
||||
pub fn written(&self) -> impl Iterator<Item = O> + '_ {
|
||||
std::iter::from_fn(|| self.next_written())
|
||||
}
|
||||
|
||||
/// Returns true if the writer was flushed after the last write operation.
|
||||
pub fn was_flushed(&self) -> bool {
|
||||
self.out.lock().expect(FAILED).flushed
|
||||
}
|
||||
|
||||
/// Returns true if the reader has unconsumed reads.
|
||||
pub fn has_unconsumed_read(&self) -> bool {
|
||||
!self.r#in.lock().expect(FAILED).data.is_empty()
|
||||
}
|
||||
|
||||
/// Returns true if the writer has unconsumed writes.
|
||||
pub fn has_unconsumed_write(&self) -> bool {
|
||||
!self.out.lock().expect(FAILED).data.is_empty()
|
||||
}
|
||||
}
|
||||
|
||||
impl<I, O> Default for TestCase<I, O> {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
573
crates/nu-plugin-core/src/interface/tests.rs
Normal file
573
crates/nu-plugin-core/src/interface/tests.rs
Normal file
@ -0,0 +1,573 @@
|
||||
use crate::util::Sequence;
|
||||
|
||||
use super::{
|
||||
stream::{StreamManager, StreamManagerHandle},
|
||||
test_util::TestCase,
|
||||
Interface, InterfaceManager, PluginRead, PluginWrite,
|
||||
};
|
||||
use nu_plugin_protocol::{
|
||||
ExternalStreamInfo, ListStreamInfo, PipelineDataHeader, PluginInput, PluginOutput,
|
||||
RawStreamInfo, StreamData, StreamMessage,
|
||||
};
|
||||
use nu_protocol::{
|
||||
DataSource, ListStream, PipelineData, PipelineMetadata, RawStream, ShellError, Span, Value,
|
||||
};
|
||||
use std::{path::Path, sync::Arc};
|
||||
|
||||
fn test_metadata() -> PipelineMetadata {
|
||||
PipelineMetadata {
|
||||
data_source: DataSource::FilePath("/test/path".into()),
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct TestInterfaceManager {
|
||||
stream_manager: StreamManager,
|
||||
test: TestCase<PluginInput, PluginOutput>,
|
||||
seq: Arc<Sequence>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct TestInterface {
|
||||
stream_manager_handle: StreamManagerHandle,
|
||||
test: TestCase<PluginInput, PluginOutput>,
|
||||
seq: Arc<Sequence>,
|
||||
}
|
||||
|
||||
impl TestInterfaceManager {
|
||||
fn new(test: &TestCase<PluginInput, PluginOutput>) -> TestInterfaceManager {
|
||||
TestInterfaceManager {
|
||||
stream_manager: StreamManager::new(),
|
||||
test: test.clone(),
|
||||
seq: Arc::new(Sequence::default()),
|
||||
}
|
||||
}
|
||||
|
||||
fn consume_all(&mut self) -> Result<(), ShellError> {
|
||||
while let Some(msg) = self.test.read()? {
|
||||
self.consume(msg)?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl InterfaceManager for TestInterfaceManager {
|
||||
type Interface = TestInterface;
|
||||
type Input = PluginInput;
|
||||
|
||||
fn get_interface(&self) -> Self::Interface {
|
||||
TestInterface {
|
||||
stream_manager_handle: self.stream_manager.get_handle(),
|
||||
test: self.test.clone(),
|
||||
seq: self.seq.clone(),
|
||||
}
|
||||
}
|
||||
|
||||
fn consume(&mut self, input: Self::Input) -> Result<(), ShellError> {
|
||||
match input {
|
||||
PluginInput::Data(..)
|
||||
| PluginInput::End(..)
|
||||
| PluginInput::Drop(..)
|
||||
| PluginInput::Ack(..) => self.consume_stream_message(
|
||||
input
|
||||
.try_into()
|
||||
.expect("failed to convert message to StreamMessage"),
|
||||
),
|
||||
_ => unimplemented!(),
|
||||
}
|
||||
}
|
||||
|
||||
fn stream_manager(&self) -> &StreamManager {
|
||||
&self.stream_manager
|
||||
}
|
||||
|
||||
fn prepare_pipeline_data(&self, data: PipelineData) -> Result<PipelineData, ShellError> {
|
||||
Ok(data.set_metadata(Some(test_metadata())))
|
||||
}
|
||||
}
|
||||
|
||||
impl Interface for TestInterface {
|
||||
type Output = PluginOutput;
|
||||
type DataContext = ();
|
||||
|
||||
fn write(&self, output: Self::Output) -> Result<(), ShellError> {
|
||||
self.test.write(&output)
|
||||
}
|
||||
|
||||
fn flush(&self) -> Result<(), ShellError> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn stream_id_sequence(&self) -> &Sequence {
|
||||
&self.seq
|
||||
}
|
||||
|
||||
fn stream_manager_handle(&self) -> &StreamManagerHandle {
|
||||
&self.stream_manager_handle
|
||||
}
|
||||
|
||||
fn prepare_pipeline_data(
|
||||
&self,
|
||||
data: PipelineData,
|
||||
_context: &(),
|
||||
) -> Result<PipelineData, ShellError> {
|
||||
// Add an arbitrary check to the data to verify this is being called
|
||||
match data {
|
||||
PipelineData::Value(Value::Binary { .. }, None) => Err(ShellError::NushellFailed {
|
||||
msg: "TEST can't send binary".into(),
|
||||
}),
|
||||
_ => Ok(data),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn read_pipeline_data_empty() -> Result<(), ShellError> {
|
||||
let manager = TestInterfaceManager::new(&TestCase::new());
|
||||
let header = PipelineDataHeader::Empty;
|
||||
|
||||
assert!(matches!(
|
||||
manager.read_pipeline_data(header, None)?,
|
||||
PipelineData::Empty
|
||||
));
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn read_pipeline_data_value() -> Result<(), ShellError> {
|
||||
let manager = TestInterfaceManager::new(&TestCase::new());
|
||||
let value = Value::test_int(4);
|
||||
let header = PipelineDataHeader::Value(value.clone());
|
||||
|
||||
match manager.read_pipeline_data(header, None)? {
|
||||
PipelineData::Value(read_value, _) => assert_eq!(value, read_value),
|
||||
PipelineData::ListStream(_, _) => panic!("unexpected ListStream"),
|
||||
PipelineData::ExternalStream { .. } => panic!("unexpected ExternalStream"),
|
||||
PipelineData::Empty => panic!("unexpected Empty"),
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn read_pipeline_data_list_stream() -> Result<(), ShellError> {
|
||||
let test = TestCase::new();
|
||||
let mut manager = TestInterfaceManager::new(&test);
|
||||
|
||||
let data = (0..100).map(Value::test_int).collect::<Vec<_>>();
|
||||
|
||||
for value in &data {
|
||||
test.add(StreamMessage::Data(7, value.clone().into()));
|
||||
}
|
||||
test.add(StreamMessage::End(7));
|
||||
|
||||
let header = PipelineDataHeader::ListStream(ListStreamInfo { id: 7 });
|
||||
|
||||
let pipe = manager.read_pipeline_data(header, None)?;
|
||||
assert!(
|
||||
matches!(pipe, PipelineData::ListStream(..)),
|
||||
"unexpected PipelineData: {pipe:?}"
|
||||
);
|
||||
|
||||
// need to consume input
|
||||
manager.consume_all()?;
|
||||
|
||||
let mut count = 0;
|
||||
for (expected, read) in data.into_iter().zip(pipe) {
|
||||
assert_eq!(expected, read);
|
||||
count += 1;
|
||||
}
|
||||
assert_eq!(100, count);
|
||||
|
||||
assert!(test.has_unconsumed_write());
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn read_pipeline_data_external_stream() -> Result<(), ShellError> {
|
||||
let test = TestCase::new();
|
||||
let mut manager = TestInterfaceManager::new(&test);
|
||||
|
||||
let iterations = 100;
|
||||
let out_pattern = b"hello".to_vec();
|
||||
let err_pattern = vec![5, 4, 3, 2];
|
||||
|
||||
test.add(StreamMessage::Data(14, Value::test_int(1).into()));
|
||||
for _ in 0..iterations {
|
||||
test.add(StreamMessage::Data(
|
||||
12,
|
||||
StreamData::Raw(Ok(out_pattern.clone())),
|
||||
));
|
||||
test.add(StreamMessage::Data(
|
||||
13,
|
||||
StreamData::Raw(Ok(err_pattern.clone())),
|
||||
));
|
||||
}
|
||||
test.add(StreamMessage::End(12));
|
||||
test.add(StreamMessage::End(13));
|
||||
test.add(StreamMessage::End(14));
|
||||
|
||||
let test_span = Span::new(10, 13);
|
||||
let header = PipelineDataHeader::ExternalStream(ExternalStreamInfo {
|
||||
span: test_span,
|
||||
stdout: Some(RawStreamInfo {
|
||||
id: 12,
|
||||
is_binary: false,
|
||||
known_size: Some((out_pattern.len() * iterations) as u64),
|
||||
}),
|
||||
stderr: Some(RawStreamInfo {
|
||||
id: 13,
|
||||
is_binary: true,
|
||||
known_size: None,
|
||||
}),
|
||||
exit_code: Some(ListStreamInfo { id: 14 }),
|
||||
trim_end_newline: true,
|
||||
});
|
||||
|
||||
let pipe = manager.read_pipeline_data(header, None)?;
|
||||
|
||||
// need to consume input
|
||||
manager.consume_all()?;
|
||||
|
||||
match pipe {
|
||||
PipelineData::ExternalStream {
|
||||
stdout,
|
||||
stderr,
|
||||
exit_code,
|
||||
span,
|
||||
metadata,
|
||||
trim_end_newline,
|
||||
} => {
|
||||
let stdout = stdout.expect("stdout is None");
|
||||
let stderr = stderr.expect("stderr is None");
|
||||
let exit_code = exit_code.expect("exit_code is None");
|
||||
assert_eq!(test_span, span);
|
||||
assert!(
|
||||
metadata.is_some(),
|
||||
"expected metadata to be Some due to prepare_pipeline_data()"
|
||||
);
|
||||
assert!(trim_end_newline);
|
||||
|
||||
assert!(!stdout.is_binary);
|
||||
assert!(stderr.is_binary);
|
||||
|
||||
assert_eq!(
|
||||
Some((out_pattern.len() * iterations) as u64),
|
||||
stdout.known_size
|
||||
);
|
||||
assert_eq!(None, stderr.known_size);
|
||||
|
||||
// check the streams
|
||||
let mut count = 0;
|
||||
for chunk in stdout.stream {
|
||||
assert_eq!(out_pattern, chunk?);
|
||||
count += 1;
|
||||
}
|
||||
assert_eq!(iterations, count, "stdout length");
|
||||
let mut count = 0;
|
||||
|
||||
for chunk in stderr.stream {
|
||||
assert_eq!(err_pattern, chunk?);
|
||||
count += 1;
|
||||
}
|
||||
assert_eq!(iterations, count, "stderr length");
|
||||
|
||||
assert_eq!(vec![Value::test_int(1)], exit_code.collect::<Vec<_>>());
|
||||
}
|
||||
_ => panic!("unexpected PipelineData: {pipe:?}"),
|
||||
}
|
||||
|
||||
// Don't need to check exactly what was written, just be sure that there is some output
|
||||
assert!(test.has_unconsumed_write());
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn read_pipeline_data_ctrlc() -> Result<(), ShellError> {
|
||||
let manager = TestInterfaceManager::new(&TestCase::new());
|
||||
let header = PipelineDataHeader::ListStream(ListStreamInfo { id: 0 });
|
||||
let ctrlc = Default::default();
|
||||
match manager.read_pipeline_data(header, Some(&ctrlc))? {
|
||||
PipelineData::ListStream(
|
||||
ListStream {
|
||||
ctrlc: stream_ctrlc,
|
||||
..
|
||||
},
|
||||
_,
|
||||
) => {
|
||||
assert!(Arc::ptr_eq(&ctrlc, &stream_ctrlc.expect("ctrlc not set")));
|
||||
Ok(())
|
||||
}
|
||||
_ => panic!("Unexpected PipelineData, should have been ListStream"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn read_pipeline_data_prepared_properly() -> Result<(), ShellError> {
|
||||
let manager = TestInterfaceManager::new(&TestCase::new());
|
||||
let header = PipelineDataHeader::ListStream(ListStreamInfo { id: 0 });
|
||||
match manager.read_pipeline_data(header, None)? {
|
||||
PipelineData::ListStream(_, meta) => match meta {
|
||||
Some(PipelineMetadata { data_source }) => match data_source {
|
||||
DataSource::FilePath(path) => {
|
||||
assert_eq!(Path::new("/test/path"), path);
|
||||
Ok(())
|
||||
}
|
||||
_ => panic!("wrong metadata: {data_source:?}"),
|
||||
},
|
||||
None => panic!("metadata not set"),
|
||||
},
|
||||
_ => panic!("Unexpected PipelineData, should have been ListStream"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn write_pipeline_data_empty() -> Result<(), ShellError> {
|
||||
let test = TestCase::new();
|
||||
let manager = TestInterfaceManager::new(&test);
|
||||
let interface = manager.get_interface();
|
||||
|
||||
let (header, writer) = interface.init_write_pipeline_data(PipelineData::Empty, &())?;
|
||||
|
||||
assert!(matches!(header, PipelineDataHeader::Empty));
|
||||
|
||||
writer.write()?;
|
||||
|
||||
assert!(
|
||||
!test.has_unconsumed_write(),
|
||||
"Empty shouldn't write any stream messages, test: {test:#?}"
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn write_pipeline_data_value() -> Result<(), ShellError> {
|
||||
let test = TestCase::new();
|
||||
let manager = TestInterfaceManager::new(&test);
|
||||
let interface = manager.get_interface();
|
||||
let value = Value::test_int(7);
|
||||
|
||||
let (header, writer) =
|
||||
interface.init_write_pipeline_data(PipelineData::Value(value.clone(), None), &())?;
|
||||
|
||||
match header {
|
||||
PipelineDataHeader::Value(read_value) => assert_eq!(value, read_value),
|
||||
_ => panic!("unexpected header: {header:?}"),
|
||||
}
|
||||
|
||||
writer.write()?;
|
||||
|
||||
assert!(
|
||||
!test.has_unconsumed_write(),
|
||||
"Value shouldn't write any stream messages, test: {test:#?}"
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn write_pipeline_data_prepared_properly() {
|
||||
let manager = TestInterfaceManager::new(&TestCase::new());
|
||||
let interface = manager.get_interface();
|
||||
|
||||
// Sending a binary should be an error in our test scenario
|
||||
let value = Value::test_binary(vec![7, 8]);
|
||||
|
||||
match interface.init_write_pipeline_data(PipelineData::Value(value, None), &()) {
|
||||
Ok(_) => panic!("prepare_pipeline_data was not called"),
|
||||
Err(err) => {
|
||||
assert_eq!(
|
||||
ShellError::NushellFailed {
|
||||
msg: "TEST can't send binary".into()
|
||||
}
|
||||
.to_string(),
|
||||
err.to_string()
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn write_pipeline_data_list_stream() -> Result<(), ShellError> {
|
||||
let test = TestCase::new();
|
||||
let manager = TestInterfaceManager::new(&test);
|
||||
let interface = manager.get_interface();
|
||||
|
||||
let values = vec![
|
||||
Value::test_int(40),
|
||||
Value::test_bool(false),
|
||||
Value::test_string("this is a test"),
|
||||
];
|
||||
|
||||
// Set up pipeline data for a list stream
|
||||
let pipe = PipelineData::ListStream(
|
||||
ListStream::from_stream(values.clone().into_iter(), None),
|
||||
None,
|
||||
);
|
||||
|
||||
let (header, writer) = interface.init_write_pipeline_data(pipe, &())?;
|
||||
|
||||
let info = match header {
|
||||
PipelineDataHeader::ListStream(info) => info,
|
||||
_ => panic!("unexpected header: {header:?}"),
|
||||
};
|
||||
|
||||
writer.write()?;
|
||||
|
||||
// Now make sure the stream messages have been written
|
||||
for value in values {
|
||||
match test.next_written().expect("unexpected end of stream") {
|
||||
PluginOutput::Data(id, data) => {
|
||||
assert_eq!(info.id, id, "Data id");
|
||||
match data {
|
||||
StreamData::List(read_value) => assert_eq!(value, read_value, "Data value"),
|
||||
_ => panic!("unexpected Data: {data:?}"),
|
||||
}
|
||||
}
|
||||
other => panic!("unexpected output: {other:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
match test.next_written().expect("unexpected end of stream") {
|
||||
PluginOutput::End(id) => {
|
||||
assert_eq!(info.id, id, "End id");
|
||||
}
|
||||
other => panic!("unexpected output: {other:?}"),
|
||||
}
|
||||
|
||||
assert!(!test.has_unconsumed_write());
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn write_pipeline_data_external_stream() -> Result<(), ShellError> {
|
||||
let test = TestCase::new();
|
||||
let manager = TestInterfaceManager::new(&test);
|
||||
let interface = manager.get_interface();
|
||||
|
||||
let stdout_bufs = vec![
|
||||
b"hello".to_vec(),
|
||||
b"world".to_vec(),
|
||||
b"these are tests".to_vec(),
|
||||
];
|
||||
let stdout_len = stdout_bufs.iter().map(|b| b.len() as u64).sum::<u64>();
|
||||
let stderr_bufs = vec![b"error messages".to_vec(), b"go here".to_vec()];
|
||||
let exit_code = Value::test_int(7);
|
||||
|
||||
let span = Span::new(400, 500);
|
||||
|
||||
// Set up pipeline data for an external stream
|
||||
let pipe = PipelineData::ExternalStream {
|
||||
stdout: Some(RawStream::new(
|
||||
Box::new(stdout_bufs.clone().into_iter().map(Ok)),
|
||||
None,
|
||||
span,
|
||||
Some(stdout_len),
|
||||
)),
|
||||
stderr: Some(RawStream::new(
|
||||
Box::new(stderr_bufs.clone().into_iter().map(Ok)),
|
||||
None,
|
||||
span,
|
||||
None,
|
||||
)),
|
||||
exit_code: Some(ListStream::from_stream(
|
||||
std::iter::once(exit_code.clone()),
|
||||
None,
|
||||
)),
|
||||
span,
|
||||
metadata: None,
|
||||
trim_end_newline: true,
|
||||
};
|
||||
|
||||
let (header, writer) = interface.init_write_pipeline_data(pipe, &())?;
|
||||
|
||||
let info = match header {
|
||||
PipelineDataHeader::ExternalStream(info) => info,
|
||||
_ => panic!("unexpected header: {header:?}"),
|
||||
};
|
||||
|
||||
writer.write()?;
|
||||
|
||||
let stdout_info = info.stdout.as_ref().expect("stdout info is None");
|
||||
let stderr_info = info.stderr.as_ref().expect("stderr info is None");
|
||||
let exit_code_info = info.exit_code.as_ref().expect("exit code info is None");
|
||||
|
||||
assert_eq!(span, info.span);
|
||||
assert!(info.trim_end_newline);
|
||||
|
||||
assert_eq!(Some(stdout_len), stdout_info.known_size);
|
||||
assert_eq!(None, stderr_info.known_size);
|
||||
|
||||
// Now make sure the stream messages have been written
|
||||
let mut stdout_iter = stdout_bufs.into_iter();
|
||||
let mut stderr_iter = stderr_bufs.into_iter();
|
||||
let mut exit_code_iter = std::iter::once(exit_code);
|
||||
|
||||
let mut stdout_ended = false;
|
||||
let mut stderr_ended = false;
|
||||
let mut exit_code_ended = false;
|
||||
|
||||
// There's no specific order these messages must come in with respect to how the streams are
|
||||
// interleaved, but all of the data for each stream must be in its original order, and the
|
||||
// End must come after all Data
|
||||
for msg in test.written() {
|
||||
match msg {
|
||||
PluginOutput::Data(id, data) => {
|
||||
if id == stdout_info.id {
|
||||
let result: Result<Vec<u8>, ShellError> =
|
||||
data.try_into().expect("wrong data in stdout stream");
|
||||
assert_eq!(
|
||||
stdout_iter.next().expect("too much data in stdout"),
|
||||
result.expect("unexpected error in stdout stream")
|
||||
);
|
||||
} else if id == stderr_info.id {
|
||||
let result: Result<Vec<u8>, ShellError> =
|
||||
data.try_into().expect("wrong data in stderr stream");
|
||||
assert_eq!(
|
||||
stderr_iter.next().expect("too much data in stderr"),
|
||||
result.expect("unexpected error in stderr stream")
|
||||
);
|
||||
} else if id == exit_code_info.id {
|
||||
let code: Value = data.try_into().expect("wrong data in stderr stream");
|
||||
assert_eq!(
|
||||
exit_code_iter.next().expect("too much data in stderr"),
|
||||
code
|
||||
);
|
||||
} else {
|
||||
panic!("unrecognized stream id: {id}");
|
||||
}
|
||||
}
|
||||
PluginOutput::End(id) => {
|
||||
if id == stdout_info.id {
|
||||
assert!(!stdout_ended, "double End of stdout");
|
||||
assert!(stdout_iter.next().is_none(), "unexpected end of stdout");
|
||||
stdout_ended = true;
|
||||
} else if id == stderr_info.id {
|
||||
assert!(!stderr_ended, "double End of stderr");
|
||||
assert!(stderr_iter.next().is_none(), "unexpected end of stderr");
|
||||
stderr_ended = true;
|
||||
} else if id == exit_code_info.id {
|
||||
assert!(!exit_code_ended, "double End of exit_code");
|
||||
assert!(
|
||||
exit_code_iter.next().is_none(),
|
||||
"unexpected end of exit_code"
|
||||
);
|
||||
exit_code_ended = true;
|
||||
} else {
|
||||
panic!("unrecognized stream id: {id}");
|
||||
}
|
||||
}
|
||||
other => panic!("unexpected output: {other:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
assert!(stdout_ended, "stdout did not End");
|
||||
assert!(stderr_ended, "stderr did not End");
|
||||
assert!(exit_code_ended, "exit_code did not End");
|
||||
|
||||
Ok(())
|
||||
}
|
Reference in New Issue
Block a user