Split the plugin crate (#12563)

# Description

This breaks `nu-plugin` up into four crates:

- `nu-plugin-protocol`: just the type definitions for the protocol, no
I/O. If someone wanted to wire up something more bare metal, maybe for
async I/O, they could use this.
- `nu-plugin-core`: the shared stuff between engine/plugin. Less stable
interface.
- `nu-plugin-engine`: everything required for the engine to talk to
plugins. Less stable interface.
- `nu-plugin`: everything required for the plugin to talk to the engine,
what plugin developers use. Should be the most stable interface.

No changes are made to the interface exposed by `nu-plugin` - it should
all still be there. Re-exports from `nu-plugin-protocol` or
`nu-plugin-core` are used as required. Plugins shouldn't ever have to
use those crates directly.

This should be somewhat faster to compile as `nu-plugin-engine` and
`nu-plugin` can compile in parallel, and the engine doesn't need
`nu-plugin` and plugins don't need `nu-plugin-engine` (except for test
support), so that should reduce what needs to be compiled too.

The only significant change here other than splitting stuff up was to
break the `source` out of `PluginCustomValue` and create a new
`PluginCustomValueWithSource` type that contains that instead. One bonus
of that is we get rid of the option and it's now more type-safe, but it
also means that the logic for that stuff (actually running the plugin
for custom value ops) can live entirely within the `nu-plugin-engine`
crate.

# User-Facing Changes
- New crates.
- Added `local-socket` feature for `nu` to try to make it possible to
compile without that support if needed.

# Tests + Formatting
- 🟢 `toolkit fmt`
- 🟢 `toolkit clippy`
- 🟢 `toolkit test`
- 🟢 `toolkit test stdlib`
This commit is contained in:
Devyn Cairns
2024-04-27 10:08:12 -07:00
committed by GitHub
parent 884d5312bb
commit 0c4d5330ee
74 changed files with 3514 additions and 3110 deletions

View File

@ -0,0 +1,628 @@
use nu_plugin_protocol::{StreamData, StreamId, StreamMessage};
use nu_protocol::{ShellError, Span, Value};
use std::{
collections::{btree_map, BTreeMap},
iter::FusedIterator,
marker::PhantomData,
sync::{mpsc, Arc, Condvar, Mutex, MutexGuard, Weak},
};
#[cfg(test)]
mod tests;
/// Receives messages from a stream read from input by a [`StreamManager`].
///
/// The receiver reads for messages of type `Result<Option<StreamData>, ShellError>` from the
/// channel, which is managed by a [`StreamManager`]. Signalling for end-of-stream is explicit
/// through `Ok(Some)`.
///
/// Failing to receive is an error. When end-of-stream is received, the `receiver` is set to `None`
/// and all further calls to `next()` return `None`.
///
/// The type `T` must implement [`FromShellError`], so that errors in the stream can be represented,
/// and `TryFrom<StreamData>` to convert it to the correct type.
///
/// For each message read, it sends [`StreamMessage::Ack`] to the writer. When dropped,
/// it sends [`StreamMessage::Drop`].
#[derive(Debug)]
pub struct StreamReader<T, W>
where
W: WriteStreamMessage,
{
id: StreamId,
receiver: Option<mpsc::Receiver<Result<Option<StreamData>, ShellError>>>,
writer: W,
/// Iterator requires the item type to be fixed, so we have to keep it as part of the type,
/// even though we're actually receiving dynamic data.
marker: PhantomData<fn() -> T>,
}
impl<T, W> StreamReader<T, W>
where
T: TryFrom<StreamData, Error = ShellError>,
W: WriteStreamMessage,
{
/// Create a new StreamReader from parts
fn new(
id: StreamId,
receiver: mpsc::Receiver<Result<Option<StreamData>, ShellError>>,
writer: W,
) -> StreamReader<T, W> {
StreamReader {
id,
receiver: Some(receiver),
writer,
marker: PhantomData,
}
}
/// Receive a message from the channel, or return an error if:
///
/// * the channel couldn't be received from
/// * an error was sent on the channel
/// * the message received couldn't be converted to `T`
pub fn recv(&mut self) -> Result<Option<T>, ShellError> {
let connection_lost = || ShellError::GenericError {
error: "Stream ended unexpectedly".into(),
msg: "connection lost before explicit end of stream".into(),
span: None,
help: None,
inner: vec![],
};
if let Some(ref rx) = self.receiver {
// Try to receive a message first
let msg = match rx.try_recv() {
Ok(msg) => msg?,
Err(mpsc::TryRecvError::Empty) => {
// The receiver doesn't have any messages waiting for us. It's possible that the
// other side hasn't seen our acknowledgements. Let's flush the writer and then
// wait
self.writer.flush()?;
rx.recv().map_err(|_| connection_lost())??
}
Err(mpsc::TryRecvError::Disconnected) => return Err(connection_lost()),
};
if let Some(data) = msg {
// Acknowledge the message
self.writer
.write_stream_message(StreamMessage::Ack(self.id))?;
// Try to convert it into the correct type
Ok(Some(data.try_into()?))
} else {
// Remove the receiver, so that future recv() calls always return Ok(None)
self.receiver = None;
Ok(None)
}
} else {
// Closed already
Ok(None)
}
}
}
impl<T, W> Iterator for StreamReader<T, W>
where
T: FromShellError + TryFrom<StreamData, Error = ShellError>,
W: WriteStreamMessage,
{
type Item = T;
fn next(&mut self) -> Option<T> {
// Converting the error to the value here makes the implementation a lot easier
match self.recv() {
Ok(option) => option,
Err(err) => {
// Drop the receiver so we don't keep returning errors
self.receiver = None;
Some(T::from_shell_error(err))
}
}
}
}
// Guaranteed not to return anything after the end
impl<T, W> FusedIterator for StreamReader<T, W>
where
T: FromShellError + TryFrom<StreamData, Error = ShellError>,
W: WriteStreamMessage,
{
}
impl<T, W> Drop for StreamReader<T, W>
where
W: WriteStreamMessage,
{
fn drop(&mut self) {
if let Err(err) = self
.writer
.write_stream_message(StreamMessage::Drop(self.id))
.and_then(|_| self.writer.flush())
{
log::warn!("Failed to send message to drop stream: {err}");
}
}
}
/// Values that can contain a `ShellError` to signal an error has occurred.
pub trait FromShellError {
fn from_shell_error(err: ShellError) -> Self;
}
// For List streams.
impl FromShellError for Value {
fn from_shell_error(err: ShellError) -> Self {
Value::error(err, Span::unknown())
}
}
// For Raw streams, mostly.
impl<T> FromShellError for Result<T, ShellError> {
fn from_shell_error(err: ShellError) -> Self {
Err(err)
}
}
/// Writes messages to a stream, with flow control.
///
/// The `signal` contained
#[derive(Debug)]
pub struct StreamWriter<W: WriteStreamMessage> {
id: StreamId,
signal: Arc<StreamWriterSignal>,
writer: W,
ended: bool,
}
impl<W> StreamWriter<W>
where
W: WriteStreamMessage,
{
fn new(id: StreamId, signal: Arc<StreamWriterSignal>, writer: W) -> StreamWriter<W> {
StreamWriter {
id,
signal,
writer,
ended: false,
}
}
/// Check if the stream was dropped from the other end. Recommended to do this before calling
/// [`.write()`], especially in a loop.
pub fn is_dropped(&self) -> Result<bool, ShellError> {
self.signal.is_dropped()
}
/// Write a single piece of data to the stream.
///
/// Error if something failed with the write, or if [`.end()`] was already called
/// previously.
pub fn write(&mut self, data: impl Into<StreamData>) -> Result<(), ShellError> {
if !self.ended {
self.writer
.write_stream_message(StreamMessage::Data(self.id, data.into()))?;
// This implements flow control, so we don't write too many messages:
if !self.signal.notify_sent()? {
// Flush the output, and then wait for acknowledgements
self.writer.flush()?;
self.signal.wait_for_drain()
} else {
Ok(())
}
} else {
Err(ShellError::GenericError {
error: "Wrote to a stream after it ended".into(),
msg: format!(
"tried to write to stream {} after it was already ended",
self.id
),
span: None,
help: Some("this may be a bug in the nu-plugin crate".into()),
inner: vec![],
})
}
}
/// Write a full iterator to the stream. Note that this doesn't end the stream, so you should
/// still call [`.end()`].
///
/// If the stream is dropped from the other end, the iterator will not be fully consumed, and
/// writing will terminate.
///
/// Returns `Ok(true)` if the iterator was fully consumed, or `Ok(false)` if a drop interrupted
/// the stream from the other side.
pub fn write_all<T>(&mut self, data: impl IntoIterator<Item = T>) -> Result<bool, ShellError>
where
T: Into<StreamData>,
{
// Check before starting
if self.is_dropped()? {
return Ok(false);
}
for item in data {
// Check again after each item is consumed from the iterator, just in case the iterator
// takes a while to produce a value
if self.is_dropped()? {
return Ok(false);
}
self.write(item)?;
}
Ok(true)
}
/// End the stream. Recommend doing this instead of relying on `Drop` so that you can catch the
/// error.
pub fn end(&mut self) -> Result<(), ShellError> {
if !self.ended {
// Set the flag first so we don't double-report in the Drop
self.ended = true;
self.writer
.write_stream_message(StreamMessage::End(self.id))?;
self.writer.flush()
} else {
Ok(())
}
}
}
impl<W> Drop for StreamWriter<W>
where
W: WriteStreamMessage,
{
fn drop(&mut self) {
// Make sure we ended the stream
if let Err(err) = self.end() {
log::warn!("Error while ending stream in Drop for StreamWriter: {err}");
}
}
}
/// Stores stream state for a writer, and can be blocked on to wait for messages to be acknowledged.
/// A key part of managing stream lifecycle and flow control.
#[derive(Debug)]
pub struct StreamWriterSignal {
mutex: Mutex<StreamWriterSignalState>,
change_cond: Condvar,
}
#[derive(Debug)]
pub struct StreamWriterSignalState {
/// Stream has been dropped and consumer is no longer interested in any messages.
dropped: bool,
/// Number of messages that have been sent without acknowledgement.
unacknowledged: i32,
/// Max number of messages to send before waiting for acknowledgement.
high_pressure_mark: i32,
}
impl StreamWriterSignal {
/// Create a new signal.
///
/// If `notify_sent()` is called more than `high_pressure_mark` times, it will wait until
/// `notify_acknowledge()` is called by another thread enough times to bring the number of
/// unacknowledged sent messages below that threshold.
fn new(high_pressure_mark: i32) -> StreamWriterSignal {
assert!(high_pressure_mark > 0);
StreamWriterSignal {
mutex: Mutex::new(StreamWriterSignalState {
dropped: false,
unacknowledged: 0,
high_pressure_mark,
}),
change_cond: Condvar::new(),
}
}
fn lock(&self) -> Result<MutexGuard<StreamWriterSignalState>, ShellError> {
self.mutex.lock().map_err(|_| ShellError::NushellFailed {
msg: "StreamWriterSignal mutex poisoned due to panic".into(),
})
}
/// True if the stream was dropped and the consumer is no longer interested in it. Indicates
/// that no more messages should be sent, other than `End`.
pub fn is_dropped(&self) -> Result<bool, ShellError> {
Ok(self.lock()?.dropped)
}
/// Notify the writers that the stream has been dropped, so they can stop writing.
pub fn set_dropped(&self) -> Result<(), ShellError> {
let mut state = self.lock()?;
state.dropped = true;
// Unblock the writers so they can terminate
self.change_cond.notify_all();
Ok(())
}
/// Track that a message has been sent. Returns `Ok(true)` if more messages can be sent,
/// or `Ok(false)` if the high pressure mark has been reached and [`.wait_for_drain()`] should
/// be called to block.
pub fn notify_sent(&self) -> Result<bool, ShellError> {
let mut state = self.lock()?;
state.unacknowledged =
state
.unacknowledged
.checked_add(1)
.ok_or_else(|| ShellError::NushellFailed {
msg: "Overflow in counter: too many unacknowledged messages".into(),
})?;
Ok(state.unacknowledged < state.high_pressure_mark)
}
/// Wait for acknowledgements before sending more data. Also returns if the stream is dropped.
pub fn wait_for_drain(&self) -> Result<(), ShellError> {
let mut state = self.lock()?;
while !state.dropped && state.unacknowledged >= state.high_pressure_mark {
state = self
.change_cond
.wait(state)
.map_err(|_| ShellError::NushellFailed {
msg: "StreamWriterSignal mutex poisoned due to panic".into(),
})?;
}
Ok(())
}
/// Notify the writers that a message has been acknowledged, so they can continue to write
/// if they were waiting.
pub fn notify_acknowledged(&self) -> Result<(), ShellError> {
let mut state = self.lock()?;
state.unacknowledged =
state
.unacknowledged
.checked_sub(1)
.ok_or_else(|| ShellError::NushellFailed {
msg: "Underflow in counter: too many message acknowledgements".into(),
})?;
// Unblock the writer
self.change_cond.notify_one();
Ok(())
}
}
/// A sink for a [`StreamMessage`]
pub trait WriteStreamMessage {
fn write_stream_message(&mut self, msg: StreamMessage) -> Result<(), ShellError>;
fn flush(&mut self) -> Result<(), ShellError>;
}
#[derive(Debug, Default)]
struct StreamManagerState {
reading_streams: BTreeMap<StreamId, mpsc::Sender<Result<Option<StreamData>, ShellError>>>,
writing_streams: BTreeMap<StreamId, Weak<StreamWriterSignal>>,
}
impl StreamManagerState {
/// Lock the state, or return a [`ShellError`] if the mutex is poisoned.
fn lock(
state: &Mutex<StreamManagerState>,
) -> Result<MutexGuard<StreamManagerState>, ShellError> {
state.lock().map_err(|_| ShellError::NushellFailed {
msg: "StreamManagerState mutex poisoned due to a panic".into(),
})
}
}
#[derive(Debug)]
pub struct StreamManager {
state: Arc<Mutex<StreamManagerState>>,
}
impl StreamManager {
/// Create a new StreamManager.
pub fn new() -> StreamManager {
StreamManager {
state: Default::default(),
}
}
fn lock(&self) -> Result<MutexGuard<StreamManagerState>, ShellError> {
StreamManagerState::lock(&self.state)
}
/// Create a new handle to the StreamManager for registering streams.
pub fn get_handle(&self) -> StreamManagerHandle {
StreamManagerHandle {
state: Arc::downgrade(&self.state),
}
}
/// Process a stream message, and update internal state accordingly.
pub fn handle_message(&self, message: StreamMessage) -> Result<(), ShellError> {
let mut state = self.lock()?;
match message {
StreamMessage::Data(id, data) => {
if let Some(sender) = state.reading_streams.get(&id) {
// We should ignore the error on send. This just means the reader has dropped,
// but it will have sent a Drop message to the other side, and we will receive
// an End message at which point we can remove the channel.
let _ = sender.send(Ok(Some(data)));
Ok(())
} else {
Err(ShellError::PluginFailedToDecode {
msg: format!("received Data for unknown stream {id}"),
})
}
}
StreamMessage::End(id) => {
if let Some(sender) = state.reading_streams.remove(&id) {
// We should ignore the error on the send, because the reader might have dropped
// already
let _ = sender.send(Ok(None));
Ok(())
} else {
Err(ShellError::PluginFailedToDecode {
msg: format!("received End for unknown stream {id}"),
})
}
}
StreamMessage::Drop(id) => {
if let Some(signal) = state.writing_streams.remove(&id) {
if let Some(signal) = signal.upgrade() {
// This will wake blocked writers so they can stop writing, so it's ok
signal.set_dropped()?;
}
}
// It's possible that the stream has already finished writing and we don't have it
// anymore, so we fall through to Ok
Ok(())
}
StreamMessage::Ack(id) => {
if let Some(signal) = state.writing_streams.get(&id) {
if let Some(signal) = signal.upgrade() {
// This will wake up a blocked writer
signal.notify_acknowledged()?;
} else {
// We know it doesn't exist, so might as well remove it
state.writing_streams.remove(&id);
}
}
// It's possible that the stream has already finished writing and we don't have it
// anymore, so we fall through to Ok
Ok(())
}
}
}
/// Broadcast an error to all stream readers. This is useful for error propagation.
pub fn broadcast_read_error(&self, error: ShellError) -> Result<(), ShellError> {
let state = self.lock()?;
for channel in state.reading_streams.values() {
// Ignore send errors.
let _ = channel.send(Err(error.clone()));
}
Ok(())
}
// If the `StreamManager` is dropped, we should let all of the stream writers know that they
// won't be able to write anymore. We don't need to do anything about the readers though
// because they'll know when the `Sender` is dropped automatically
fn drop_all_writers(&self) -> Result<(), ShellError> {
let mut state = self.lock()?;
let writers = std::mem::take(&mut state.writing_streams);
for (_, signal) in writers {
if let Some(signal) = signal.upgrade() {
// more important that we send to all than handling an error
let _ = signal.set_dropped();
}
}
Ok(())
}
}
impl Default for StreamManager {
fn default() -> Self {
Self::new()
}
}
impl Drop for StreamManager {
fn drop(&mut self) {
if let Err(err) = self.drop_all_writers() {
log::warn!("error during Drop for StreamManager: {}", err)
}
}
}
/// A [`StreamManagerHandle`] supports operations for interacting with the [`StreamManager`].
///
/// Streams can be registered for reading, returning a [`StreamReader`], or for writing, returning
/// a [`StreamWriter`].
#[derive(Debug, Clone)]
pub struct StreamManagerHandle {
state: Weak<Mutex<StreamManagerState>>,
}
impl StreamManagerHandle {
/// Because the handle only has a weak reference to the [`StreamManager`] state, we have to
/// first try to upgrade to a strong reference and then lock. This function wraps those two
/// operations together, handling errors appropriately.
fn with_lock<T, F>(&self, f: F) -> Result<T, ShellError>
where
F: FnOnce(MutexGuard<StreamManagerState>) -> Result<T, ShellError>,
{
let upgraded = self
.state
.upgrade()
.ok_or_else(|| ShellError::NushellFailed {
msg: "StreamManager is no longer alive".into(),
})?;
let guard = upgraded.lock().map_err(|_| ShellError::NushellFailed {
msg: "StreamManagerState mutex poisoned due to a panic".into(),
})?;
f(guard)
}
/// Register a new stream for reading, and return a [`StreamReader`] that can be used to iterate
/// on the values received. A [`StreamMessage`] writer is required for writing control messages
/// back to the producer.
pub fn read_stream<T, W>(
&self,
id: StreamId,
writer: W,
) -> Result<StreamReader<T, W>, ShellError>
where
T: TryFrom<StreamData, Error = ShellError>,
W: WriteStreamMessage,
{
let (tx, rx) = mpsc::channel();
self.with_lock(|mut state| {
// Must be exclusive
if let btree_map::Entry::Vacant(e) = state.reading_streams.entry(id) {
e.insert(tx);
Ok(())
} else {
Err(ShellError::GenericError {
error: format!("Failed to acquire reader for stream {id}"),
msg: "tried to get a reader for a stream that's already being read".into(),
span: None,
help: Some("this may be a bug in the nu-plugin crate".into()),
inner: vec![],
})
}
})?;
Ok(StreamReader::new(id, rx, writer))
}
/// Register a new stream for writing, and return a [`StreamWriter`] that can be used to send
/// data to the stream.
///
/// The `high_pressure_mark` value controls how many messages can be written without receiving
/// an acknowledgement before any further attempts to write will wait for the consumer to
/// acknowledge them. This prevents overwhelming the reader.
pub fn write_stream<W>(
&self,
id: StreamId,
writer: W,
high_pressure_mark: i32,
) -> Result<StreamWriter<W>, ShellError>
where
W: WriteStreamMessage,
{
let signal = Arc::new(StreamWriterSignal::new(high_pressure_mark));
self.with_lock(|mut state| {
// Remove dead writing streams
state
.writing_streams
.retain(|_, signal| signal.strong_count() > 0);
// Must be exclusive
if let btree_map::Entry::Vacant(e) = state.writing_streams.entry(id) {
e.insert(Arc::downgrade(&signal));
Ok(())
} else {
Err(ShellError::GenericError {
error: format!("Failed to acquire writer for stream {id}"),
msg: "tried to get a writer for a stream that's already being written".into(),
span: None,
help: Some("this may be a bug in the nu-plugin crate".into()),
inner: vec![],
})
}
})?;
Ok(StreamWriter::new(id, signal, writer))
}
}

View File

@ -0,0 +1,550 @@
use std::{
sync::{
atomic::{AtomicBool, Ordering::Relaxed},
mpsc, Arc,
},
time::{Duration, Instant},
};
use super::{StreamManager, StreamReader, StreamWriter, StreamWriterSignal, WriteStreamMessage};
use nu_plugin_protocol::{StreamData, StreamMessage};
use nu_protocol::{ShellError, Value};
// Should be long enough to definitely complete any quick operation, but not so long that tests are
// slow to complete. 10 ms is a pretty long time
const WAIT_DURATION: Duration = Duration::from_millis(10);
// Maximum time to wait for a condition to be true
const MAX_WAIT_DURATION: Duration = Duration::from_millis(500);
/// Wait for a condition to be true, or panic if the duration exceeds MAX_WAIT_DURATION
#[track_caller]
fn wait_for_condition(mut cond: impl FnMut() -> bool, message: &str) {
// Early check
if cond() {
return;
}
let start = Instant::now();
loop {
std::thread::sleep(Duration::from_millis(10));
if cond() {
return;
}
let elapsed = Instant::now().saturating_duration_since(start);
if elapsed > MAX_WAIT_DURATION {
panic!(
"{message}: Waited {:.2}sec, which is more than the maximum of {:.2}sec",
elapsed.as_secs_f64(),
MAX_WAIT_DURATION.as_secs_f64(),
);
}
}
}
#[derive(Debug, Clone, Default)]
struct TestSink(Vec<StreamMessage>);
impl WriteStreamMessage for TestSink {
fn write_stream_message(&mut self, msg: StreamMessage) -> Result<(), ShellError> {
self.0.push(msg);
Ok(())
}
fn flush(&mut self) -> Result<(), ShellError> {
Ok(())
}
}
impl WriteStreamMessage for mpsc::Sender<StreamMessage> {
fn write_stream_message(&mut self, msg: StreamMessage) -> Result<(), ShellError> {
self.send(msg).map_err(|err| ShellError::NushellFailed {
msg: err.to_string(),
})
}
fn flush(&mut self) -> Result<(), ShellError> {
Ok(())
}
}
#[test]
fn reader_recv_list_messages() -> Result<(), ShellError> {
let (tx, rx) = mpsc::channel();
let mut reader = StreamReader::new(0, rx, TestSink::default());
tx.send(Ok(Some(StreamData::List(Value::test_int(5)))))
.unwrap();
drop(tx);
assert_eq!(Some(Value::test_int(5)), reader.recv()?);
Ok(())
}
#[test]
fn list_reader_recv_wrong_type() -> Result<(), ShellError> {
let (tx, rx) = mpsc::channel();
let mut reader = StreamReader::<Value, _>::new(0, rx, TestSink::default());
tx.send(Ok(Some(StreamData::Raw(Ok(vec![10, 20])))))
.unwrap();
tx.send(Ok(Some(StreamData::List(Value::test_nothing()))))
.unwrap();
drop(tx);
reader.recv().expect_err("should be an error");
reader.recv().expect("should be able to recover");
Ok(())
}
#[test]
fn reader_recv_raw_messages() -> Result<(), ShellError> {
let (tx, rx) = mpsc::channel();
let mut reader =
StreamReader::<Result<Vec<u8>, ShellError>, _>::new(0, rx, TestSink::default());
tx.send(Ok(Some(StreamData::Raw(Ok(vec![10, 20])))))
.unwrap();
drop(tx);
assert_eq!(Some(vec![10, 20]), reader.recv()?.transpose()?);
Ok(())
}
#[test]
fn raw_reader_recv_wrong_type() -> Result<(), ShellError> {
let (tx, rx) = mpsc::channel();
let mut reader =
StreamReader::<Result<Vec<u8>, ShellError>, _>::new(0, rx, TestSink::default());
tx.send(Ok(Some(StreamData::List(Value::test_nothing()))))
.unwrap();
tx.send(Ok(Some(StreamData::Raw(Ok(vec![10, 20])))))
.unwrap();
drop(tx);
reader.recv().expect_err("should be an error");
reader.recv().expect("should be able to recover");
Ok(())
}
#[test]
fn reader_recv_acknowledge() -> Result<(), ShellError> {
let (tx, rx) = mpsc::channel();
let mut reader = StreamReader::<Value, _>::new(0, rx, TestSink::default());
tx.send(Ok(Some(StreamData::List(Value::test_int(5)))))
.unwrap();
tx.send(Ok(Some(StreamData::List(Value::test_int(6)))))
.unwrap();
drop(tx);
reader.recv()?;
reader.recv()?;
let wrote = &reader.writer.0;
assert!(wrote.len() >= 2);
assert!(
matches!(wrote[0], StreamMessage::Ack(0)),
"0 = {:?}",
wrote[0]
);
assert!(
matches!(wrote[1], StreamMessage::Ack(0)),
"1 = {:?}",
wrote[1]
);
Ok(())
}
#[test]
fn reader_recv_end_of_stream() -> Result<(), ShellError> {
let (tx, rx) = mpsc::channel();
let mut reader = StreamReader::<Value, _>::new(0, rx, TestSink::default());
tx.send(Ok(Some(StreamData::List(Value::test_int(5)))))
.unwrap();
tx.send(Ok(None)).unwrap();
drop(tx);
assert!(reader.recv()?.is_some(), "actual message");
assert!(reader.recv()?.is_none(), "on close");
assert!(reader.recv()?.is_none(), "after close");
Ok(())
}
#[test]
fn reader_iter_fuse_on_error() -> Result<(), ShellError> {
let (tx, rx) = mpsc::channel();
let mut reader = StreamReader::<Value, _>::new(0, rx, TestSink::default());
drop(tx); // should cause error, because we didn't explicitly signal the end
assert!(
reader.next().is_some_and(|e| e.is_error()),
"should be error the first time"
);
assert!(reader.next().is_none(), "should be closed the second time");
Ok(())
}
#[test]
fn reader_drop() {
let (_tx, rx) = mpsc::channel();
// Flag set if drop message is received.
struct Check(Arc<AtomicBool>);
impl WriteStreamMessage for Check {
fn write_stream_message(&mut self, msg: StreamMessage) -> Result<(), ShellError> {
assert!(matches!(msg, StreamMessage::Drop(1)), "got {:?}", msg);
self.0.store(true, Relaxed);
Ok(())
}
fn flush(&mut self) -> Result<(), ShellError> {
Ok(())
}
}
let flag = Arc::new(AtomicBool::new(false));
let reader = StreamReader::<Value, _>::new(1, rx, Check(flag.clone()));
drop(reader);
assert!(flag.load(Relaxed));
}
#[test]
fn writer_write_all_stops_if_dropped() -> Result<(), ShellError> {
let signal = Arc::new(StreamWriterSignal::new(20));
let id = 1337;
let mut writer = StreamWriter::new(id, signal.clone(), TestSink::default());
// Simulate this by having it consume a stream that will actually do the drop halfway through
let iter = (0..5).map(Value::test_int).chain({
let mut n = 5;
std::iter::from_fn(move || {
// produces numbers 5..10, but drops for the first one
if n == 5 {
signal.set_dropped().unwrap();
}
if n < 10 {
let value = Value::test_int(n);
n += 1;
Some(value)
} else {
None
}
})
});
writer.write_all(iter)?;
assert!(writer.is_dropped()?);
let wrote = &writer.writer.0;
assert_eq!(5, wrote.len(), "length wrong: {wrote:?}");
for (n, message) in (0..5).zip(wrote) {
match message {
StreamMessage::Data(msg_id, StreamData::List(value)) => {
assert_eq!(id, *msg_id, "id");
assert_eq!(Value::test_int(n), *value, "value");
}
other => panic!("unexpected message: {other:?}"),
}
}
Ok(())
}
#[test]
fn writer_end() -> Result<(), ShellError> {
let signal = Arc::new(StreamWriterSignal::new(20));
let mut writer = StreamWriter::new(9001, signal.clone(), TestSink::default());
writer.end()?;
writer
.write(Value::test_int(2))
.expect_err("shouldn't be able to write after end");
writer.end().expect("end twice should be ok");
let wrote = &writer.writer.0;
assert!(
matches!(wrote.last(), Some(StreamMessage::End(9001))),
"didn't write end message: {wrote:?}"
);
Ok(())
}
#[test]
fn signal_set_dropped() -> Result<(), ShellError> {
let signal = StreamWriterSignal::new(4);
assert!(!signal.is_dropped()?);
signal.set_dropped()?;
assert!(signal.is_dropped()?);
Ok(())
}
#[test]
fn signal_notify_sent_false_if_unacknowledged() -> Result<(), ShellError> {
let signal = StreamWriterSignal::new(2);
assert!(signal.notify_sent()?);
for _ in 0..100 {
assert!(!signal.notify_sent()?);
}
Ok(())
}
#[test]
fn signal_notify_sent_never_false_if_flowing() -> Result<(), ShellError> {
let signal = StreamWriterSignal::new(1);
for _ in 0..100 {
signal.notify_acknowledged()?;
}
for _ in 0..100 {
assert!(signal.notify_sent()?);
}
Ok(())
}
#[test]
fn signal_wait_for_drain_blocks_on_unacknowledged() -> Result<(), ShellError> {
let signal = StreamWriterSignal::new(50);
std::thread::scope(|scope| {
let spawned = scope.spawn(|| {
for _ in 0..100 {
if !signal.notify_sent()? {
signal.wait_for_drain()?;
}
}
Ok(())
});
std::thread::sleep(WAIT_DURATION);
assert!(!spawned.is_finished(), "didn't block");
for _ in 0..100 {
signal.notify_acknowledged()?;
}
wait_for_condition(|| spawned.is_finished(), "blocked at end");
spawned.join().unwrap()
})
}
#[test]
fn signal_wait_for_drain_unblocks_on_dropped() -> Result<(), ShellError> {
let signal = StreamWriterSignal::new(1);
std::thread::scope(|scope| {
let spawned = scope.spawn(|| {
while !signal.is_dropped()? {
if !signal.notify_sent()? {
signal.wait_for_drain()?;
}
}
Ok(())
});
std::thread::sleep(WAIT_DURATION);
assert!(!spawned.is_finished(), "didn't block");
signal.set_dropped()?;
wait_for_condition(|| spawned.is_finished(), "still blocked at end");
spawned.join().unwrap()
})
}
#[test]
fn stream_manager_single_stream_read_scenario() -> Result<(), ShellError> {
let manager = StreamManager::new();
let handle = manager.get_handle();
let (tx, rx) = mpsc::channel();
let readable = handle.read_stream::<Value, _>(2, tx)?;
let expected_values = vec![Value::test_int(40), Value::test_string("hello")];
for value in &expected_values {
manager.handle_message(StreamMessage::Data(2, value.clone().into()))?;
}
manager.handle_message(StreamMessage::End(2))?;
let values = readable.collect::<Vec<Value>>();
assert_eq!(expected_values, values);
// Now check the sent messages on consumption
// Should be Ack for each message, then Drop
for _ in &expected_values {
match rx.try_recv().expect("failed to receive Ack") {
StreamMessage::Ack(2) => (),
other => panic!("should have been an Ack: {other:?}"),
}
}
match rx.try_recv().expect("failed to receive Drop") {
StreamMessage::Drop(2) => (),
other => panic!("should have been a Drop: {other:?}"),
}
Ok(())
}
#[test]
fn stream_manager_multi_stream_read_scenario() -> Result<(), ShellError> {
let manager = StreamManager::new();
let handle = manager.get_handle();
let (tx, rx) = mpsc::channel();
let readable_list = handle.read_stream::<Value, _>(2, tx.clone())?;
let readable_raw = handle.read_stream::<Result<Vec<u8>, _>, _>(3, tx)?;
let expected_values = (1..100).map(Value::test_int).collect::<Vec<_>>();
let expected_raw_buffers = (1..100).map(|n| vec![n]).collect::<Vec<Vec<u8>>>();
for (value, buf) in expected_values.iter().zip(&expected_raw_buffers) {
manager.handle_message(StreamMessage::Data(2, value.clone().into()))?;
manager.handle_message(StreamMessage::Data(3, StreamData::Raw(Ok(buf.clone()))))?;
}
manager.handle_message(StreamMessage::End(2))?;
manager.handle_message(StreamMessage::End(3))?;
let values = readable_list.collect::<Vec<Value>>();
let bufs = readable_raw.collect::<Result<Vec<Vec<u8>>, _>>()?;
for (expected_value, value) in expected_values.iter().zip(&values) {
assert_eq!(expected_value, value, "in List stream");
}
for (expected_buf, buf) in expected_raw_buffers.iter().zip(&bufs) {
assert_eq!(expected_buf, buf, "in Raw stream");
}
// Now check the sent messages on consumption
// Should be Ack for each message, then Drop
for _ in &expected_values {
match rx.try_recv().expect("failed to receive Ack") {
StreamMessage::Ack(2) => (),
other => panic!("should have been an Ack(2): {other:?}"),
}
}
match rx.try_recv().expect("failed to receive Drop") {
StreamMessage::Drop(2) => (),
other => panic!("should have been a Drop(2): {other:?}"),
}
for _ in &expected_values {
match rx.try_recv().expect("failed to receive Ack") {
StreamMessage::Ack(3) => (),
other => panic!("should have been an Ack(3): {other:?}"),
}
}
match rx.try_recv().expect("failed to receive Drop") {
StreamMessage::Drop(3) => (),
other => panic!("should have been a Drop(3): {other:?}"),
}
// Should be end of stream
assert!(
rx.try_recv().is_err(),
"more messages written to stream than expected"
);
Ok(())
}
#[test]
fn stream_manager_write_scenario() -> Result<(), ShellError> {
let manager = StreamManager::new();
let handle = manager.get_handle();
let (tx, rx) = mpsc::channel();
let mut writable = handle.write_stream(4, tx, 100)?;
let expected_values = vec![b"hello".to_vec(), b"world".to_vec(), b"test".to_vec()];
for value in &expected_values {
writable.write(Ok::<_, ShellError>(value.clone()))?;
}
// Now try signalling ack
assert_eq!(
expected_values.len() as i32,
writable.signal.lock()?.unacknowledged,
"unacknowledged initial count",
);
manager.handle_message(StreamMessage::Ack(4))?;
assert_eq!(
expected_values.len() as i32 - 1,
writable.signal.lock()?.unacknowledged,
"unacknowledged post-Ack count",
);
// ...and Drop
manager.handle_message(StreamMessage::Drop(4))?;
assert!(writable.is_dropped()?);
// Drop the StreamWriter...
drop(writable);
// now check what was actually written
for value in &expected_values {
match rx.try_recv().expect("failed to receive Data") {
StreamMessage::Data(4, StreamData::Raw(Ok(received))) => {
assert_eq!(*value, received);
}
other @ StreamMessage::Data(..) => panic!("wrong Data for {value:?}: {other:?}"),
other => panic!("should have been Data: {other:?}"),
}
}
match rx.try_recv().expect("failed to receive End") {
StreamMessage::End(4) => (),
other => panic!("should have been End: {other:?}"),
}
Ok(())
}
#[test]
fn stream_manager_broadcast_read_error() -> Result<(), ShellError> {
let manager = StreamManager::new();
let handle = manager.get_handle();
let mut readable0 = handle.read_stream::<Value, _>(0, TestSink::default())?;
let mut readable1 = handle.read_stream::<Result<Vec<u8>, _>, _>(1, TestSink::default())?;
let error = ShellError::PluginFailedToDecode {
msg: "test decode error".into(),
};
manager.broadcast_read_error(error.clone())?;
drop(manager);
assert_eq!(
error.to_string(),
readable0
.recv()
.transpose()
.expect("nothing received from readable0")
.expect_err("not an error received from readable0")
.to_string()
);
assert_eq!(
error.to_string(),
readable1
.next()
.expect("nothing received from readable1")
.expect_err("not an error received from readable1")
.to_string()
);
Ok(())
}
#[test]
fn stream_manager_drop_writers_on_drop() -> Result<(), ShellError> {
let manager = StreamManager::new();
let handle = manager.get_handle();
let writable = handle.write_stream(4, TestSink::default(), 100)?;
assert!(!writable.is_dropped()?);
drop(manager);
assert!(writable.is_dropped()?);
Ok(())
}