mirror of
https://github.com/nushell/nushell.git
synced 2025-08-09 13:26:01 +02:00
Split the plugin crate (#12563)
# Description This breaks `nu-plugin` up into four crates: - `nu-plugin-protocol`: just the type definitions for the protocol, no I/O. If someone wanted to wire up something more bare metal, maybe for async I/O, they could use this. - `nu-plugin-core`: the shared stuff between engine/plugin. Less stable interface. - `nu-plugin-engine`: everything required for the engine to talk to plugins. Less stable interface. - `nu-plugin`: everything required for the plugin to talk to the engine, what plugin developers use. Should be the most stable interface. No changes are made to the interface exposed by `nu-plugin` - it should all still be there. Re-exports from `nu-plugin-protocol` or `nu-plugin-core` are used as required. Plugins shouldn't ever have to use those crates directly. This should be somewhat faster to compile as `nu-plugin-engine` and `nu-plugin` can compile in parallel, and the engine doesn't need `nu-plugin` and plugins don't need `nu-plugin-engine` (except for test support), so that should reduce what needs to be compiled too. The only significant change here other than splitting stuff up was to break the `source` out of `PluginCustomValue` and create a new `PluginCustomValueWithSource` type that contains that instead. One bonus of that is we get rid of the option and it's now more type-safe, but it also means that the logic for that stuff (actually running the plugin for custom value ops) can live entirely within the `nu-plugin-engine` crate. # User-Facing Changes - New crates. - Added `local-socket` feature for `nu` to try to make it possible to compile without that support if needed. # Tests + Formatting - 🟢 `toolkit fmt` - 🟢 `toolkit clippy` - 🟢 `toolkit test` - 🟢 `toolkit test stdlib`
This commit is contained in:
@ -0,0 +1,84 @@
|
||||
use std::ffi::OsString;
|
||||
|
||||
#[cfg(test)]
|
||||
pub(crate) mod tests;
|
||||
|
||||
/// Generate a name to be used for a local socket specific to this `nu` process, described by the
|
||||
/// given `unique_id`, which should be unique to the purpose of the socket.
|
||||
///
|
||||
/// On Unix, this is a path, which should generally be 100 characters or less for compatibility. On
|
||||
/// Windows, this is a name within the `\\.\pipe` namespace.
|
||||
#[cfg(unix)]
|
||||
pub fn make_local_socket_name(unique_id: &str) -> OsString {
|
||||
// Prefer to put it in XDG_RUNTIME_DIR if set, since that's user-local
|
||||
let mut base = if let Some(runtime_dir) = std::env::var_os("XDG_RUNTIME_DIR") {
|
||||
std::path::PathBuf::from(runtime_dir)
|
||||
} else {
|
||||
// Use std::env::temp_dir() for portability, especially since on Android this is probably
|
||||
// not `/tmp`
|
||||
std::env::temp_dir()
|
||||
};
|
||||
let socket_name = format!("nu.{}.{}.sock", std::process::id(), unique_id);
|
||||
base.push(socket_name);
|
||||
base.into()
|
||||
}
|
||||
|
||||
/// Generate a name to be used for a local socket specific to this `nu` process, described by the
|
||||
/// given `unique_id`, which should be unique to the purpose of the socket.
|
||||
///
|
||||
/// On Unix, this is a path, which should generally be 100 characters or less for compatibility. On
|
||||
/// Windows, this is a name within the `\\.\pipe` namespace.
|
||||
#[cfg(windows)]
|
||||
pub fn make_local_socket_name(unique_id: &str) -> OsString {
|
||||
format!("nu.{}.{}", std::process::id(), unique_id).into()
|
||||
}
|
||||
|
||||
/// Determine if the error is just due to the listener not being ready yet in asynchronous mode
|
||||
#[cfg(not(windows))]
|
||||
pub fn is_would_block_err(err: &std::io::Error) -> bool {
|
||||
err.kind() == std::io::ErrorKind::WouldBlock
|
||||
}
|
||||
|
||||
/// Determine if the error is just due to the listener not being ready yet in asynchronous mode
|
||||
#[cfg(windows)]
|
||||
pub fn is_would_block_err(err: &std::io::Error) -> bool {
|
||||
err.kind() == std::io::ErrorKind::WouldBlock
|
||||
|| err.raw_os_error().is_some_and(|e| {
|
||||
// Windows returns this error when trying to accept a pipe in non-blocking mode
|
||||
e as i64 == windows::Win32::Foundation::ERROR_PIPE_LISTENING.0 as i64
|
||||
})
|
||||
}
|
||||
|
||||
/// Wraps the `interprocess` local socket stream for greater compatibility
|
||||
#[derive(Debug)]
|
||||
pub struct LocalSocketStream(pub interprocess::local_socket::LocalSocketStream);
|
||||
|
||||
impl From<interprocess::local_socket::LocalSocketStream> for LocalSocketStream {
|
||||
fn from(value: interprocess::local_socket::LocalSocketStream) -> Self {
|
||||
LocalSocketStream(value)
|
||||
}
|
||||
}
|
||||
|
||||
impl std::io::Read for LocalSocketStream {
|
||||
fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
|
||||
self.0.read(buf)
|
||||
}
|
||||
}
|
||||
|
||||
impl std::io::Write for LocalSocketStream {
|
||||
fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
|
||||
self.0.write(buf)
|
||||
}
|
||||
|
||||
fn flush(&mut self) -> std::io::Result<()> {
|
||||
// We don't actually flush the underlying socket on Windows. The flush operation on a
|
||||
// Windows named pipe actually synchronizes with read on the other side, and won't finish
|
||||
// until the other side is empty. This isn't how most of our other I/O methods work, so we
|
||||
// just won't do it. The BufWriter above this will have still made a write call with the
|
||||
// contents of the buffer, which should be good enough.
|
||||
if cfg!(not(windows)) {
|
||||
self.0.flush()?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
@ -0,0 +1,19 @@
|
||||
use super::make_local_socket_name;
|
||||
|
||||
#[test]
|
||||
fn local_socket_path_contains_pid() {
|
||||
let name = make_local_socket_name("test-string")
|
||||
.to_string_lossy()
|
||||
.into_owned();
|
||||
println!("{}", name);
|
||||
assert!(name.to_string().contains(&std::process::id().to_string()));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn local_socket_path_contains_provided_name() {
|
||||
let name = make_local_socket_name("test-string")
|
||||
.to_string_lossy()
|
||||
.into_owned();
|
||||
println!("{}", name);
|
||||
assert!(name.to_string().contains("test-string"));
|
||||
}
|
249
crates/nu-plugin-core/src/communication_mode/mod.rs
Normal file
249
crates/nu-plugin-core/src/communication_mode/mod.rs
Normal file
@ -0,0 +1,249 @@
|
||||
use std::ffi::OsStr;
|
||||
use std::io::{Stdin, Stdout};
|
||||
use std::process::{Child, ChildStdin, ChildStdout, Command, Stdio};
|
||||
|
||||
use nu_protocol::ShellError;
|
||||
|
||||
#[cfg(feature = "local-socket")]
|
||||
use interprocess::local_socket::LocalSocketListener;
|
||||
|
||||
#[cfg(feature = "local-socket")]
|
||||
mod local_socket;
|
||||
|
||||
#[cfg(feature = "local-socket")]
|
||||
use local_socket::*;
|
||||
|
||||
/// The type of communication used between the plugin and the engine.
|
||||
///
|
||||
/// `Stdio` is required to be supported by all plugins, and is attempted initially. If the
|
||||
/// `local-socket` feature is enabled and the plugin supports it, `LocalSocket` may be attempted.
|
||||
///
|
||||
/// Local socket communication has the benefit of not tying up stdio, so it's more compatible with
|
||||
/// plugins that want to take user input from the terminal in some way.
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum CommunicationMode {
|
||||
/// Communicate using `stdin` and `stdout`.
|
||||
Stdio,
|
||||
/// Communicate using an operating system-specific local socket.
|
||||
#[cfg(feature = "local-socket")]
|
||||
LocalSocket(std::ffi::OsString),
|
||||
}
|
||||
|
||||
impl CommunicationMode {
|
||||
/// Generate a new local socket communication mode based on the given plugin exe path.
|
||||
#[cfg(feature = "local-socket")]
|
||||
pub fn local_socket(plugin_exe: &std::path::Path) -> CommunicationMode {
|
||||
use std::hash::{Hash, Hasher};
|
||||
use std::time::SystemTime;
|
||||
|
||||
// Generate the unique ID based on the plugin path and the current time. The actual
|
||||
// algorithm here is not very important, we just want this to be relatively unique very
|
||||
// briefly. Using the default hasher in the stdlib means zero extra dependencies.
|
||||
let mut hasher = std::collections::hash_map::DefaultHasher::new();
|
||||
|
||||
plugin_exe.hash(&mut hasher);
|
||||
SystemTime::now().hash(&mut hasher);
|
||||
|
||||
let unique_id = format!("{:016x}", hasher.finish());
|
||||
|
||||
CommunicationMode::LocalSocket(make_local_socket_name(&unique_id))
|
||||
}
|
||||
|
||||
pub fn args(&self) -> Vec<&OsStr> {
|
||||
match self {
|
||||
CommunicationMode::Stdio => vec![OsStr::new("--stdio")],
|
||||
#[cfg(feature = "local-socket")]
|
||||
CommunicationMode::LocalSocket(path) => {
|
||||
vec![OsStr::new("--local-socket"), path.as_os_str()]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn setup_command_io(&self, command: &mut Command) {
|
||||
match self {
|
||||
CommunicationMode::Stdio => {
|
||||
// Both stdout and stdin are piped so we can receive information from the plugin
|
||||
command.stdin(Stdio::piped());
|
||||
command.stdout(Stdio::piped());
|
||||
}
|
||||
#[cfg(feature = "local-socket")]
|
||||
CommunicationMode::LocalSocket(_) => {
|
||||
// Stdio can be used by the plugin to talk to the terminal in local socket mode,
|
||||
// which is the big benefit
|
||||
command.stdin(Stdio::inherit());
|
||||
command.stdout(Stdio::inherit());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn serve(&self) -> Result<PreparedServerCommunication, ShellError> {
|
||||
match self {
|
||||
// Nothing to set up for stdio - we just take it from the child.
|
||||
CommunicationMode::Stdio => Ok(PreparedServerCommunication::Stdio),
|
||||
// For sockets: we need to create the server so that the child won't fail to connect.
|
||||
#[cfg(feature = "local-socket")]
|
||||
CommunicationMode::LocalSocket(name) => {
|
||||
let listener = LocalSocketListener::bind(name.as_os_str()).map_err(|err| {
|
||||
ShellError::IOError {
|
||||
msg: format!("failed to open socket for plugin: {err}"),
|
||||
}
|
||||
})?;
|
||||
Ok(PreparedServerCommunication::LocalSocket {
|
||||
name: name.clone(),
|
||||
listener,
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn connect_as_client(&self) -> Result<ClientCommunicationIo, ShellError> {
|
||||
match self {
|
||||
CommunicationMode::Stdio => Ok(ClientCommunicationIo::Stdio(
|
||||
std::io::stdin(),
|
||||
std::io::stdout(),
|
||||
)),
|
||||
#[cfg(feature = "local-socket")]
|
||||
CommunicationMode::LocalSocket(name) => {
|
||||
// Connect to the specified socket.
|
||||
let get_socket = || {
|
||||
use interprocess::local_socket as ls;
|
||||
ls::LocalSocketStream::connect(name.as_os_str())
|
||||
.map_err(|err| ShellError::IOError {
|
||||
msg: format!("failed to connect to socket: {err}"),
|
||||
})
|
||||
.map(LocalSocketStream::from)
|
||||
};
|
||||
// Reverse order from the server: read in, write out
|
||||
let read_in = get_socket()?;
|
||||
let write_out = get_socket()?;
|
||||
Ok(ClientCommunicationIo::LocalSocket { read_in, write_out })
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// The result of [`CommunicationMode::serve()`], which acts as an intermediate stage for
|
||||
/// communication modes that require some kind of socket binding to occur before the client process
|
||||
/// can be started. Call [`.connect()`] once the client process has been started.
|
||||
///
|
||||
/// The socket may be cleaned up on `Drop` if applicable.
|
||||
pub enum PreparedServerCommunication {
|
||||
/// Will take stdin and stdout from the process on [`.connect()`].
|
||||
Stdio,
|
||||
/// Contains the listener to accept connections on. On Unix, the socket is unlinked on `Drop`.
|
||||
#[cfg(feature = "local-socket")]
|
||||
LocalSocket {
|
||||
#[cfg_attr(windows, allow(dead_code))] // not used on Windows
|
||||
name: std::ffi::OsString,
|
||||
listener: LocalSocketListener,
|
||||
},
|
||||
}
|
||||
|
||||
impl PreparedServerCommunication {
|
||||
pub fn connect(&self, child: &mut Child) -> Result<ServerCommunicationIo, ShellError> {
|
||||
match self {
|
||||
PreparedServerCommunication::Stdio => {
|
||||
let stdin = child
|
||||
.stdin
|
||||
.take()
|
||||
.ok_or_else(|| ShellError::PluginFailedToLoad {
|
||||
msg: "Plugin missing stdin writer".into(),
|
||||
})?;
|
||||
|
||||
let stdout = child
|
||||
.stdout
|
||||
.take()
|
||||
.ok_or_else(|| ShellError::PluginFailedToLoad {
|
||||
msg: "Plugin missing stdout writer".into(),
|
||||
})?;
|
||||
|
||||
Ok(ServerCommunicationIo::Stdio(stdin, stdout))
|
||||
}
|
||||
#[cfg(feature = "local-socket")]
|
||||
PreparedServerCommunication::LocalSocket { listener, .. } => {
|
||||
use std::time::{Duration, Instant};
|
||||
|
||||
const RETRY_PERIOD: Duration = Duration::from_millis(1);
|
||||
const TIMEOUT: Duration = Duration::from_secs(10);
|
||||
|
||||
let start = Instant::now();
|
||||
|
||||
// Use a loop to try to get two clients from the listener: one for read (the plugin
|
||||
// output) and one for write (the plugin input)
|
||||
listener.set_nonblocking(true)?;
|
||||
let mut get_socket = || {
|
||||
let mut result = None;
|
||||
while let Ok(None) = child.try_wait() {
|
||||
match listener.accept() {
|
||||
Ok(stream) => {
|
||||
// Success! But make sure the stream is in blocking mode.
|
||||
stream.set_nonblocking(false)?;
|
||||
result = Some(stream);
|
||||
break;
|
||||
}
|
||||
Err(err) => {
|
||||
if !is_would_block_err(&err) {
|
||||
// `WouldBlock` is ok, just means it's not ready yet, but some other
|
||||
// kind of error should be reported
|
||||
return Err(err.into());
|
||||
}
|
||||
}
|
||||
}
|
||||
if Instant::now().saturating_duration_since(start) > TIMEOUT {
|
||||
return Err(ShellError::PluginFailedToLoad {
|
||||
msg: "Plugin timed out while waiting to connect to socket".into(),
|
||||
});
|
||||
} else {
|
||||
std::thread::sleep(RETRY_PERIOD);
|
||||
}
|
||||
}
|
||||
if let Some(stream) = result {
|
||||
Ok(LocalSocketStream(stream))
|
||||
} else {
|
||||
// The process may have exited
|
||||
Err(ShellError::PluginFailedToLoad {
|
||||
msg: "Plugin exited without connecting".into(),
|
||||
})
|
||||
}
|
||||
};
|
||||
// Input stream always comes before output
|
||||
let write_in = get_socket()?;
|
||||
let read_out = get_socket()?;
|
||||
Ok(ServerCommunicationIo::LocalSocket { read_out, write_in })
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for PreparedServerCommunication {
|
||||
fn drop(&mut self) {
|
||||
match self {
|
||||
#[cfg(all(unix, feature = "local-socket"))]
|
||||
PreparedServerCommunication::LocalSocket { name: path, .. } => {
|
||||
// Just try to remove the socket file, it's ok if this fails
|
||||
let _ = std::fs::remove_file(path);
|
||||
}
|
||||
_ => (),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// The required streams for communication from the engine side, i.e. the server in socket terms.
|
||||
pub enum ServerCommunicationIo {
|
||||
Stdio(ChildStdin, ChildStdout),
|
||||
#[cfg(feature = "local-socket")]
|
||||
LocalSocket {
|
||||
read_out: LocalSocketStream,
|
||||
write_in: LocalSocketStream,
|
||||
},
|
||||
}
|
||||
|
||||
/// The required streams for communication from the plugin side, i.e. the client in socket terms.
|
||||
pub enum ClientCommunicationIo {
|
||||
Stdio(Stdin, Stdout),
|
||||
#[cfg(feature = "local-socket")]
|
||||
LocalSocket {
|
||||
read_in: LocalSocketStream,
|
||||
write_out: LocalSocketStream,
|
||||
},
|
||||
}
|
453
crates/nu-plugin-core/src/interface/mod.rs
Normal file
453
crates/nu-plugin-core/src/interface/mod.rs
Normal file
@ -0,0 +1,453 @@
|
||||
//! Implements the stream multiplexing interface for both the plugin side and the engine side.
|
||||
|
||||
use nu_plugin_protocol::{
|
||||
ExternalStreamInfo, ListStreamInfo, PipelineDataHeader, RawStreamInfo, StreamMessage,
|
||||
};
|
||||
use nu_protocol::{ListStream, PipelineData, RawStream, ShellError};
|
||||
use std::{
|
||||
io::Write,
|
||||
sync::{
|
||||
atomic::{AtomicBool, Ordering::Relaxed},
|
||||
Arc, Mutex,
|
||||
},
|
||||
thread,
|
||||
};
|
||||
|
||||
pub mod stream;
|
||||
|
||||
use crate::{util::Sequence, Encoder};
|
||||
|
||||
use self::stream::{StreamManager, StreamManagerHandle, StreamWriter, WriteStreamMessage};
|
||||
|
||||
pub mod test_util;
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests;
|
||||
|
||||
/// The maximum number of list stream values to send without acknowledgement. This should be tuned
|
||||
/// with consideration for memory usage.
|
||||
const LIST_STREAM_HIGH_PRESSURE: i32 = 100;
|
||||
|
||||
/// The maximum number of raw stream buffers to send without acknowledgement. This should be tuned
|
||||
/// with consideration for memory usage.
|
||||
const RAW_STREAM_HIGH_PRESSURE: i32 = 50;
|
||||
|
||||
/// Read input/output from the stream.
|
||||
pub trait PluginRead<T> {
|
||||
/// Returns `Ok(None)` on end of stream.
|
||||
fn read(&mut self) -> Result<Option<T>, ShellError>;
|
||||
}
|
||||
|
||||
impl<R, E, T> PluginRead<T> for (R, E)
|
||||
where
|
||||
R: std::io::BufRead,
|
||||
E: Encoder<T>,
|
||||
{
|
||||
fn read(&mut self) -> Result<Option<T>, ShellError> {
|
||||
self.1.decode(&mut self.0)
|
||||
}
|
||||
}
|
||||
|
||||
impl<R, T> PluginRead<T> for &mut R
|
||||
where
|
||||
R: PluginRead<T>,
|
||||
{
|
||||
fn read(&mut self) -> Result<Option<T>, ShellError> {
|
||||
(**self).read()
|
||||
}
|
||||
}
|
||||
|
||||
/// Write input/output to the stream.
|
||||
///
|
||||
/// The write should be atomic, without interference from other threads.
|
||||
pub trait PluginWrite<T>: Send + Sync {
|
||||
fn write(&self, data: &T) -> Result<(), ShellError>;
|
||||
|
||||
/// Flush any internal buffers, if applicable.
|
||||
fn flush(&self) -> Result<(), ShellError>;
|
||||
|
||||
/// True if this output is stdout, so that plugins can avoid using stdout for their own purpose
|
||||
fn is_stdout(&self) -> bool {
|
||||
false
|
||||
}
|
||||
}
|
||||
|
||||
impl<E, T> PluginWrite<T> for (std::io::Stdout, E)
|
||||
where
|
||||
E: Encoder<T>,
|
||||
{
|
||||
fn write(&self, data: &T) -> Result<(), ShellError> {
|
||||
let mut lock = self.0.lock();
|
||||
self.1.encode(data, &mut lock)
|
||||
}
|
||||
|
||||
fn flush(&self) -> Result<(), ShellError> {
|
||||
self.0.lock().flush().map_err(|err| ShellError::IOError {
|
||||
msg: err.to_string(),
|
||||
})
|
||||
}
|
||||
|
||||
fn is_stdout(&self) -> bool {
|
||||
true
|
||||
}
|
||||
}
|
||||
|
||||
impl<W, E, T> PluginWrite<T> for (Mutex<W>, E)
|
||||
where
|
||||
W: std::io::Write + Send,
|
||||
E: Encoder<T>,
|
||||
{
|
||||
fn write(&self, data: &T) -> Result<(), ShellError> {
|
||||
let mut lock = self.0.lock().map_err(|_| ShellError::NushellFailed {
|
||||
msg: "writer mutex poisoned".into(),
|
||||
})?;
|
||||
self.1.encode(data, &mut *lock)
|
||||
}
|
||||
|
||||
fn flush(&self) -> Result<(), ShellError> {
|
||||
let mut lock = self.0.lock().map_err(|_| ShellError::NushellFailed {
|
||||
msg: "writer mutex poisoned".into(),
|
||||
})?;
|
||||
lock.flush().map_err(|err| ShellError::IOError {
|
||||
msg: err.to_string(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl<W, T> PluginWrite<T> for &W
|
||||
where
|
||||
W: PluginWrite<T>,
|
||||
{
|
||||
fn write(&self, data: &T) -> Result<(), ShellError> {
|
||||
(**self).write(data)
|
||||
}
|
||||
|
||||
fn flush(&self) -> Result<(), ShellError> {
|
||||
(**self).flush()
|
||||
}
|
||||
|
||||
fn is_stdout(&self) -> bool {
|
||||
(**self).is_stdout()
|
||||
}
|
||||
}
|
||||
|
||||
/// An interface manager handles I/O and state management for communication between a plugin and
|
||||
/// the engine. See `PluginInterfaceManager` in `nu-plugin-engine` for communication from the engine
|
||||
/// side to a plugin, or `EngineInterfaceManager` in `nu-plugin` for communication from the plugin
|
||||
/// side to the engine.
|
||||
///
|
||||
/// There is typically one [`InterfaceManager`] consuming input from a background thread, and
|
||||
/// managing shared state.
|
||||
pub trait InterfaceManager {
|
||||
/// The corresponding interface type.
|
||||
type Interface: Interface + 'static;
|
||||
|
||||
/// The input message type.
|
||||
type Input;
|
||||
|
||||
/// Make a new interface that communicates with this [`InterfaceManager`].
|
||||
fn get_interface(&self) -> Self::Interface;
|
||||
|
||||
/// Consume an input message.
|
||||
///
|
||||
/// When implementing, call [`.consume_stream_message()`] for any encapsulated
|
||||
/// [`StreamMessage`]s received.
|
||||
fn consume(&mut self, input: Self::Input) -> Result<(), ShellError>;
|
||||
|
||||
/// Get the [`StreamManager`] for handling operations related to stream messages.
|
||||
fn stream_manager(&self) -> &StreamManager;
|
||||
|
||||
/// Prepare [`PipelineData`] after reading. This is called by `read_pipeline_data()` as
|
||||
/// a hook so that values that need special handling can be taken care of.
|
||||
fn prepare_pipeline_data(&self, data: PipelineData) -> Result<PipelineData, ShellError>;
|
||||
|
||||
/// Consume an input stream message.
|
||||
///
|
||||
/// This method is provided for implementors to use.
|
||||
fn consume_stream_message(&mut self, message: StreamMessage) -> Result<(), ShellError> {
|
||||
self.stream_manager().handle_message(message)
|
||||
}
|
||||
|
||||
/// Generate `PipelineData` for reading a stream, given a [`PipelineDataHeader`] that was
|
||||
/// received from the other side.
|
||||
///
|
||||
/// This method is provided for implementors to use.
|
||||
fn read_pipeline_data(
|
||||
&self,
|
||||
header: PipelineDataHeader,
|
||||
ctrlc: Option<&Arc<AtomicBool>>,
|
||||
) -> Result<PipelineData, ShellError> {
|
||||
self.prepare_pipeline_data(match header {
|
||||
PipelineDataHeader::Empty => PipelineData::Empty,
|
||||
PipelineDataHeader::Value(value) => PipelineData::Value(value, None),
|
||||
PipelineDataHeader::ListStream(info) => {
|
||||
let handle = self.stream_manager().get_handle();
|
||||
let reader = handle.read_stream(info.id, self.get_interface())?;
|
||||
PipelineData::ListStream(ListStream::from_stream(reader, ctrlc.cloned()), None)
|
||||
}
|
||||
PipelineDataHeader::ExternalStream(info) => {
|
||||
let handle = self.stream_manager().get_handle();
|
||||
let span = info.span;
|
||||
let new_raw_stream = |raw_info: RawStreamInfo| {
|
||||
let reader = handle.read_stream(raw_info.id, self.get_interface())?;
|
||||
let mut stream =
|
||||
RawStream::new(Box::new(reader), ctrlc.cloned(), span, raw_info.known_size);
|
||||
stream.is_binary = raw_info.is_binary;
|
||||
Ok::<_, ShellError>(stream)
|
||||
};
|
||||
PipelineData::ExternalStream {
|
||||
stdout: info.stdout.map(new_raw_stream).transpose()?,
|
||||
stderr: info.stderr.map(new_raw_stream).transpose()?,
|
||||
exit_code: info
|
||||
.exit_code
|
||||
.map(|list_info| {
|
||||
handle
|
||||
.read_stream(list_info.id, self.get_interface())
|
||||
.map(|reader| ListStream::from_stream(reader, ctrlc.cloned()))
|
||||
})
|
||||
.transpose()?,
|
||||
span: info.span,
|
||||
metadata: None,
|
||||
trim_end_newline: info.trim_end_newline,
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
/// An interface provides an API for communicating with a plugin or the engine and facilitates
|
||||
/// stream I/O. See `PluginInterface` in `nu-plugin-engine` for the API from the engine side to a
|
||||
/// plugin, or `EngineInterface` in `nu-plugin` for the API from the plugin side to the engine.
|
||||
///
|
||||
/// There can be multiple copies of the interface managed by a single [`InterfaceManager`].
|
||||
pub trait Interface: Clone + Send {
|
||||
/// The output message type, which must be capable of encapsulating a [`StreamMessage`].
|
||||
type Output: From<StreamMessage>;
|
||||
|
||||
/// Any context required to construct [`PipelineData`]. Can be `()` if not needed.
|
||||
type DataContext;
|
||||
|
||||
/// Write an output message.
|
||||
fn write(&self, output: Self::Output) -> Result<(), ShellError>;
|
||||
|
||||
/// Flush the output buffer, so messages are visible to the other side.
|
||||
fn flush(&self) -> Result<(), ShellError>;
|
||||
|
||||
/// Get the sequence for generating new [`StreamId`](nu_plugin_protocol::StreamId)s.
|
||||
fn stream_id_sequence(&self) -> &Sequence;
|
||||
|
||||
/// Get the [`StreamManagerHandle`] for doing stream operations.
|
||||
fn stream_manager_handle(&self) -> &StreamManagerHandle;
|
||||
|
||||
/// Prepare [`PipelineData`] to be written. This is called by `init_write_pipeline_data()` as
|
||||
/// a hook so that values that need special handling can be taken care of.
|
||||
fn prepare_pipeline_data(
|
||||
&self,
|
||||
data: PipelineData,
|
||||
context: &Self::DataContext,
|
||||
) -> Result<PipelineData, ShellError>;
|
||||
|
||||
/// Initialize a write for [`PipelineData`]. This returns two parts: the header, which can be
|
||||
/// embedded in the particular message that references the stream, and a writer, which will
|
||||
/// write out all of the data in the pipeline when `.write()` is called.
|
||||
///
|
||||
/// Note that not all [`PipelineData`] starts a stream. You should call `write()` anyway, as
|
||||
/// it will automatically handle this case.
|
||||
///
|
||||
/// This method is provided for implementors to use.
|
||||
fn init_write_pipeline_data(
|
||||
&self,
|
||||
data: PipelineData,
|
||||
context: &Self::DataContext,
|
||||
) -> Result<(PipelineDataHeader, PipelineDataWriter<Self>), ShellError> {
|
||||
// Allocate a stream id and a writer
|
||||
let new_stream = |high_pressure_mark: i32| {
|
||||
// Get a free stream id
|
||||
let id = self.stream_id_sequence().next()?;
|
||||
// Create the writer
|
||||
let writer =
|
||||
self.stream_manager_handle()
|
||||
.write_stream(id, self.clone(), high_pressure_mark)?;
|
||||
Ok::<_, ShellError>((id, writer))
|
||||
};
|
||||
match self.prepare_pipeline_data(data, context)? {
|
||||
PipelineData::Value(value, _) => {
|
||||
Ok((PipelineDataHeader::Value(value), PipelineDataWriter::None))
|
||||
}
|
||||
PipelineData::Empty => Ok((PipelineDataHeader::Empty, PipelineDataWriter::None)),
|
||||
PipelineData::ListStream(stream, _) => {
|
||||
let (id, writer) = new_stream(LIST_STREAM_HIGH_PRESSURE)?;
|
||||
Ok((
|
||||
PipelineDataHeader::ListStream(ListStreamInfo { id }),
|
||||
PipelineDataWriter::ListStream(writer, stream),
|
||||
))
|
||||
}
|
||||
PipelineData::ExternalStream {
|
||||
stdout,
|
||||
stderr,
|
||||
exit_code,
|
||||
span,
|
||||
metadata: _,
|
||||
trim_end_newline,
|
||||
} => {
|
||||
// Create the writers and stream ids
|
||||
let stdout_stream = stdout
|
||||
.is_some()
|
||||
.then(|| new_stream(RAW_STREAM_HIGH_PRESSURE))
|
||||
.transpose()?;
|
||||
let stderr_stream = stderr
|
||||
.is_some()
|
||||
.then(|| new_stream(RAW_STREAM_HIGH_PRESSURE))
|
||||
.transpose()?;
|
||||
let exit_code_stream = exit_code
|
||||
.is_some()
|
||||
.then(|| new_stream(LIST_STREAM_HIGH_PRESSURE))
|
||||
.transpose()?;
|
||||
// Generate the header, with the stream ids
|
||||
let header = PipelineDataHeader::ExternalStream(ExternalStreamInfo {
|
||||
span,
|
||||
stdout: stdout
|
||||
.as_ref()
|
||||
.zip(stdout_stream.as_ref())
|
||||
.map(|(stream, (id, _))| RawStreamInfo::new(*id, stream)),
|
||||
stderr: stderr
|
||||
.as_ref()
|
||||
.zip(stderr_stream.as_ref())
|
||||
.map(|(stream, (id, _))| RawStreamInfo::new(*id, stream)),
|
||||
exit_code: exit_code_stream
|
||||
.as_ref()
|
||||
.map(|&(id, _)| ListStreamInfo { id }),
|
||||
trim_end_newline,
|
||||
});
|
||||
// Collect the writers
|
||||
let writer = PipelineDataWriter::ExternalStream {
|
||||
stdout: stdout_stream.map(|(_, writer)| writer).zip(stdout),
|
||||
stderr: stderr_stream.map(|(_, writer)| writer).zip(stderr),
|
||||
exit_code: exit_code_stream.map(|(_, writer)| writer).zip(exit_code),
|
||||
};
|
||||
Ok((header, writer))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> WriteStreamMessage for T
|
||||
where
|
||||
T: Interface,
|
||||
{
|
||||
fn write_stream_message(&mut self, msg: StreamMessage) -> Result<(), ShellError> {
|
||||
self.write(msg.into())
|
||||
}
|
||||
|
||||
fn flush(&mut self) -> Result<(), ShellError> {
|
||||
<Self as Interface>::flush(self)
|
||||
}
|
||||
}
|
||||
|
||||
/// Completes the write operation for a [`PipelineData`]. You must call
|
||||
/// [`PipelineDataWriter::write()`] to write all of the data contained within the streams.
|
||||
#[derive(Default)]
|
||||
#[must_use]
|
||||
pub enum PipelineDataWriter<W: WriteStreamMessage> {
|
||||
#[default]
|
||||
None,
|
||||
ListStream(StreamWriter<W>, ListStream),
|
||||
ExternalStream {
|
||||
stdout: Option<(StreamWriter<W>, RawStream)>,
|
||||
stderr: Option<(StreamWriter<W>, RawStream)>,
|
||||
exit_code: Option<(StreamWriter<W>, ListStream)>,
|
||||
},
|
||||
}
|
||||
|
||||
impl<W> PipelineDataWriter<W>
|
||||
where
|
||||
W: WriteStreamMessage + Send + 'static,
|
||||
{
|
||||
/// Write all of the data in each of the streams. This method waits for completion.
|
||||
pub fn write(self) -> Result<(), ShellError> {
|
||||
match self {
|
||||
// If no stream was contained in the PipelineData, do nothing.
|
||||
PipelineDataWriter::None => Ok(()),
|
||||
// Write a list stream.
|
||||
PipelineDataWriter::ListStream(mut writer, stream) => {
|
||||
writer.write_all(stream)?;
|
||||
Ok(())
|
||||
}
|
||||
// Write all three possible streams of an ExternalStream on separate threads.
|
||||
PipelineDataWriter::ExternalStream {
|
||||
stdout,
|
||||
stderr,
|
||||
exit_code,
|
||||
} => {
|
||||
thread::scope(|scope| {
|
||||
let stderr_thread = stderr
|
||||
.map(|(mut writer, stream)| {
|
||||
thread::Builder::new()
|
||||
.name("plugin stderr writer".into())
|
||||
.spawn_scoped(scope, move || {
|
||||
writer.write_all(raw_stream_iter(stream))
|
||||
})
|
||||
})
|
||||
.transpose()?;
|
||||
let exit_code_thread = exit_code
|
||||
.map(|(mut writer, stream)| {
|
||||
thread::Builder::new()
|
||||
.name("plugin exit_code writer".into())
|
||||
.spawn_scoped(scope, move || writer.write_all(stream))
|
||||
})
|
||||
.transpose()?;
|
||||
// Optimize for stdout: if only stdout is present, don't spawn any other
|
||||
// threads.
|
||||
if let Some((mut writer, stream)) = stdout {
|
||||
writer.write_all(raw_stream_iter(stream))?;
|
||||
}
|
||||
let panicked = |thread_name: &str| {
|
||||
Err(ShellError::NushellFailed {
|
||||
msg: format!(
|
||||
"{thread_name} thread panicked in PipelineDataWriter::write"
|
||||
),
|
||||
})
|
||||
};
|
||||
stderr_thread
|
||||
.map(|t| t.join().unwrap_or_else(|_| panicked("stderr")))
|
||||
.transpose()?;
|
||||
exit_code_thread
|
||||
.map(|t| t.join().unwrap_or_else(|_| panicked("exit_code")))
|
||||
.transpose()?;
|
||||
Ok(())
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Write all of the data in each of the streams. This method returns immediately; any necessary
|
||||
/// write will happen in the background. If a thread was spawned, its handle is returned.
|
||||
pub fn write_background(
|
||||
self,
|
||||
) -> Result<Option<thread::JoinHandle<Result<(), ShellError>>>, ShellError> {
|
||||
match self {
|
||||
PipelineDataWriter::None => Ok(None),
|
||||
_ => Ok(Some(
|
||||
thread::Builder::new()
|
||||
.name("plugin stream background writer".into())
|
||||
.spawn(move || {
|
||||
let result = self.write();
|
||||
if let Err(ref err) = result {
|
||||
// Assume that the background thread error probably won't be handled and log it
|
||||
// here just in case.
|
||||
log::warn!("Error while writing pipeline in background: {err}");
|
||||
}
|
||||
result
|
||||
})?,
|
||||
)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Custom iterator for [`RawStream`] that respects ctrlc, but still has binary chunks
|
||||
fn raw_stream_iter(stream: RawStream) -> impl Iterator<Item = Result<Vec<u8>, ShellError>> {
|
||||
let ctrlc = stream.ctrlc;
|
||||
stream
|
||||
.stream
|
||||
.take_while(move |_| ctrlc.as_ref().map(|b| !b.load(Relaxed)).unwrap_or(true))
|
||||
}
|
628
crates/nu-plugin-core/src/interface/stream/mod.rs
Normal file
628
crates/nu-plugin-core/src/interface/stream/mod.rs
Normal file
@ -0,0 +1,628 @@
|
||||
use nu_plugin_protocol::{StreamData, StreamId, StreamMessage};
|
||||
use nu_protocol::{ShellError, Span, Value};
|
||||
use std::{
|
||||
collections::{btree_map, BTreeMap},
|
||||
iter::FusedIterator,
|
||||
marker::PhantomData,
|
||||
sync::{mpsc, Arc, Condvar, Mutex, MutexGuard, Weak},
|
||||
};
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests;
|
||||
|
||||
/// Receives messages from a stream read from input by a [`StreamManager`].
|
||||
///
|
||||
/// The receiver reads for messages of type `Result<Option<StreamData>, ShellError>` from the
|
||||
/// channel, which is managed by a [`StreamManager`]. Signalling for end-of-stream is explicit
|
||||
/// through `Ok(Some)`.
|
||||
///
|
||||
/// Failing to receive is an error. When end-of-stream is received, the `receiver` is set to `None`
|
||||
/// and all further calls to `next()` return `None`.
|
||||
///
|
||||
/// The type `T` must implement [`FromShellError`], so that errors in the stream can be represented,
|
||||
/// and `TryFrom<StreamData>` to convert it to the correct type.
|
||||
///
|
||||
/// For each message read, it sends [`StreamMessage::Ack`] to the writer. When dropped,
|
||||
/// it sends [`StreamMessage::Drop`].
|
||||
#[derive(Debug)]
|
||||
pub struct StreamReader<T, W>
|
||||
where
|
||||
W: WriteStreamMessage,
|
||||
{
|
||||
id: StreamId,
|
||||
receiver: Option<mpsc::Receiver<Result<Option<StreamData>, ShellError>>>,
|
||||
writer: W,
|
||||
/// Iterator requires the item type to be fixed, so we have to keep it as part of the type,
|
||||
/// even though we're actually receiving dynamic data.
|
||||
marker: PhantomData<fn() -> T>,
|
||||
}
|
||||
|
||||
impl<T, W> StreamReader<T, W>
|
||||
where
|
||||
T: TryFrom<StreamData, Error = ShellError>,
|
||||
W: WriteStreamMessage,
|
||||
{
|
||||
/// Create a new StreamReader from parts
|
||||
fn new(
|
||||
id: StreamId,
|
||||
receiver: mpsc::Receiver<Result<Option<StreamData>, ShellError>>,
|
||||
writer: W,
|
||||
) -> StreamReader<T, W> {
|
||||
StreamReader {
|
||||
id,
|
||||
receiver: Some(receiver),
|
||||
writer,
|
||||
marker: PhantomData,
|
||||
}
|
||||
}
|
||||
|
||||
/// Receive a message from the channel, or return an error if:
|
||||
///
|
||||
/// * the channel couldn't be received from
|
||||
/// * an error was sent on the channel
|
||||
/// * the message received couldn't be converted to `T`
|
||||
pub fn recv(&mut self) -> Result<Option<T>, ShellError> {
|
||||
let connection_lost = || ShellError::GenericError {
|
||||
error: "Stream ended unexpectedly".into(),
|
||||
msg: "connection lost before explicit end of stream".into(),
|
||||
span: None,
|
||||
help: None,
|
||||
inner: vec![],
|
||||
};
|
||||
|
||||
if let Some(ref rx) = self.receiver {
|
||||
// Try to receive a message first
|
||||
let msg = match rx.try_recv() {
|
||||
Ok(msg) => msg?,
|
||||
Err(mpsc::TryRecvError::Empty) => {
|
||||
// The receiver doesn't have any messages waiting for us. It's possible that the
|
||||
// other side hasn't seen our acknowledgements. Let's flush the writer and then
|
||||
// wait
|
||||
self.writer.flush()?;
|
||||
rx.recv().map_err(|_| connection_lost())??
|
||||
}
|
||||
Err(mpsc::TryRecvError::Disconnected) => return Err(connection_lost()),
|
||||
};
|
||||
|
||||
if let Some(data) = msg {
|
||||
// Acknowledge the message
|
||||
self.writer
|
||||
.write_stream_message(StreamMessage::Ack(self.id))?;
|
||||
// Try to convert it into the correct type
|
||||
Ok(Some(data.try_into()?))
|
||||
} else {
|
||||
// Remove the receiver, so that future recv() calls always return Ok(None)
|
||||
self.receiver = None;
|
||||
Ok(None)
|
||||
}
|
||||
} else {
|
||||
// Closed already
|
||||
Ok(None)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T, W> Iterator for StreamReader<T, W>
|
||||
where
|
||||
T: FromShellError + TryFrom<StreamData, Error = ShellError>,
|
||||
W: WriteStreamMessage,
|
||||
{
|
||||
type Item = T;
|
||||
|
||||
fn next(&mut self) -> Option<T> {
|
||||
// Converting the error to the value here makes the implementation a lot easier
|
||||
match self.recv() {
|
||||
Ok(option) => option,
|
||||
Err(err) => {
|
||||
// Drop the receiver so we don't keep returning errors
|
||||
self.receiver = None;
|
||||
Some(T::from_shell_error(err))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Guaranteed not to return anything after the end
|
||||
impl<T, W> FusedIterator for StreamReader<T, W>
|
||||
where
|
||||
T: FromShellError + TryFrom<StreamData, Error = ShellError>,
|
||||
W: WriteStreamMessage,
|
||||
{
|
||||
}
|
||||
|
||||
impl<T, W> Drop for StreamReader<T, W>
|
||||
where
|
||||
W: WriteStreamMessage,
|
||||
{
|
||||
fn drop(&mut self) {
|
||||
if let Err(err) = self
|
||||
.writer
|
||||
.write_stream_message(StreamMessage::Drop(self.id))
|
||||
.and_then(|_| self.writer.flush())
|
||||
{
|
||||
log::warn!("Failed to send message to drop stream: {err}");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Values that can contain a `ShellError` to signal an error has occurred.
|
||||
pub trait FromShellError {
|
||||
fn from_shell_error(err: ShellError) -> Self;
|
||||
}
|
||||
|
||||
// For List streams.
|
||||
impl FromShellError for Value {
|
||||
fn from_shell_error(err: ShellError) -> Self {
|
||||
Value::error(err, Span::unknown())
|
||||
}
|
||||
}
|
||||
|
||||
// For Raw streams, mostly.
|
||||
impl<T> FromShellError for Result<T, ShellError> {
|
||||
fn from_shell_error(err: ShellError) -> Self {
|
||||
Err(err)
|
||||
}
|
||||
}
|
||||
|
||||
/// Writes messages to a stream, with flow control.
|
||||
///
|
||||
/// The `signal` contained
|
||||
#[derive(Debug)]
|
||||
pub struct StreamWriter<W: WriteStreamMessage> {
|
||||
id: StreamId,
|
||||
signal: Arc<StreamWriterSignal>,
|
||||
writer: W,
|
||||
ended: bool,
|
||||
}
|
||||
|
||||
impl<W> StreamWriter<W>
|
||||
where
|
||||
W: WriteStreamMessage,
|
||||
{
|
||||
fn new(id: StreamId, signal: Arc<StreamWriterSignal>, writer: W) -> StreamWriter<W> {
|
||||
StreamWriter {
|
||||
id,
|
||||
signal,
|
||||
writer,
|
||||
ended: false,
|
||||
}
|
||||
}
|
||||
|
||||
/// Check if the stream was dropped from the other end. Recommended to do this before calling
|
||||
/// [`.write()`], especially in a loop.
|
||||
pub fn is_dropped(&self) -> Result<bool, ShellError> {
|
||||
self.signal.is_dropped()
|
||||
}
|
||||
|
||||
/// Write a single piece of data to the stream.
|
||||
///
|
||||
/// Error if something failed with the write, or if [`.end()`] was already called
|
||||
/// previously.
|
||||
pub fn write(&mut self, data: impl Into<StreamData>) -> Result<(), ShellError> {
|
||||
if !self.ended {
|
||||
self.writer
|
||||
.write_stream_message(StreamMessage::Data(self.id, data.into()))?;
|
||||
// This implements flow control, so we don't write too many messages:
|
||||
if !self.signal.notify_sent()? {
|
||||
// Flush the output, and then wait for acknowledgements
|
||||
self.writer.flush()?;
|
||||
self.signal.wait_for_drain()
|
||||
} else {
|
||||
Ok(())
|
||||
}
|
||||
} else {
|
||||
Err(ShellError::GenericError {
|
||||
error: "Wrote to a stream after it ended".into(),
|
||||
msg: format!(
|
||||
"tried to write to stream {} after it was already ended",
|
||||
self.id
|
||||
),
|
||||
span: None,
|
||||
help: Some("this may be a bug in the nu-plugin crate".into()),
|
||||
inner: vec![],
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
/// Write a full iterator to the stream. Note that this doesn't end the stream, so you should
|
||||
/// still call [`.end()`].
|
||||
///
|
||||
/// If the stream is dropped from the other end, the iterator will not be fully consumed, and
|
||||
/// writing will terminate.
|
||||
///
|
||||
/// Returns `Ok(true)` if the iterator was fully consumed, or `Ok(false)` if a drop interrupted
|
||||
/// the stream from the other side.
|
||||
pub fn write_all<T>(&mut self, data: impl IntoIterator<Item = T>) -> Result<bool, ShellError>
|
||||
where
|
||||
T: Into<StreamData>,
|
||||
{
|
||||
// Check before starting
|
||||
if self.is_dropped()? {
|
||||
return Ok(false);
|
||||
}
|
||||
|
||||
for item in data {
|
||||
// Check again after each item is consumed from the iterator, just in case the iterator
|
||||
// takes a while to produce a value
|
||||
if self.is_dropped()? {
|
||||
return Ok(false);
|
||||
}
|
||||
self.write(item)?;
|
||||
}
|
||||
Ok(true)
|
||||
}
|
||||
|
||||
/// End the stream. Recommend doing this instead of relying on `Drop` so that you can catch the
|
||||
/// error.
|
||||
pub fn end(&mut self) -> Result<(), ShellError> {
|
||||
if !self.ended {
|
||||
// Set the flag first so we don't double-report in the Drop
|
||||
self.ended = true;
|
||||
self.writer
|
||||
.write_stream_message(StreamMessage::End(self.id))?;
|
||||
self.writer.flush()
|
||||
} else {
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<W> Drop for StreamWriter<W>
|
||||
where
|
||||
W: WriteStreamMessage,
|
||||
{
|
||||
fn drop(&mut self) {
|
||||
// Make sure we ended the stream
|
||||
if let Err(err) = self.end() {
|
||||
log::warn!("Error while ending stream in Drop for StreamWriter: {err}");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Stores stream state for a writer, and can be blocked on to wait for messages to be acknowledged.
|
||||
/// A key part of managing stream lifecycle and flow control.
|
||||
#[derive(Debug)]
|
||||
pub struct StreamWriterSignal {
|
||||
mutex: Mutex<StreamWriterSignalState>,
|
||||
change_cond: Condvar,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct StreamWriterSignalState {
|
||||
/// Stream has been dropped and consumer is no longer interested in any messages.
|
||||
dropped: bool,
|
||||
/// Number of messages that have been sent without acknowledgement.
|
||||
unacknowledged: i32,
|
||||
/// Max number of messages to send before waiting for acknowledgement.
|
||||
high_pressure_mark: i32,
|
||||
}
|
||||
|
||||
impl StreamWriterSignal {
|
||||
/// Create a new signal.
|
||||
///
|
||||
/// If `notify_sent()` is called more than `high_pressure_mark` times, it will wait until
|
||||
/// `notify_acknowledge()` is called by another thread enough times to bring the number of
|
||||
/// unacknowledged sent messages below that threshold.
|
||||
fn new(high_pressure_mark: i32) -> StreamWriterSignal {
|
||||
assert!(high_pressure_mark > 0);
|
||||
|
||||
StreamWriterSignal {
|
||||
mutex: Mutex::new(StreamWriterSignalState {
|
||||
dropped: false,
|
||||
unacknowledged: 0,
|
||||
high_pressure_mark,
|
||||
}),
|
||||
change_cond: Condvar::new(),
|
||||
}
|
||||
}
|
||||
|
||||
fn lock(&self) -> Result<MutexGuard<StreamWriterSignalState>, ShellError> {
|
||||
self.mutex.lock().map_err(|_| ShellError::NushellFailed {
|
||||
msg: "StreamWriterSignal mutex poisoned due to panic".into(),
|
||||
})
|
||||
}
|
||||
|
||||
/// True if the stream was dropped and the consumer is no longer interested in it. Indicates
|
||||
/// that no more messages should be sent, other than `End`.
|
||||
pub fn is_dropped(&self) -> Result<bool, ShellError> {
|
||||
Ok(self.lock()?.dropped)
|
||||
}
|
||||
|
||||
/// Notify the writers that the stream has been dropped, so they can stop writing.
|
||||
pub fn set_dropped(&self) -> Result<(), ShellError> {
|
||||
let mut state = self.lock()?;
|
||||
state.dropped = true;
|
||||
// Unblock the writers so they can terminate
|
||||
self.change_cond.notify_all();
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Track that a message has been sent. Returns `Ok(true)` if more messages can be sent,
|
||||
/// or `Ok(false)` if the high pressure mark has been reached and [`.wait_for_drain()`] should
|
||||
/// be called to block.
|
||||
pub fn notify_sent(&self) -> Result<bool, ShellError> {
|
||||
let mut state = self.lock()?;
|
||||
state.unacknowledged =
|
||||
state
|
||||
.unacknowledged
|
||||
.checked_add(1)
|
||||
.ok_or_else(|| ShellError::NushellFailed {
|
||||
msg: "Overflow in counter: too many unacknowledged messages".into(),
|
||||
})?;
|
||||
|
||||
Ok(state.unacknowledged < state.high_pressure_mark)
|
||||
}
|
||||
|
||||
/// Wait for acknowledgements before sending more data. Also returns if the stream is dropped.
|
||||
pub fn wait_for_drain(&self) -> Result<(), ShellError> {
|
||||
let mut state = self.lock()?;
|
||||
while !state.dropped && state.unacknowledged >= state.high_pressure_mark {
|
||||
state = self
|
||||
.change_cond
|
||||
.wait(state)
|
||||
.map_err(|_| ShellError::NushellFailed {
|
||||
msg: "StreamWriterSignal mutex poisoned due to panic".into(),
|
||||
})?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Notify the writers that a message has been acknowledged, so they can continue to write
|
||||
/// if they were waiting.
|
||||
pub fn notify_acknowledged(&self) -> Result<(), ShellError> {
|
||||
let mut state = self.lock()?;
|
||||
state.unacknowledged =
|
||||
state
|
||||
.unacknowledged
|
||||
.checked_sub(1)
|
||||
.ok_or_else(|| ShellError::NushellFailed {
|
||||
msg: "Underflow in counter: too many message acknowledgements".into(),
|
||||
})?;
|
||||
// Unblock the writer
|
||||
self.change_cond.notify_one();
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
/// A sink for a [`StreamMessage`]
|
||||
pub trait WriteStreamMessage {
|
||||
fn write_stream_message(&mut self, msg: StreamMessage) -> Result<(), ShellError>;
|
||||
fn flush(&mut self) -> Result<(), ShellError>;
|
||||
}
|
||||
|
||||
#[derive(Debug, Default)]
|
||||
struct StreamManagerState {
|
||||
reading_streams: BTreeMap<StreamId, mpsc::Sender<Result<Option<StreamData>, ShellError>>>,
|
||||
writing_streams: BTreeMap<StreamId, Weak<StreamWriterSignal>>,
|
||||
}
|
||||
|
||||
impl StreamManagerState {
|
||||
/// Lock the state, or return a [`ShellError`] if the mutex is poisoned.
|
||||
fn lock(
|
||||
state: &Mutex<StreamManagerState>,
|
||||
) -> Result<MutexGuard<StreamManagerState>, ShellError> {
|
||||
state.lock().map_err(|_| ShellError::NushellFailed {
|
||||
msg: "StreamManagerState mutex poisoned due to a panic".into(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct StreamManager {
|
||||
state: Arc<Mutex<StreamManagerState>>,
|
||||
}
|
||||
|
||||
impl StreamManager {
|
||||
/// Create a new StreamManager.
|
||||
pub fn new() -> StreamManager {
|
||||
StreamManager {
|
||||
state: Default::default(),
|
||||
}
|
||||
}
|
||||
|
||||
fn lock(&self) -> Result<MutexGuard<StreamManagerState>, ShellError> {
|
||||
StreamManagerState::lock(&self.state)
|
||||
}
|
||||
|
||||
/// Create a new handle to the StreamManager for registering streams.
|
||||
pub fn get_handle(&self) -> StreamManagerHandle {
|
||||
StreamManagerHandle {
|
||||
state: Arc::downgrade(&self.state),
|
||||
}
|
||||
}
|
||||
|
||||
/// Process a stream message, and update internal state accordingly.
|
||||
pub fn handle_message(&self, message: StreamMessage) -> Result<(), ShellError> {
|
||||
let mut state = self.lock()?;
|
||||
match message {
|
||||
StreamMessage::Data(id, data) => {
|
||||
if let Some(sender) = state.reading_streams.get(&id) {
|
||||
// We should ignore the error on send. This just means the reader has dropped,
|
||||
// but it will have sent a Drop message to the other side, and we will receive
|
||||
// an End message at which point we can remove the channel.
|
||||
let _ = sender.send(Ok(Some(data)));
|
||||
Ok(())
|
||||
} else {
|
||||
Err(ShellError::PluginFailedToDecode {
|
||||
msg: format!("received Data for unknown stream {id}"),
|
||||
})
|
||||
}
|
||||
}
|
||||
StreamMessage::End(id) => {
|
||||
if let Some(sender) = state.reading_streams.remove(&id) {
|
||||
// We should ignore the error on the send, because the reader might have dropped
|
||||
// already
|
||||
let _ = sender.send(Ok(None));
|
||||
Ok(())
|
||||
} else {
|
||||
Err(ShellError::PluginFailedToDecode {
|
||||
msg: format!("received End for unknown stream {id}"),
|
||||
})
|
||||
}
|
||||
}
|
||||
StreamMessage::Drop(id) => {
|
||||
if let Some(signal) = state.writing_streams.remove(&id) {
|
||||
if let Some(signal) = signal.upgrade() {
|
||||
// This will wake blocked writers so they can stop writing, so it's ok
|
||||
signal.set_dropped()?;
|
||||
}
|
||||
}
|
||||
// It's possible that the stream has already finished writing and we don't have it
|
||||
// anymore, so we fall through to Ok
|
||||
Ok(())
|
||||
}
|
||||
StreamMessage::Ack(id) => {
|
||||
if let Some(signal) = state.writing_streams.get(&id) {
|
||||
if let Some(signal) = signal.upgrade() {
|
||||
// This will wake up a blocked writer
|
||||
signal.notify_acknowledged()?;
|
||||
} else {
|
||||
// We know it doesn't exist, so might as well remove it
|
||||
state.writing_streams.remove(&id);
|
||||
}
|
||||
}
|
||||
// It's possible that the stream has already finished writing and we don't have it
|
||||
// anymore, so we fall through to Ok
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Broadcast an error to all stream readers. This is useful for error propagation.
|
||||
pub fn broadcast_read_error(&self, error: ShellError) -> Result<(), ShellError> {
|
||||
let state = self.lock()?;
|
||||
for channel in state.reading_streams.values() {
|
||||
// Ignore send errors.
|
||||
let _ = channel.send(Err(error.clone()));
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// If the `StreamManager` is dropped, we should let all of the stream writers know that they
|
||||
// won't be able to write anymore. We don't need to do anything about the readers though
|
||||
// because they'll know when the `Sender` is dropped automatically
|
||||
fn drop_all_writers(&self) -> Result<(), ShellError> {
|
||||
let mut state = self.lock()?;
|
||||
let writers = std::mem::take(&mut state.writing_streams);
|
||||
for (_, signal) in writers {
|
||||
if let Some(signal) = signal.upgrade() {
|
||||
// more important that we send to all than handling an error
|
||||
let _ = signal.set_dropped();
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for StreamManager {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for StreamManager {
|
||||
fn drop(&mut self) {
|
||||
if let Err(err) = self.drop_all_writers() {
|
||||
log::warn!("error during Drop for StreamManager: {}", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// A [`StreamManagerHandle`] supports operations for interacting with the [`StreamManager`].
|
||||
///
|
||||
/// Streams can be registered for reading, returning a [`StreamReader`], or for writing, returning
|
||||
/// a [`StreamWriter`].
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct StreamManagerHandle {
|
||||
state: Weak<Mutex<StreamManagerState>>,
|
||||
}
|
||||
|
||||
impl StreamManagerHandle {
|
||||
/// Because the handle only has a weak reference to the [`StreamManager`] state, we have to
|
||||
/// first try to upgrade to a strong reference and then lock. This function wraps those two
|
||||
/// operations together, handling errors appropriately.
|
||||
fn with_lock<T, F>(&self, f: F) -> Result<T, ShellError>
|
||||
where
|
||||
F: FnOnce(MutexGuard<StreamManagerState>) -> Result<T, ShellError>,
|
||||
{
|
||||
let upgraded = self
|
||||
.state
|
||||
.upgrade()
|
||||
.ok_or_else(|| ShellError::NushellFailed {
|
||||
msg: "StreamManager is no longer alive".into(),
|
||||
})?;
|
||||
let guard = upgraded.lock().map_err(|_| ShellError::NushellFailed {
|
||||
msg: "StreamManagerState mutex poisoned due to a panic".into(),
|
||||
})?;
|
||||
f(guard)
|
||||
}
|
||||
|
||||
/// Register a new stream for reading, and return a [`StreamReader`] that can be used to iterate
|
||||
/// on the values received. A [`StreamMessage`] writer is required for writing control messages
|
||||
/// back to the producer.
|
||||
pub fn read_stream<T, W>(
|
||||
&self,
|
||||
id: StreamId,
|
||||
writer: W,
|
||||
) -> Result<StreamReader<T, W>, ShellError>
|
||||
where
|
||||
T: TryFrom<StreamData, Error = ShellError>,
|
||||
W: WriteStreamMessage,
|
||||
{
|
||||
let (tx, rx) = mpsc::channel();
|
||||
self.with_lock(|mut state| {
|
||||
// Must be exclusive
|
||||
if let btree_map::Entry::Vacant(e) = state.reading_streams.entry(id) {
|
||||
e.insert(tx);
|
||||
Ok(())
|
||||
} else {
|
||||
Err(ShellError::GenericError {
|
||||
error: format!("Failed to acquire reader for stream {id}"),
|
||||
msg: "tried to get a reader for a stream that's already being read".into(),
|
||||
span: None,
|
||||
help: Some("this may be a bug in the nu-plugin crate".into()),
|
||||
inner: vec![],
|
||||
})
|
||||
}
|
||||
})?;
|
||||
Ok(StreamReader::new(id, rx, writer))
|
||||
}
|
||||
|
||||
/// Register a new stream for writing, and return a [`StreamWriter`] that can be used to send
|
||||
/// data to the stream.
|
||||
///
|
||||
/// The `high_pressure_mark` value controls how many messages can be written without receiving
|
||||
/// an acknowledgement before any further attempts to write will wait for the consumer to
|
||||
/// acknowledge them. This prevents overwhelming the reader.
|
||||
pub fn write_stream<W>(
|
||||
&self,
|
||||
id: StreamId,
|
||||
writer: W,
|
||||
high_pressure_mark: i32,
|
||||
) -> Result<StreamWriter<W>, ShellError>
|
||||
where
|
||||
W: WriteStreamMessage,
|
||||
{
|
||||
let signal = Arc::new(StreamWriterSignal::new(high_pressure_mark));
|
||||
self.with_lock(|mut state| {
|
||||
// Remove dead writing streams
|
||||
state
|
||||
.writing_streams
|
||||
.retain(|_, signal| signal.strong_count() > 0);
|
||||
// Must be exclusive
|
||||
if let btree_map::Entry::Vacant(e) = state.writing_streams.entry(id) {
|
||||
e.insert(Arc::downgrade(&signal));
|
||||
Ok(())
|
||||
} else {
|
||||
Err(ShellError::GenericError {
|
||||
error: format!("Failed to acquire writer for stream {id}"),
|
||||
msg: "tried to get a writer for a stream that's already being written".into(),
|
||||
span: None,
|
||||
help: Some("this may be a bug in the nu-plugin crate".into()),
|
||||
inner: vec![],
|
||||
})
|
||||
}
|
||||
})?;
|
||||
Ok(StreamWriter::new(id, signal, writer))
|
||||
}
|
||||
}
|
550
crates/nu-plugin-core/src/interface/stream/tests.rs
Normal file
550
crates/nu-plugin-core/src/interface/stream/tests.rs
Normal file
@ -0,0 +1,550 @@
|
||||
use std::{
|
||||
sync::{
|
||||
atomic::{AtomicBool, Ordering::Relaxed},
|
||||
mpsc, Arc,
|
||||
},
|
||||
time::{Duration, Instant},
|
||||
};
|
||||
|
||||
use super::{StreamManager, StreamReader, StreamWriter, StreamWriterSignal, WriteStreamMessage};
|
||||
use nu_plugin_protocol::{StreamData, StreamMessage};
|
||||
use nu_protocol::{ShellError, Value};
|
||||
|
||||
// Should be long enough to definitely complete any quick operation, but not so long that tests are
|
||||
// slow to complete. 10 ms is a pretty long time
|
||||
const WAIT_DURATION: Duration = Duration::from_millis(10);
|
||||
|
||||
// Maximum time to wait for a condition to be true
|
||||
const MAX_WAIT_DURATION: Duration = Duration::from_millis(500);
|
||||
|
||||
/// Wait for a condition to be true, or panic if the duration exceeds MAX_WAIT_DURATION
|
||||
#[track_caller]
|
||||
fn wait_for_condition(mut cond: impl FnMut() -> bool, message: &str) {
|
||||
// Early check
|
||||
if cond() {
|
||||
return;
|
||||
}
|
||||
|
||||
let start = Instant::now();
|
||||
loop {
|
||||
std::thread::sleep(Duration::from_millis(10));
|
||||
|
||||
if cond() {
|
||||
return;
|
||||
}
|
||||
|
||||
let elapsed = Instant::now().saturating_duration_since(start);
|
||||
if elapsed > MAX_WAIT_DURATION {
|
||||
panic!(
|
||||
"{message}: Waited {:.2}sec, which is more than the maximum of {:.2}sec",
|
||||
elapsed.as_secs_f64(),
|
||||
MAX_WAIT_DURATION.as_secs_f64(),
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Default)]
|
||||
struct TestSink(Vec<StreamMessage>);
|
||||
|
||||
impl WriteStreamMessage for TestSink {
|
||||
fn write_stream_message(&mut self, msg: StreamMessage) -> Result<(), ShellError> {
|
||||
self.0.push(msg);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn flush(&mut self) -> Result<(), ShellError> {
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl WriteStreamMessage for mpsc::Sender<StreamMessage> {
|
||||
fn write_stream_message(&mut self, msg: StreamMessage) -> Result<(), ShellError> {
|
||||
self.send(msg).map_err(|err| ShellError::NushellFailed {
|
||||
msg: err.to_string(),
|
||||
})
|
||||
}
|
||||
|
||||
fn flush(&mut self) -> Result<(), ShellError> {
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn reader_recv_list_messages() -> Result<(), ShellError> {
|
||||
let (tx, rx) = mpsc::channel();
|
||||
let mut reader = StreamReader::new(0, rx, TestSink::default());
|
||||
|
||||
tx.send(Ok(Some(StreamData::List(Value::test_int(5)))))
|
||||
.unwrap();
|
||||
drop(tx);
|
||||
|
||||
assert_eq!(Some(Value::test_int(5)), reader.recv()?);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn list_reader_recv_wrong_type() -> Result<(), ShellError> {
|
||||
let (tx, rx) = mpsc::channel();
|
||||
let mut reader = StreamReader::<Value, _>::new(0, rx, TestSink::default());
|
||||
|
||||
tx.send(Ok(Some(StreamData::Raw(Ok(vec![10, 20])))))
|
||||
.unwrap();
|
||||
tx.send(Ok(Some(StreamData::List(Value::test_nothing()))))
|
||||
.unwrap();
|
||||
drop(tx);
|
||||
|
||||
reader.recv().expect_err("should be an error");
|
||||
reader.recv().expect("should be able to recover");
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn reader_recv_raw_messages() -> Result<(), ShellError> {
|
||||
let (tx, rx) = mpsc::channel();
|
||||
let mut reader =
|
||||
StreamReader::<Result<Vec<u8>, ShellError>, _>::new(0, rx, TestSink::default());
|
||||
|
||||
tx.send(Ok(Some(StreamData::Raw(Ok(vec![10, 20])))))
|
||||
.unwrap();
|
||||
drop(tx);
|
||||
|
||||
assert_eq!(Some(vec![10, 20]), reader.recv()?.transpose()?);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn raw_reader_recv_wrong_type() -> Result<(), ShellError> {
|
||||
let (tx, rx) = mpsc::channel();
|
||||
let mut reader =
|
||||
StreamReader::<Result<Vec<u8>, ShellError>, _>::new(0, rx, TestSink::default());
|
||||
|
||||
tx.send(Ok(Some(StreamData::List(Value::test_nothing()))))
|
||||
.unwrap();
|
||||
tx.send(Ok(Some(StreamData::Raw(Ok(vec![10, 20])))))
|
||||
.unwrap();
|
||||
drop(tx);
|
||||
|
||||
reader.recv().expect_err("should be an error");
|
||||
reader.recv().expect("should be able to recover");
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn reader_recv_acknowledge() -> Result<(), ShellError> {
|
||||
let (tx, rx) = mpsc::channel();
|
||||
let mut reader = StreamReader::<Value, _>::new(0, rx, TestSink::default());
|
||||
|
||||
tx.send(Ok(Some(StreamData::List(Value::test_int(5)))))
|
||||
.unwrap();
|
||||
tx.send(Ok(Some(StreamData::List(Value::test_int(6)))))
|
||||
.unwrap();
|
||||
drop(tx);
|
||||
|
||||
reader.recv()?;
|
||||
reader.recv()?;
|
||||
let wrote = &reader.writer.0;
|
||||
assert!(wrote.len() >= 2);
|
||||
assert!(
|
||||
matches!(wrote[0], StreamMessage::Ack(0)),
|
||||
"0 = {:?}",
|
||||
wrote[0]
|
||||
);
|
||||
assert!(
|
||||
matches!(wrote[1], StreamMessage::Ack(0)),
|
||||
"1 = {:?}",
|
||||
wrote[1]
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn reader_recv_end_of_stream() -> Result<(), ShellError> {
|
||||
let (tx, rx) = mpsc::channel();
|
||||
let mut reader = StreamReader::<Value, _>::new(0, rx, TestSink::default());
|
||||
|
||||
tx.send(Ok(Some(StreamData::List(Value::test_int(5)))))
|
||||
.unwrap();
|
||||
tx.send(Ok(None)).unwrap();
|
||||
drop(tx);
|
||||
|
||||
assert!(reader.recv()?.is_some(), "actual message");
|
||||
assert!(reader.recv()?.is_none(), "on close");
|
||||
assert!(reader.recv()?.is_none(), "after close");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn reader_iter_fuse_on_error() -> Result<(), ShellError> {
|
||||
let (tx, rx) = mpsc::channel();
|
||||
let mut reader = StreamReader::<Value, _>::new(0, rx, TestSink::default());
|
||||
|
||||
drop(tx); // should cause error, because we didn't explicitly signal the end
|
||||
|
||||
assert!(
|
||||
reader.next().is_some_and(|e| e.is_error()),
|
||||
"should be error the first time"
|
||||
);
|
||||
assert!(reader.next().is_none(), "should be closed the second time");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn reader_drop() {
|
||||
let (_tx, rx) = mpsc::channel();
|
||||
|
||||
// Flag set if drop message is received.
|
||||
struct Check(Arc<AtomicBool>);
|
||||
|
||||
impl WriteStreamMessage for Check {
|
||||
fn write_stream_message(&mut self, msg: StreamMessage) -> Result<(), ShellError> {
|
||||
assert!(matches!(msg, StreamMessage::Drop(1)), "got {:?}", msg);
|
||||
self.0.store(true, Relaxed);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn flush(&mut self) -> Result<(), ShellError> {
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
let flag = Arc::new(AtomicBool::new(false));
|
||||
|
||||
let reader = StreamReader::<Value, _>::new(1, rx, Check(flag.clone()));
|
||||
drop(reader);
|
||||
|
||||
assert!(flag.load(Relaxed));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn writer_write_all_stops_if_dropped() -> Result<(), ShellError> {
|
||||
let signal = Arc::new(StreamWriterSignal::new(20));
|
||||
let id = 1337;
|
||||
let mut writer = StreamWriter::new(id, signal.clone(), TestSink::default());
|
||||
|
||||
// Simulate this by having it consume a stream that will actually do the drop halfway through
|
||||
let iter = (0..5).map(Value::test_int).chain({
|
||||
let mut n = 5;
|
||||
std::iter::from_fn(move || {
|
||||
// produces numbers 5..10, but drops for the first one
|
||||
if n == 5 {
|
||||
signal.set_dropped().unwrap();
|
||||
}
|
||||
if n < 10 {
|
||||
let value = Value::test_int(n);
|
||||
n += 1;
|
||||
Some(value)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
});
|
||||
|
||||
writer.write_all(iter)?;
|
||||
|
||||
assert!(writer.is_dropped()?);
|
||||
|
||||
let wrote = &writer.writer.0;
|
||||
assert_eq!(5, wrote.len(), "length wrong: {wrote:?}");
|
||||
|
||||
for (n, message) in (0..5).zip(wrote) {
|
||||
match message {
|
||||
StreamMessage::Data(msg_id, StreamData::List(value)) => {
|
||||
assert_eq!(id, *msg_id, "id");
|
||||
assert_eq!(Value::test_int(n), *value, "value");
|
||||
}
|
||||
other => panic!("unexpected message: {other:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn writer_end() -> Result<(), ShellError> {
|
||||
let signal = Arc::new(StreamWriterSignal::new(20));
|
||||
let mut writer = StreamWriter::new(9001, signal.clone(), TestSink::default());
|
||||
|
||||
writer.end()?;
|
||||
writer
|
||||
.write(Value::test_int(2))
|
||||
.expect_err("shouldn't be able to write after end");
|
||||
writer.end().expect("end twice should be ok");
|
||||
|
||||
let wrote = &writer.writer.0;
|
||||
assert!(
|
||||
matches!(wrote.last(), Some(StreamMessage::End(9001))),
|
||||
"didn't write end message: {wrote:?}"
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn signal_set_dropped() -> Result<(), ShellError> {
|
||||
let signal = StreamWriterSignal::new(4);
|
||||
assert!(!signal.is_dropped()?);
|
||||
signal.set_dropped()?;
|
||||
assert!(signal.is_dropped()?);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn signal_notify_sent_false_if_unacknowledged() -> Result<(), ShellError> {
|
||||
let signal = StreamWriterSignal::new(2);
|
||||
assert!(signal.notify_sent()?);
|
||||
for _ in 0..100 {
|
||||
assert!(!signal.notify_sent()?);
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn signal_notify_sent_never_false_if_flowing() -> Result<(), ShellError> {
|
||||
let signal = StreamWriterSignal::new(1);
|
||||
for _ in 0..100 {
|
||||
signal.notify_acknowledged()?;
|
||||
}
|
||||
for _ in 0..100 {
|
||||
assert!(signal.notify_sent()?);
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn signal_wait_for_drain_blocks_on_unacknowledged() -> Result<(), ShellError> {
|
||||
let signal = StreamWriterSignal::new(50);
|
||||
std::thread::scope(|scope| {
|
||||
let spawned = scope.spawn(|| {
|
||||
for _ in 0..100 {
|
||||
if !signal.notify_sent()? {
|
||||
signal.wait_for_drain()?;
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
});
|
||||
std::thread::sleep(WAIT_DURATION);
|
||||
assert!(!spawned.is_finished(), "didn't block");
|
||||
for _ in 0..100 {
|
||||
signal.notify_acknowledged()?;
|
||||
}
|
||||
wait_for_condition(|| spawned.is_finished(), "blocked at end");
|
||||
spawned.join().unwrap()
|
||||
})
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn signal_wait_for_drain_unblocks_on_dropped() -> Result<(), ShellError> {
|
||||
let signal = StreamWriterSignal::new(1);
|
||||
std::thread::scope(|scope| {
|
||||
let spawned = scope.spawn(|| {
|
||||
while !signal.is_dropped()? {
|
||||
if !signal.notify_sent()? {
|
||||
signal.wait_for_drain()?;
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
});
|
||||
std::thread::sleep(WAIT_DURATION);
|
||||
assert!(!spawned.is_finished(), "didn't block");
|
||||
signal.set_dropped()?;
|
||||
wait_for_condition(|| spawned.is_finished(), "still blocked at end");
|
||||
spawned.join().unwrap()
|
||||
})
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn stream_manager_single_stream_read_scenario() -> Result<(), ShellError> {
|
||||
let manager = StreamManager::new();
|
||||
let handle = manager.get_handle();
|
||||
let (tx, rx) = mpsc::channel();
|
||||
let readable = handle.read_stream::<Value, _>(2, tx)?;
|
||||
|
||||
let expected_values = vec![Value::test_int(40), Value::test_string("hello")];
|
||||
|
||||
for value in &expected_values {
|
||||
manager.handle_message(StreamMessage::Data(2, value.clone().into()))?;
|
||||
}
|
||||
manager.handle_message(StreamMessage::End(2))?;
|
||||
|
||||
let values = readable.collect::<Vec<Value>>();
|
||||
|
||||
assert_eq!(expected_values, values);
|
||||
|
||||
// Now check the sent messages on consumption
|
||||
// Should be Ack for each message, then Drop
|
||||
for _ in &expected_values {
|
||||
match rx.try_recv().expect("failed to receive Ack") {
|
||||
StreamMessage::Ack(2) => (),
|
||||
other => panic!("should have been an Ack: {other:?}"),
|
||||
}
|
||||
}
|
||||
match rx.try_recv().expect("failed to receive Drop") {
|
||||
StreamMessage::Drop(2) => (),
|
||||
other => panic!("should have been a Drop: {other:?}"),
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn stream_manager_multi_stream_read_scenario() -> Result<(), ShellError> {
|
||||
let manager = StreamManager::new();
|
||||
let handle = manager.get_handle();
|
||||
let (tx, rx) = mpsc::channel();
|
||||
let readable_list = handle.read_stream::<Value, _>(2, tx.clone())?;
|
||||
let readable_raw = handle.read_stream::<Result<Vec<u8>, _>, _>(3, tx)?;
|
||||
|
||||
let expected_values = (1..100).map(Value::test_int).collect::<Vec<_>>();
|
||||
let expected_raw_buffers = (1..100).map(|n| vec![n]).collect::<Vec<Vec<u8>>>();
|
||||
|
||||
for (value, buf) in expected_values.iter().zip(&expected_raw_buffers) {
|
||||
manager.handle_message(StreamMessage::Data(2, value.clone().into()))?;
|
||||
manager.handle_message(StreamMessage::Data(3, StreamData::Raw(Ok(buf.clone()))))?;
|
||||
}
|
||||
manager.handle_message(StreamMessage::End(2))?;
|
||||
manager.handle_message(StreamMessage::End(3))?;
|
||||
|
||||
let values = readable_list.collect::<Vec<Value>>();
|
||||
let bufs = readable_raw.collect::<Result<Vec<Vec<u8>>, _>>()?;
|
||||
|
||||
for (expected_value, value) in expected_values.iter().zip(&values) {
|
||||
assert_eq!(expected_value, value, "in List stream");
|
||||
}
|
||||
for (expected_buf, buf) in expected_raw_buffers.iter().zip(&bufs) {
|
||||
assert_eq!(expected_buf, buf, "in Raw stream");
|
||||
}
|
||||
|
||||
// Now check the sent messages on consumption
|
||||
// Should be Ack for each message, then Drop
|
||||
for _ in &expected_values {
|
||||
match rx.try_recv().expect("failed to receive Ack") {
|
||||
StreamMessage::Ack(2) => (),
|
||||
other => panic!("should have been an Ack(2): {other:?}"),
|
||||
}
|
||||
}
|
||||
match rx.try_recv().expect("failed to receive Drop") {
|
||||
StreamMessage::Drop(2) => (),
|
||||
other => panic!("should have been a Drop(2): {other:?}"),
|
||||
}
|
||||
for _ in &expected_values {
|
||||
match rx.try_recv().expect("failed to receive Ack") {
|
||||
StreamMessage::Ack(3) => (),
|
||||
other => panic!("should have been an Ack(3): {other:?}"),
|
||||
}
|
||||
}
|
||||
match rx.try_recv().expect("failed to receive Drop") {
|
||||
StreamMessage::Drop(3) => (),
|
||||
other => panic!("should have been a Drop(3): {other:?}"),
|
||||
}
|
||||
|
||||
// Should be end of stream
|
||||
assert!(
|
||||
rx.try_recv().is_err(),
|
||||
"more messages written to stream than expected"
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn stream_manager_write_scenario() -> Result<(), ShellError> {
|
||||
let manager = StreamManager::new();
|
||||
let handle = manager.get_handle();
|
||||
let (tx, rx) = mpsc::channel();
|
||||
let mut writable = handle.write_stream(4, tx, 100)?;
|
||||
|
||||
let expected_values = vec![b"hello".to_vec(), b"world".to_vec(), b"test".to_vec()];
|
||||
|
||||
for value in &expected_values {
|
||||
writable.write(Ok::<_, ShellError>(value.clone()))?;
|
||||
}
|
||||
|
||||
// Now try signalling ack
|
||||
assert_eq!(
|
||||
expected_values.len() as i32,
|
||||
writable.signal.lock()?.unacknowledged,
|
||||
"unacknowledged initial count",
|
||||
);
|
||||
manager.handle_message(StreamMessage::Ack(4))?;
|
||||
assert_eq!(
|
||||
expected_values.len() as i32 - 1,
|
||||
writable.signal.lock()?.unacknowledged,
|
||||
"unacknowledged post-Ack count",
|
||||
);
|
||||
|
||||
// ...and Drop
|
||||
manager.handle_message(StreamMessage::Drop(4))?;
|
||||
assert!(writable.is_dropped()?);
|
||||
|
||||
// Drop the StreamWriter...
|
||||
drop(writable);
|
||||
|
||||
// now check what was actually written
|
||||
for value in &expected_values {
|
||||
match rx.try_recv().expect("failed to receive Data") {
|
||||
StreamMessage::Data(4, StreamData::Raw(Ok(received))) => {
|
||||
assert_eq!(*value, received);
|
||||
}
|
||||
other @ StreamMessage::Data(..) => panic!("wrong Data for {value:?}: {other:?}"),
|
||||
other => panic!("should have been Data: {other:?}"),
|
||||
}
|
||||
}
|
||||
match rx.try_recv().expect("failed to receive End") {
|
||||
StreamMessage::End(4) => (),
|
||||
other => panic!("should have been End: {other:?}"),
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn stream_manager_broadcast_read_error() -> Result<(), ShellError> {
|
||||
let manager = StreamManager::new();
|
||||
let handle = manager.get_handle();
|
||||
let mut readable0 = handle.read_stream::<Value, _>(0, TestSink::default())?;
|
||||
let mut readable1 = handle.read_stream::<Result<Vec<u8>, _>, _>(1, TestSink::default())?;
|
||||
|
||||
let error = ShellError::PluginFailedToDecode {
|
||||
msg: "test decode error".into(),
|
||||
};
|
||||
|
||||
manager.broadcast_read_error(error.clone())?;
|
||||
drop(manager);
|
||||
|
||||
assert_eq!(
|
||||
error.to_string(),
|
||||
readable0
|
||||
.recv()
|
||||
.transpose()
|
||||
.expect("nothing received from readable0")
|
||||
.expect_err("not an error received from readable0")
|
||||
.to_string()
|
||||
);
|
||||
assert_eq!(
|
||||
error.to_string(),
|
||||
readable1
|
||||
.next()
|
||||
.expect("nothing received from readable1")
|
||||
.expect_err("not an error received from readable1")
|
||||
.to_string()
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn stream_manager_drop_writers_on_drop() -> Result<(), ShellError> {
|
||||
let manager = StreamManager::new();
|
||||
let handle = manager.get_handle();
|
||||
let writable = handle.write_stream(4, TestSink::default(), 100)?;
|
||||
|
||||
assert!(!writable.is_dropped()?);
|
||||
|
||||
drop(manager);
|
||||
|
||||
assert!(writable.is_dropped()?);
|
||||
|
||||
Ok(())
|
||||
}
|
134
crates/nu-plugin-core/src/interface/test_util.rs
Normal file
134
crates/nu-plugin-core/src/interface/test_util.rs
Normal file
@ -0,0 +1,134 @@
|
||||
use nu_protocol::ShellError;
|
||||
use std::{
|
||||
collections::VecDeque,
|
||||
sync::{Arc, Mutex},
|
||||
};
|
||||
|
||||
use crate::{PluginRead, PluginWrite};
|
||||
|
||||
const FAILED: &str = "failed to lock TestCase";
|
||||
|
||||
/// Mock read/write helper for the engine and plugin interfaces.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct TestCase<I, O> {
|
||||
r#in: Arc<Mutex<TestData<I>>>,
|
||||
out: Arc<Mutex<TestData<O>>>,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct TestData<T> {
|
||||
data: VecDeque<T>,
|
||||
error: Option<ShellError>,
|
||||
flushed: bool,
|
||||
}
|
||||
|
||||
impl<T> Default for TestData<T> {
|
||||
fn default() -> Self {
|
||||
TestData {
|
||||
data: VecDeque::new(),
|
||||
error: None,
|
||||
flushed: false,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<I, O> PluginRead<I> for TestCase<I, O> {
|
||||
fn read(&mut self) -> Result<Option<I>, ShellError> {
|
||||
let mut lock = self.r#in.lock().expect(FAILED);
|
||||
if let Some(err) = lock.error.take() {
|
||||
Err(err)
|
||||
} else {
|
||||
Ok(lock.data.pop_front())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<I, O> PluginWrite<O> for TestCase<I, O>
|
||||
where
|
||||
I: Send + Clone,
|
||||
O: Send + Clone,
|
||||
{
|
||||
fn write(&self, data: &O) -> Result<(), ShellError> {
|
||||
let mut lock = self.out.lock().expect(FAILED);
|
||||
lock.flushed = false;
|
||||
|
||||
if let Some(err) = lock.error.take() {
|
||||
Err(err)
|
||||
} else {
|
||||
lock.data.push_back(data.clone());
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
fn flush(&self) -> Result<(), ShellError> {
|
||||
let mut lock = self.out.lock().expect(FAILED);
|
||||
lock.flushed = true;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
impl<I, O> TestCase<I, O> {
|
||||
pub fn new() -> TestCase<I, O> {
|
||||
TestCase {
|
||||
r#in: Default::default(),
|
||||
out: Default::default(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Clear the read buffer.
|
||||
pub fn clear(&self) {
|
||||
self.r#in.lock().expect(FAILED).data.truncate(0);
|
||||
}
|
||||
|
||||
/// Add input that will be read by the interface.
|
||||
pub fn add(&self, input: impl Into<I>) {
|
||||
self.r#in.lock().expect(FAILED).data.push_back(input.into());
|
||||
}
|
||||
|
||||
/// Add multiple inputs that will be read by the interface.
|
||||
pub fn extend(&self, inputs: impl IntoIterator<Item = I>) {
|
||||
self.r#in.lock().expect(FAILED).data.extend(inputs);
|
||||
}
|
||||
|
||||
/// Return an error from the next read operation.
|
||||
pub fn set_read_error(&self, err: ShellError) {
|
||||
self.r#in.lock().expect(FAILED).error = Some(err);
|
||||
}
|
||||
|
||||
/// Return an error from the next write operation.
|
||||
pub fn set_write_error(&self, err: ShellError) {
|
||||
self.out.lock().expect(FAILED).error = Some(err);
|
||||
}
|
||||
|
||||
/// Get the next output that was written.
|
||||
pub fn next_written(&self) -> Option<O> {
|
||||
self.out.lock().expect(FAILED).data.pop_front()
|
||||
}
|
||||
|
||||
/// Iterator over written data.
|
||||
pub fn written(&self) -> impl Iterator<Item = O> + '_ {
|
||||
std::iter::from_fn(|| self.next_written())
|
||||
}
|
||||
|
||||
/// Returns true if the writer was flushed after the last write operation.
|
||||
pub fn was_flushed(&self) -> bool {
|
||||
self.out.lock().expect(FAILED).flushed
|
||||
}
|
||||
|
||||
/// Returns true if the reader has unconsumed reads.
|
||||
pub fn has_unconsumed_read(&self) -> bool {
|
||||
!self.r#in.lock().expect(FAILED).data.is_empty()
|
||||
}
|
||||
|
||||
/// Returns true if the writer has unconsumed writes.
|
||||
pub fn has_unconsumed_write(&self) -> bool {
|
||||
!self.out.lock().expect(FAILED).data.is_empty()
|
||||
}
|
||||
}
|
||||
|
||||
impl<I, O> Default for TestCase<I, O> {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
573
crates/nu-plugin-core/src/interface/tests.rs
Normal file
573
crates/nu-plugin-core/src/interface/tests.rs
Normal file
@ -0,0 +1,573 @@
|
||||
use crate::util::Sequence;
|
||||
|
||||
use super::{
|
||||
stream::{StreamManager, StreamManagerHandle},
|
||||
test_util::TestCase,
|
||||
Interface, InterfaceManager, PluginRead, PluginWrite,
|
||||
};
|
||||
use nu_plugin_protocol::{
|
||||
ExternalStreamInfo, ListStreamInfo, PipelineDataHeader, PluginInput, PluginOutput,
|
||||
RawStreamInfo, StreamData, StreamMessage,
|
||||
};
|
||||
use nu_protocol::{
|
||||
DataSource, ListStream, PipelineData, PipelineMetadata, RawStream, ShellError, Span, Value,
|
||||
};
|
||||
use std::{path::Path, sync::Arc};
|
||||
|
||||
fn test_metadata() -> PipelineMetadata {
|
||||
PipelineMetadata {
|
||||
data_source: DataSource::FilePath("/test/path".into()),
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct TestInterfaceManager {
|
||||
stream_manager: StreamManager,
|
||||
test: TestCase<PluginInput, PluginOutput>,
|
||||
seq: Arc<Sequence>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct TestInterface {
|
||||
stream_manager_handle: StreamManagerHandle,
|
||||
test: TestCase<PluginInput, PluginOutput>,
|
||||
seq: Arc<Sequence>,
|
||||
}
|
||||
|
||||
impl TestInterfaceManager {
|
||||
fn new(test: &TestCase<PluginInput, PluginOutput>) -> TestInterfaceManager {
|
||||
TestInterfaceManager {
|
||||
stream_manager: StreamManager::new(),
|
||||
test: test.clone(),
|
||||
seq: Arc::new(Sequence::default()),
|
||||
}
|
||||
}
|
||||
|
||||
fn consume_all(&mut self) -> Result<(), ShellError> {
|
||||
while let Some(msg) = self.test.read()? {
|
||||
self.consume(msg)?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl InterfaceManager for TestInterfaceManager {
|
||||
type Interface = TestInterface;
|
||||
type Input = PluginInput;
|
||||
|
||||
fn get_interface(&self) -> Self::Interface {
|
||||
TestInterface {
|
||||
stream_manager_handle: self.stream_manager.get_handle(),
|
||||
test: self.test.clone(),
|
||||
seq: self.seq.clone(),
|
||||
}
|
||||
}
|
||||
|
||||
fn consume(&mut self, input: Self::Input) -> Result<(), ShellError> {
|
||||
match input {
|
||||
PluginInput::Data(..)
|
||||
| PluginInput::End(..)
|
||||
| PluginInput::Drop(..)
|
||||
| PluginInput::Ack(..) => self.consume_stream_message(
|
||||
input
|
||||
.try_into()
|
||||
.expect("failed to convert message to StreamMessage"),
|
||||
),
|
||||
_ => unimplemented!(),
|
||||
}
|
||||
}
|
||||
|
||||
fn stream_manager(&self) -> &StreamManager {
|
||||
&self.stream_manager
|
||||
}
|
||||
|
||||
fn prepare_pipeline_data(&self, data: PipelineData) -> Result<PipelineData, ShellError> {
|
||||
Ok(data.set_metadata(Some(test_metadata())))
|
||||
}
|
||||
}
|
||||
|
||||
impl Interface for TestInterface {
|
||||
type Output = PluginOutput;
|
||||
type DataContext = ();
|
||||
|
||||
fn write(&self, output: Self::Output) -> Result<(), ShellError> {
|
||||
self.test.write(&output)
|
||||
}
|
||||
|
||||
fn flush(&self) -> Result<(), ShellError> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn stream_id_sequence(&self) -> &Sequence {
|
||||
&self.seq
|
||||
}
|
||||
|
||||
fn stream_manager_handle(&self) -> &StreamManagerHandle {
|
||||
&self.stream_manager_handle
|
||||
}
|
||||
|
||||
fn prepare_pipeline_data(
|
||||
&self,
|
||||
data: PipelineData,
|
||||
_context: &(),
|
||||
) -> Result<PipelineData, ShellError> {
|
||||
// Add an arbitrary check to the data to verify this is being called
|
||||
match data {
|
||||
PipelineData::Value(Value::Binary { .. }, None) => Err(ShellError::NushellFailed {
|
||||
msg: "TEST can't send binary".into(),
|
||||
}),
|
||||
_ => Ok(data),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn read_pipeline_data_empty() -> Result<(), ShellError> {
|
||||
let manager = TestInterfaceManager::new(&TestCase::new());
|
||||
let header = PipelineDataHeader::Empty;
|
||||
|
||||
assert!(matches!(
|
||||
manager.read_pipeline_data(header, None)?,
|
||||
PipelineData::Empty
|
||||
));
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn read_pipeline_data_value() -> Result<(), ShellError> {
|
||||
let manager = TestInterfaceManager::new(&TestCase::new());
|
||||
let value = Value::test_int(4);
|
||||
let header = PipelineDataHeader::Value(value.clone());
|
||||
|
||||
match manager.read_pipeline_data(header, None)? {
|
||||
PipelineData::Value(read_value, _) => assert_eq!(value, read_value),
|
||||
PipelineData::ListStream(_, _) => panic!("unexpected ListStream"),
|
||||
PipelineData::ExternalStream { .. } => panic!("unexpected ExternalStream"),
|
||||
PipelineData::Empty => panic!("unexpected Empty"),
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn read_pipeline_data_list_stream() -> Result<(), ShellError> {
|
||||
let test = TestCase::new();
|
||||
let mut manager = TestInterfaceManager::new(&test);
|
||||
|
||||
let data = (0..100).map(Value::test_int).collect::<Vec<_>>();
|
||||
|
||||
for value in &data {
|
||||
test.add(StreamMessage::Data(7, value.clone().into()));
|
||||
}
|
||||
test.add(StreamMessage::End(7));
|
||||
|
||||
let header = PipelineDataHeader::ListStream(ListStreamInfo { id: 7 });
|
||||
|
||||
let pipe = manager.read_pipeline_data(header, None)?;
|
||||
assert!(
|
||||
matches!(pipe, PipelineData::ListStream(..)),
|
||||
"unexpected PipelineData: {pipe:?}"
|
||||
);
|
||||
|
||||
// need to consume input
|
||||
manager.consume_all()?;
|
||||
|
||||
let mut count = 0;
|
||||
for (expected, read) in data.into_iter().zip(pipe) {
|
||||
assert_eq!(expected, read);
|
||||
count += 1;
|
||||
}
|
||||
assert_eq!(100, count);
|
||||
|
||||
assert!(test.has_unconsumed_write());
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn read_pipeline_data_external_stream() -> Result<(), ShellError> {
|
||||
let test = TestCase::new();
|
||||
let mut manager = TestInterfaceManager::new(&test);
|
||||
|
||||
let iterations = 100;
|
||||
let out_pattern = b"hello".to_vec();
|
||||
let err_pattern = vec![5, 4, 3, 2];
|
||||
|
||||
test.add(StreamMessage::Data(14, Value::test_int(1).into()));
|
||||
for _ in 0..iterations {
|
||||
test.add(StreamMessage::Data(
|
||||
12,
|
||||
StreamData::Raw(Ok(out_pattern.clone())),
|
||||
));
|
||||
test.add(StreamMessage::Data(
|
||||
13,
|
||||
StreamData::Raw(Ok(err_pattern.clone())),
|
||||
));
|
||||
}
|
||||
test.add(StreamMessage::End(12));
|
||||
test.add(StreamMessage::End(13));
|
||||
test.add(StreamMessage::End(14));
|
||||
|
||||
let test_span = Span::new(10, 13);
|
||||
let header = PipelineDataHeader::ExternalStream(ExternalStreamInfo {
|
||||
span: test_span,
|
||||
stdout: Some(RawStreamInfo {
|
||||
id: 12,
|
||||
is_binary: false,
|
||||
known_size: Some((out_pattern.len() * iterations) as u64),
|
||||
}),
|
||||
stderr: Some(RawStreamInfo {
|
||||
id: 13,
|
||||
is_binary: true,
|
||||
known_size: None,
|
||||
}),
|
||||
exit_code: Some(ListStreamInfo { id: 14 }),
|
||||
trim_end_newline: true,
|
||||
});
|
||||
|
||||
let pipe = manager.read_pipeline_data(header, None)?;
|
||||
|
||||
// need to consume input
|
||||
manager.consume_all()?;
|
||||
|
||||
match pipe {
|
||||
PipelineData::ExternalStream {
|
||||
stdout,
|
||||
stderr,
|
||||
exit_code,
|
||||
span,
|
||||
metadata,
|
||||
trim_end_newline,
|
||||
} => {
|
||||
let stdout = stdout.expect("stdout is None");
|
||||
let stderr = stderr.expect("stderr is None");
|
||||
let exit_code = exit_code.expect("exit_code is None");
|
||||
assert_eq!(test_span, span);
|
||||
assert!(
|
||||
metadata.is_some(),
|
||||
"expected metadata to be Some due to prepare_pipeline_data()"
|
||||
);
|
||||
assert!(trim_end_newline);
|
||||
|
||||
assert!(!stdout.is_binary);
|
||||
assert!(stderr.is_binary);
|
||||
|
||||
assert_eq!(
|
||||
Some((out_pattern.len() * iterations) as u64),
|
||||
stdout.known_size
|
||||
);
|
||||
assert_eq!(None, stderr.known_size);
|
||||
|
||||
// check the streams
|
||||
let mut count = 0;
|
||||
for chunk in stdout.stream {
|
||||
assert_eq!(out_pattern, chunk?);
|
||||
count += 1;
|
||||
}
|
||||
assert_eq!(iterations, count, "stdout length");
|
||||
let mut count = 0;
|
||||
|
||||
for chunk in stderr.stream {
|
||||
assert_eq!(err_pattern, chunk?);
|
||||
count += 1;
|
||||
}
|
||||
assert_eq!(iterations, count, "stderr length");
|
||||
|
||||
assert_eq!(vec![Value::test_int(1)], exit_code.collect::<Vec<_>>());
|
||||
}
|
||||
_ => panic!("unexpected PipelineData: {pipe:?}"),
|
||||
}
|
||||
|
||||
// Don't need to check exactly what was written, just be sure that there is some output
|
||||
assert!(test.has_unconsumed_write());
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn read_pipeline_data_ctrlc() -> Result<(), ShellError> {
|
||||
let manager = TestInterfaceManager::new(&TestCase::new());
|
||||
let header = PipelineDataHeader::ListStream(ListStreamInfo { id: 0 });
|
||||
let ctrlc = Default::default();
|
||||
match manager.read_pipeline_data(header, Some(&ctrlc))? {
|
||||
PipelineData::ListStream(
|
||||
ListStream {
|
||||
ctrlc: stream_ctrlc,
|
||||
..
|
||||
},
|
||||
_,
|
||||
) => {
|
||||
assert!(Arc::ptr_eq(&ctrlc, &stream_ctrlc.expect("ctrlc not set")));
|
||||
Ok(())
|
||||
}
|
||||
_ => panic!("Unexpected PipelineData, should have been ListStream"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn read_pipeline_data_prepared_properly() -> Result<(), ShellError> {
|
||||
let manager = TestInterfaceManager::new(&TestCase::new());
|
||||
let header = PipelineDataHeader::ListStream(ListStreamInfo { id: 0 });
|
||||
match manager.read_pipeline_data(header, None)? {
|
||||
PipelineData::ListStream(_, meta) => match meta {
|
||||
Some(PipelineMetadata { data_source }) => match data_source {
|
||||
DataSource::FilePath(path) => {
|
||||
assert_eq!(Path::new("/test/path"), path);
|
||||
Ok(())
|
||||
}
|
||||
_ => panic!("wrong metadata: {data_source:?}"),
|
||||
},
|
||||
None => panic!("metadata not set"),
|
||||
},
|
||||
_ => panic!("Unexpected PipelineData, should have been ListStream"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn write_pipeline_data_empty() -> Result<(), ShellError> {
|
||||
let test = TestCase::new();
|
||||
let manager = TestInterfaceManager::new(&test);
|
||||
let interface = manager.get_interface();
|
||||
|
||||
let (header, writer) = interface.init_write_pipeline_data(PipelineData::Empty, &())?;
|
||||
|
||||
assert!(matches!(header, PipelineDataHeader::Empty));
|
||||
|
||||
writer.write()?;
|
||||
|
||||
assert!(
|
||||
!test.has_unconsumed_write(),
|
||||
"Empty shouldn't write any stream messages, test: {test:#?}"
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn write_pipeline_data_value() -> Result<(), ShellError> {
|
||||
let test = TestCase::new();
|
||||
let manager = TestInterfaceManager::new(&test);
|
||||
let interface = manager.get_interface();
|
||||
let value = Value::test_int(7);
|
||||
|
||||
let (header, writer) =
|
||||
interface.init_write_pipeline_data(PipelineData::Value(value.clone(), None), &())?;
|
||||
|
||||
match header {
|
||||
PipelineDataHeader::Value(read_value) => assert_eq!(value, read_value),
|
||||
_ => panic!("unexpected header: {header:?}"),
|
||||
}
|
||||
|
||||
writer.write()?;
|
||||
|
||||
assert!(
|
||||
!test.has_unconsumed_write(),
|
||||
"Value shouldn't write any stream messages, test: {test:#?}"
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn write_pipeline_data_prepared_properly() {
|
||||
let manager = TestInterfaceManager::new(&TestCase::new());
|
||||
let interface = manager.get_interface();
|
||||
|
||||
// Sending a binary should be an error in our test scenario
|
||||
let value = Value::test_binary(vec![7, 8]);
|
||||
|
||||
match interface.init_write_pipeline_data(PipelineData::Value(value, None), &()) {
|
||||
Ok(_) => panic!("prepare_pipeline_data was not called"),
|
||||
Err(err) => {
|
||||
assert_eq!(
|
||||
ShellError::NushellFailed {
|
||||
msg: "TEST can't send binary".into()
|
||||
}
|
||||
.to_string(),
|
||||
err.to_string()
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn write_pipeline_data_list_stream() -> Result<(), ShellError> {
|
||||
let test = TestCase::new();
|
||||
let manager = TestInterfaceManager::new(&test);
|
||||
let interface = manager.get_interface();
|
||||
|
||||
let values = vec![
|
||||
Value::test_int(40),
|
||||
Value::test_bool(false),
|
||||
Value::test_string("this is a test"),
|
||||
];
|
||||
|
||||
// Set up pipeline data for a list stream
|
||||
let pipe = PipelineData::ListStream(
|
||||
ListStream::from_stream(values.clone().into_iter(), None),
|
||||
None,
|
||||
);
|
||||
|
||||
let (header, writer) = interface.init_write_pipeline_data(pipe, &())?;
|
||||
|
||||
let info = match header {
|
||||
PipelineDataHeader::ListStream(info) => info,
|
||||
_ => panic!("unexpected header: {header:?}"),
|
||||
};
|
||||
|
||||
writer.write()?;
|
||||
|
||||
// Now make sure the stream messages have been written
|
||||
for value in values {
|
||||
match test.next_written().expect("unexpected end of stream") {
|
||||
PluginOutput::Data(id, data) => {
|
||||
assert_eq!(info.id, id, "Data id");
|
||||
match data {
|
||||
StreamData::List(read_value) => assert_eq!(value, read_value, "Data value"),
|
||||
_ => panic!("unexpected Data: {data:?}"),
|
||||
}
|
||||
}
|
||||
other => panic!("unexpected output: {other:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
match test.next_written().expect("unexpected end of stream") {
|
||||
PluginOutput::End(id) => {
|
||||
assert_eq!(info.id, id, "End id");
|
||||
}
|
||||
other => panic!("unexpected output: {other:?}"),
|
||||
}
|
||||
|
||||
assert!(!test.has_unconsumed_write());
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn write_pipeline_data_external_stream() -> Result<(), ShellError> {
|
||||
let test = TestCase::new();
|
||||
let manager = TestInterfaceManager::new(&test);
|
||||
let interface = manager.get_interface();
|
||||
|
||||
let stdout_bufs = vec![
|
||||
b"hello".to_vec(),
|
||||
b"world".to_vec(),
|
||||
b"these are tests".to_vec(),
|
||||
];
|
||||
let stdout_len = stdout_bufs.iter().map(|b| b.len() as u64).sum::<u64>();
|
||||
let stderr_bufs = vec![b"error messages".to_vec(), b"go here".to_vec()];
|
||||
let exit_code = Value::test_int(7);
|
||||
|
||||
let span = Span::new(400, 500);
|
||||
|
||||
// Set up pipeline data for an external stream
|
||||
let pipe = PipelineData::ExternalStream {
|
||||
stdout: Some(RawStream::new(
|
||||
Box::new(stdout_bufs.clone().into_iter().map(Ok)),
|
||||
None,
|
||||
span,
|
||||
Some(stdout_len),
|
||||
)),
|
||||
stderr: Some(RawStream::new(
|
||||
Box::new(stderr_bufs.clone().into_iter().map(Ok)),
|
||||
None,
|
||||
span,
|
||||
None,
|
||||
)),
|
||||
exit_code: Some(ListStream::from_stream(
|
||||
std::iter::once(exit_code.clone()),
|
||||
None,
|
||||
)),
|
||||
span,
|
||||
metadata: None,
|
||||
trim_end_newline: true,
|
||||
};
|
||||
|
||||
let (header, writer) = interface.init_write_pipeline_data(pipe, &())?;
|
||||
|
||||
let info = match header {
|
||||
PipelineDataHeader::ExternalStream(info) => info,
|
||||
_ => panic!("unexpected header: {header:?}"),
|
||||
};
|
||||
|
||||
writer.write()?;
|
||||
|
||||
let stdout_info = info.stdout.as_ref().expect("stdout info is None");
|
||||
let stderr_info = info.stderr.as_ref().expect("stderr info is None");
|
||||
let exit_code_info = info.exit_code.as_ref().expect("exit code info is None");
|
||||
|
||||
assert_eq!(span, info.span);
|
||||
assert!(info.trim_end_newline);
|
||||
|
||||
assert_eq!(Some(stdout_len), stdout_info.known_size);
|
||||
assert_eq!(None, stderr_info.known_size);
|
||||
|
||||
// Now make sure the stream messages have been written
|
||||
let mut stdout_iter = stdout_bufs.into_iter();
|
||||
let mut stderr_iter = stderr_bufs.into_iter();
|
||||
let mut exit_code_iter = std::iter::once(exit_code);
|
||||
|
||||
let mut stdout_ended = false;
|
||||
let mut stderr_ended = false;
|
||||
let mut exit_code_ended = false;
|
||||
|
||||
// There's no specific order these messages must come in with respect to how the streams are
|
||||
// interleaved, but all of the data for each stream must be in its original order, and the
|
||||
// End must come after all Data
|
||||
for msg in test.written() {
|
||||
match msg {
|
||||
PluginOutput::Data(id, data) => {
|
||||
if id == stdout_info.id {
|
||||
let result: Result<Vec<u8>, ShellError> =
|
||||
data.try_into().expect("wrong data in stdout stream");
|
||||
assert_eq!(
|
||||
stdout_iter.next().expect("too much data in stdout"),
|
||||
result.expect("unexpected error in stdout stream")
|
||||
);
|
||||
} else if id == stderr_info.id {
|
||||
let result: Result<Vec<u8>, ShellError> =
|
||||
data.try_into().expect("wrong data in stderr stream");
|
||||
assert_eq!(
|
||||
stderr_iter.next().expect("too much data in stderr"),
|
||||
result.expect("unexpected error in stderr stream")
|
||||
);
|
||||
} else if id == exit_code_info.id {
|
||||
let code: Value = data.try_into().expect("wrong data in stderr stream");
|
||||
assert_eq!(
|
||||
exit_code_iter.next().expect("too much data in stderr"),
|
||||
code
|
||||
);
|
||||
} else {
|
||||
panic!("unrecognized stream id: {id}");
|
||||
}
|
||||
}
|
||||
PluginOutput::End(id) => {
|
||||
if id == stdout_info.id {
|
||||
assert!(!stdout_ended, "double End of stdout");
|
||||
assert!(stdout_iter.next().is_none(), "unexpected end of stdout");
|
||||
stdout_ended = true;
|
||||
} else if id == stderr_info.id {
|
||||
assert!(!stderr_ended, "double End of stderr");
|
||||
assert!(stderr_iter.next().is_none(), "unexpected end of stderr");
|
||||
stderr_ended = true;
|
||||
} else if id == exit_code_info.id {
|
||||
assert!(!exit_code_ended, "double End of exit_code");
|
||||
assert!(
|
||||
exit_code_iter.next().is_none(),
|
||||
"unexpected end of exit_code"
|
||||
);
|
||||
exit_code_ended = true;
|
||||
} else {
|
||||
panic!("unrecognized stream id: {id}");
|
||||
}
|
||||
}
|
||||
other => panic!("unexpected output: {other:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
assert!(stdout_ended, "stdout did not End");
|
||||
assert!(stderr_ended, "stderr did not End");
|
||||
assert!(exit_code_ended, "exit_code did not End");
|
||||
|
||||
Ok(())
|
||||
}
|
24
crates/nu-plugin-core/src/lib.rs
Normal file
24
crates/nu-plugin-core/src/lib.rs
Normal file
@ -0,0 +1,24 @@
|
||||
//! Functionality and types shared between the plugin and the engine, other than protocol types.
|
||||
//!
|
||||
//! If you are writing a plugin, you probably don't need this crate. We will make fewer guarantees
|
||||
//! for the stability of the interface of this crate than for `nu_plugin`.
|
||||
|
||||
pub mod util;
|
||||
|
||||
mod communication_mode;
|
||||
mod interface;
|
||||
mod serializers;
|
||||
|
||||
pub use communication_mode::{
|
||||
ClientCommunicationIo, CommunicationMode, PreparedServerCommunication, ServerCommunicationIo,
|
||||
};
|
||||
pub use interface::{
|
||||
stream::{FromShellError, StreamManager, StreamManagerHandle, StreamReader, StreamWriter},
|
||||
Interface, InterfaceManager, PipelineDataWriter, PluginRead, PluginWrite,
|
||||
};
|
||||
pub use serializers::{
|
||||
json::JsonSerializer, msgpack::MsgPackSerializer, Encoder, EncodingType, PluginEncoder,
|
||||
};
|
||||
|
||||
#[doc(hidden)]
|
||||
pub use interface::test_util as interface_test_util;
|
133
crates/nu-plugin-core/src/serializers/json.rs
Normal file
133
crates/nu-plugin-core/src/serializers/json.rs
Normal file
@ -0,0 +1,133 @@
|
||||
use nu_plugin_protocol::{PluginInput, PluginOutput};
|
||||
use nu_protocol::ShellError;
|
||||
use serde::Deserialize;
|
||||
|
||||
use crate::{Encoder, PluginEncoder};
|
||||
|
||||
/// A `PluginEncoder` that enables the plugin to communicate with Nushell with JSON
|
||||
/// serialized data.
|
||||
///
|
||||
/// Each message in the stream is followed by a newline when serializing, but is not required for
|
||||
/// deserialization. The output is not pretty printed and each object does not contain newlines.
|
||||
/// If it is more convenient, a plugin may choose to separate messages by newline.
|
||||
#[derive(Clone, Copy, Debug)]
|
||||
pub struct JsonSerializer;
|
||||
|
||||
impl PluginEncoder for JsonSerializer {
|
||||
fn name(&self) -> &str {
|
||||
"json"
|
||||
}
|
||||
}
|
||||
|
||||
impl Encoder<PluginInput> for JsonSerializer {
|
||||
fn encode(
|
||||
&self,
|
||||
plugin_input: &PluginInput,
|
||||
writer: &mut impl std::io::Write,
|
||||
) -> Result<(), nu_protocol::ShellError> {
|
||||
serde_json::to_writer(&mut *writer, plugin_input).map_err(json_encode_err)?;
|
||||
writer.write_all(b"\n").map_err(|err| ShellError::IOError {
|
||||
msg: err.to_string(),
|
||||
})
|
||||
}
|
||||
|
||||
fn decode(
|
||||
&self,
|
||||
reader: &mut impl std::io::BufRead,
|
||||
) -> Result<Option<PluginInput>, nu_protocol::ShellError> {
|
||||
let mut de = serde_json::Deserializer::from_reader(reader);
|
||||
PluginInput::deserialize(&mut de)
|
||||
.map(Some)
|
||||
.or_else(json_decode_err)
|
||||
}
|
||||
}
|
||||
|
||||
impl Encoder<PluginOutput> for JsonSerializer {
|
||||
fn encode(
|
||||
&self,
|
||||
plugin_output: &PluginOutput,
|
||||
writer: &mut impl std::io::Write,
|
||||
) -> Result<(), ShellError> {
|
||||
serde_json::to_writer(&mut *writer, plugin_output).map_err(json_encode_err)?;
|
||||
writer.write_all(b"\n").map_err(|err| ShellError::IOError {
|
||||
msg: err.to_string(),
|
||||
})
|
||||
}
|
||||
|
||||
fn decode(
|
||||
&self,
|
||||
reader: &mut impl std::io::BufRead,
|
||||
) -> Result<Option<PluginOutput>, ShellError> {
|
||||
let mut de = serde_json::Deserializer::from_reader(reader);
|
||||
PluginOutput::deserialize(&mut de)
|
||||
.map(Some)
|
||||
.or_else(json_decode_err)
|
||||
}
|
||||
}
|
||||
|
||||
/// Handle a `serde_json` encode error.
|
||||
fn json_encode_err(err: serde_json::Error) -> ShellError {
|
||||
if err.is_io() {
|
||||
ShellError::IOError {
|
||||
msg: err.to_string(),
|
||||
}
|
||||
} else {
|
||||
ShellError::PluginFailedToEncode {
|
||||
msg: err.to_string(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Handle a `serde_json` decode error. Returns `Ok(None)` on eof.
|
||||
fn json_decode_err<T>(err: serde_json::Error) -> Result<Option<T>, ShellError> {
|
||||
if err.is_eof() {
|
||||
Ok(None)
|
||||
} else if err.is_io() {
|
||||
Err(ShellError::IOError {
|
||||
msg: err.to_string(),
|
||||
})
|
||||
} else {
|
||||
Err(ShellError::PluginFailedToDecode {
|
||||
msg: err.to_string(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
crate::serializers::tests::generate_tests!(JsonSerializer {});
|
||||
|
||||
#[test]
|
||||
fn json_ends_in_newline() {
|
||||
let mut out = vec![];
|
||||
JsonSerializer {}
|
||||
.encode(&PluginInput::Call(0, PluginCall::Signature), &mut out)
|
||||
.expect("serialization error");
|
||||
let string = std::str::from_utf8(&out).expect("utf-8 error");
|
||||
assert!(
|
||||
string.ends_with('\n'),
|
||||
"doesn't end with newline: {:?}",
|
||||
string
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn json_has_no_other_newlines() {
|
||||
let mut out = vec![];
|
||||
// use something deeply nested, to try to trigger any pretty printing
|
||||
let output = PluginOutput::Data(
|
||||
0,
|
||||
StreamData::List(Value::test_list(vec![
|
||||
Value::test_int(4),
|
||||
// in case escaping failed
|
||||
Value::test_string("newline\ncontaining\nstring"),
|
||||
])),
|
||||
);
|
||||
JsonSerializer {}
|
||||
.encode(&output, &mut out)
|
||||
.expect("serialization error");
|
||||
let string = std::str::from_utf8(&out).expect("utf-8 error");
|
||||
assert_eq!(1, string.chars().filter(|ch| *ch == '\n').count());
|
||||
}
|
||||
}
|
71
crates/nu-plugin-core/src/serializers/mod.rs
Normal file
71
crates/nu-plugin-core/src/serializers/mod.rs
Normal file
@ -0,0 +1,71 @@
|
||||
use nu_plugin_protocol::{PluginInput, PluginOutput};
|
||||
use nu_protocol::ShellError;
|
||||
|
||||
pub mod json;
|
||||
pub mod msgpack;
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests;
|
||||
|
||||
/// Encoder for a specific message type. Usually implemented on [`PluginInput`]
|
||||
/// and [`PluginOutput`].
|
||||
pub trait Encoder<T>: Clone + Send + Sync {
|
||||
/// Serialize a value in the [`PluginEncoder`]s format
|
||||
///
|
||||
/// Returns [`ShellError::IOError`] if there was a problem writing, or
|
||||
/// [`ShellError::PluginFailedToEncode`] for a serialization error.
|
||||
fn encode(&self, data: &T, writer: &mut impl std::io::Write) -> Result<(), ShellError>;
|
||||
|
||||
/// Deserialize a value from the [`PluginEncoder`]'s format
|
||||
///
|
||||
/// Returns `None` if there is no more output to receive.
|
||||
///
|
||||
/// Returns [`ShellError::IOError`] if there was a problem reading, or
|
||||
/// [`ShellError::PluginFailedToDecode`] for a deserialization error.
|
||||
fn decode(&self, reader: &mut impl std::io::BufRead) -> Result<Option<T>, ShellError>;
|
||||
}
|
||||
|
||||
/// Encoding scheme that defines a plugin's communication protocol with Nu
|
||||
pub trait PluginEncoder: Encoder<PluginInput> + Encoder<PluginOutput> {
|
||||
/// The name of the encoder (e.g., `json`)
|
||||
fn name(&self) -> &str;
|
||||
}
|
||||
|
||||
/// Enum that supports all of the plugin serialization formats.
|
||||
#[derive(Clone, Copy, Debug)]
|
||||
pub enum EncodingType {
|
||||
Json(json::JsonSerializer),
|
||||
MsgPack(msgpack::MsgPackSerializer),
|
||||
}
|
||||
|
||||
impl EncodingType {
|
||||
/// Determine the plugin encoding type from the provided byte string (either `b"json"` or
|
||||
/// `b"msgpack"`).
|
||||
pub fn try_from_bytes(bytes: &[u8]) -> Option<Self> {
|
||||
match bytes {
|
||||
b"json" => Some(Self::Json(json::JsonSerializer {})),
|
||||
b"msgpack" => Some(Self::MsgPack(msgpack::MsgPackSerializer {})),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> Encoder<T> for EncodingType
|
||||
where
|
||||
json::JsonSerializer: Encoder<T>,
|
||||
msgpack::MsgPackSerializer: Encoder<T>,
|
||||
{
|
||||
fn encode(&self, data: &T, writer: &mut impl std::io::Write) -> Result<(), ShellError> {
|
||||
match self {
|
||||
EncodingType::Json(encoder) => encoder.encode(data, writer),
|
||||
EncodingType::MsgPack(encoder) => encoder.encode(data, writer),
|
||||
}
|
||||
}
|
||||
|
||||
fn decode(&self, reader: &mut impl std::io::BufRead) -> Result<Option<T>, ShellError> {
|
||||
match self {
|
||||
EncodingType::Json(encoder) => encoder.decode(reader),
|
||||
EncodingType::MsgPack(encoder) => encoder.decode(reader),
|
||||
}
|
||||
}
|
||||
}
|
108
crates/nu-plugin-core/src/serializers/msgpack.rs
Normal file
108
crates/nu-plugin-core/src/serializers/msgpack.rs
Normal file
@ -0,0 +1,108 @@
|
||||
use std::io::ErrorKind;
|
||||
|
||||
use nu_plugin_protocol::{PluginInput, PluginOutput};
|
||||
use nu_protocol::ShellError;
|
||||
use serde::Deserialize;
|
||||
|
||||
use crate::{Encoder, PluginEncoder};
|
||||
|
||||
/// A `PluginEncoder` that enables the plugin to communicate with Nushell with MsgPack
|
||||
/// serialized data.
|
||||
///
|
||||
/// Each message is written as a MessagePack object. There is no message envelope or separator.
|
||||
#[derive(Clone, Copy, Debug)]
|
||||
pub struct MsgPackSerializer;
|
||||
|
||||
impl PluginEncoder for MsgPackSerializer {
|
||||
fn name(&self) -> &str {
|
||||
"msgpack"
|
||||
}
|
||||
}
|
||||
|
||||
impl Encoder<PluginInput> for MsgPackSerializer {
|
||||
fn encode(
|
||||
&self,
|
||||
plugin_input: &PluginInput,
|
||||
writer: &mut impl std::io::Write,
|
||||
) -> Result<(), nu_protocol::ShellError> {
|
||||
rmp_serde::encode::write_named(writer, plugin_input).map_err(rmp_encode_err)
|
||||
}
|
||||
|
||||
fn decode(
|
||||
&self,
|
||||
reader: &mut impl std::io::BufRead,
|
||||
) -> Result<Option<PluginInput>, ShellError> {
|
||||
let mut de = rmp_serde::Deserializer::new(reader);
|
||||
PluginInput::deserialize(&mut de)
|
||||
.map(Some)
|
||||
.or_else(rmp_decode_err)
|
||||
}
|
||||
}
|
||||
|
||||
impl Encoder<PluginOutput> for MsgPackSerializer {
|
||||
fn encode(
|
||||
&self,
|
||||
plugin_output: &PluginOutput,
|
||||
writer: &mut impl std::io::Write,
|
||||
) -> Result<(), ShellError> {
|
||||
rmp_serde::encode::write_named(writer, plugin_output).map_err(rmp_encode_err)
|
||||
}
|
||||
|
||||
fn decode(
|
||||
&self,
|
||||
reader: &mut impl std::io::BufRead,
|
||||
) -> Result<Option<PluginOutput>, ShellError> {
|
||||
let mut de = rmp_serde::Deserializer::new(reader);
|
||||
PluginOutput::deserialize(&mut de)
|
||||
.map(Some)
|
||||
.or_else(rmp_decode_err)
|
||||
}
|
||||
}
|
||||
|
||||
/// Handle a msgpack encode error
|
||||
fn rmp_encode_err(err: rmp_serde::encode::Error) -> ShellError {
|
||||
match err {
|
||||
rmp_serde::encode::Error::InvalidValueWrite(_) => {
|
||||
// I/O error
|
||||
ShellError::IOError {
|
||||
msg: err.to_string(),
|
||||
}
|
||||
}
|
||||
_ => {
|
||||
// Something else
|
||||
ShellError::PluginFailedToEncode {
|
||||
msg: err.to_string(),
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Handle a msgpack decode error. Returns `Ok(None)` on eof
|
||||
fn rmp_decode_err<T>(err: rmp_serde::decode::Error) -> Result<Option<T>, ShellError> {
|
||||
match err {
|
||||
rmp_serde::decode::Error::InvalidMarkerRead(err)
|
||||
| rmp_serde::decode::Error::InvalidDataRead(err) => {
|
||||
if matches!(err.kind(), ErrorKind::UnexpectedEof) {
|
||||
// EOF
|
||||
Ok(None)
|
||||
} else {
|
||||
// I/O error
|
||||
Err(ShellError::IOError {
|
||||
msg: err.to_string(),
|
||||
})
|
||||
}
|
||||
}
|
||||
_ => {
|
||||
// Something else
|
||||
Err(ShellError::PluginFailedToDecode {
|
||||
msg: err.to_string(),
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
crate::serializers::tests::generate_tests!(MsgPackSerializer {});
|
||||
}
|
557
crates/nu-plugin-core/src/serializers/tests.rs
Normal file
557
crates/nu-plugin-core/src/serializers/tests.rs
Normal file
@ -0,0 +1,557 @@
|
||||
macro_rules! generate_tests {
|
||||
($encoder:expr) => {
|
||||
use nu_plugin_protocol::{
|
||||
CallInfo, CustomValueOp, EvaluatedCall, PipelineDataHeader, PluginCall,
|
||||
PluginCallResponse, PluginCustomValue, PluginInput, PluginOption, PluginOutput,
|
||||
StreamData,
|
||||
};
|
||||
use nu_protocol::{
|
||||
LabeledError, PluginSignature, Signature, Span, Spanned, SyntaxShape, Value,
|
||||
};
|
||||
|
||||
#[test]
|
||||
fn decode_eof() {
|
||||
let mut buffer: &[u8] = &[];
|
||||
let encoder = $encoder;
|
||||
let result: Option<PluginInput> = encoder
|
||||
.decode(&mut buffer)
|
||||
.expect("eof should not result in an error");
|
||||
assert!(result.is_none(), "decode result: {result:?}");
|
||||
let result: Option<PluginOutput> = encoder
|
||||
.decode(&mut buffer)
|
||||
.expect("eof should not result in an error");
|
||||
assert!(result.is_none(), "decode result: {result:?}");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn decode_io_error() {
|
||||
struct ErrorProducer;
|
||||
impl std::io::Read for ErrorProducer {
|
||||
fn read(&mut self, _buf: &mut [u8]) -> std::io::Result<usize> {
|
||||
Err(std::io::Error::from(std::io::ErrorKind::ConnectionReset))
|
||||
}
|
||||
}
|
||||
let encoder = $encoder;
|
||||
let mut buffered = std::io::BufReader::new(ErrorProducer);
|
||||
match Encoder::<PluginInput>::decode(&encoder, &mut buffered) {
|
||||
Ok(_) => panic!("decode: i/o error was not passed through"),
|
||||
Err(ShellError::IOError { .. }) => (), // okay
|
||||
Err(other) => panic!(
|
||||
"decode: got other error, should have been a \
|
||||
ShellError::IOError: {other:?}"
|
||||
),
|
||||
}
|
||||
match Encoder::<PluginOutput>::decode(&encoder, &mut buffered) {
|
||||
Ok(_) => panic!("decode: i/o error was not passed through"),
|
||||
Err(ShellError::IOError { .. }) => (), // okay
|
||||
Err(other) => panic!(
|
||||
"decode: got other error, should have been a \
|
||||
ShellError::IOError: {other:?}"
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn decode_gibberish() {
|
||||
// just a sequence of bytes that shouldn't be valid in anything we use
|
||||
let gibberish: &[u8] = &[
|
||||
0, 80, 74, 85, 117, 122, 86, 100, 74, 115, 20, 104, 55, 98, 67, 203, 83, 85, 77,
|
||||
112, 74, 79, 254, 71, 80,
|
||||
];
|
||||
let encoder = $encoder;
|
||||
|
||||
let mut buffered = std::io::BufReader::new(&gibberish[..]);
|
||||
match Encoder::<PluginInput>::decode(&encoder, &mut buffered) {
|
||||
Ok(value) => panic!("decode: parsed successfully => {value:?}"),
|
||||
Err(ShellError::PluginFailedToDecode { .. }) => (), // okay
|
||||
Err(other) => panic!(
|
||||
"decode: got other error, should have been a \
|
||||
ShellError::PluginFailedToDecode: {other:?}"
|
||||
),
|
||||
}
|
||||
|
||||
let mut buffered = std::io::BufReader::new(&gibberish[..]);
|
||||
match Encoder::<PluginOutput>::decode(&encoder, &mut buffered) {
|
||||
Ok(value) => panic!("decode: parsed successfully => {value:?}"),
|
||||
Err(ShellError::PluginFailedToDecode { .. }) => (), // okay
|
||||
Err(other) => panic!(
|
||||
"decode: got other error, should have been a \
|
||||
ShellError::PluginFailedToDecode: {other:?}"
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn call_round_trip_signature() {
|
||||
let plugin_call = PluginCall::Signature;
|
||||
let plugin_input = PluginInput::Call(0, plugin_call);
|
||||
let encoder = $encoder;
|
||||
|
||||
let mut buffer: Vec<u8> = Vec::new();
|
||||
encoder
|
||||
.encode(&plugin_input, &mut buffer)
|
||||
.expect("unable to serialize message");
|
||||
let returned = encoder
|
||||
.decode(&mut buffer.as_slice())
|
||||
.expect("unable to deserialize message")
|
||||
.expect("eof");
|
||||
|
||||
match returned {
|
||||
PluginInput::Call(0, PluginCall::Signature) => {}
|
||||
_ => panic!("decoded into wrong value: {returned:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn call_round_trip_run() {
|
||||
let name = "test".to_string();
|
||||
|
||||
let input = Value::bool(false, Span::new(1, 20));
|
||||
|
||||
let call = EvaluatedCall {
|
||||
head: Span::new(0, 10),
|
||||
positional: vec![
|
||||
Value::float(1.0, Span::new(0, 10)),
|
||||
Value::string("something", Span::new(0, 10)),
|
||||
],
|
||||
named: vec![(
|
||||
Spanned {
|
||||
item: "name".to_string(),
|
||||
span: Span::new(0, 10),
|
||||
},
|
||||
Some(Value::float(1.0, Span::new(0, 10))),
|
||||
)],
|
||||
};
|
||||
|
||||
let plugin_call = PluginCall::Run(CallInfo {
|
||||
name: name.clone(),
|
||||
call: call.clone(),
|
||||
input: PipelineDataHeader::Value(input.clone()),
|
||||
});
|
||||
|
||||
let plugin_input = PluginInput::Call(1, plugin_call);
|
||||
|
||||
let encoder = $encoder;
|
||||
let mut buffer: Vec<u8> = Vec::new();
|
||||
encoder
|
||||
.encode(&plugin_input, &mut buffer)
|
||||
.expect("unable to serialize message");
|
||||
let returned = encoder
|
||||
.decode(&mut buffer.as_slice())
|
||||
.expect("unable to deserialize message")
|
||||
.expect("eof");
|
||||
|
||||
match returned {
|
||||
PluginInput::Call(1, PluginCall::Run(call_info)) => {
|
||||
assert_eq!(name, call_info.name);
|
||||
assert_eq!(PipelineDataHeader::Value(input), call_info.input);
|
||||
assert_eq!(call.head, call_info.call.head);
|
||||
assert_eq!(call.positional.len(), call_info.call.positional.len());
|
||||
|
||||
call.positional
|
||||
.iter()
|
||||
.zip(call_info.call.positional.iter())
|
||||
.for_each(|(lhs, rhs)| assert_eq!(lhs, rhs));
|
||||
|
||||
call.named
|
||||
.iter()
|
||||
.zip(call_info.call.named.iter())
|
||||
.for_each(|(lhs, rhs)| {
|
||||
// Comparing the keys
|
||||
assert_eq!(lhs.0.item, rhs.0.item);
|
||||
|
||||
match (&lhs.1, &rhs.1) {
|
||||
(None, None) => {}
|
||||
(Some(a), Some(b)) => assert_eq!(a, b),
|
||||
_ => panic!("not matching values"),
|
||||
}
|
||||
});
|
||||
}
|
||||
_ => panic!("decoded into wrong value: {returned:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn call_round_trip_customvalueop() {
|
||||
let data = vec![1, 2, 3, 4, 5, 6, 7];
|
||||
let span = Span::new(0, 20);
|
||||
|
||||
let custom_value_op = PluginCall::CustomValueOp(
|
||||
Spanned {
|
||||
item: PluginCustomValue::new("Foo".into(), data.clone(), false),
|
||||
span,
|
||||
},
|
||||
CustomValueOp::ToBaseValue,
|
||||
);
|
||||
|
||||
let plugin_input = PluginInput::Call(2, custom_value_op);
|
||||
|
||||
let encoder = $encoder;
|
||||
let mut buffer: Vec<u8> = Vec::new();
|
||||
encoder
|
||||
.encode(&plugin_input, &mut buffer)
|
||||
.expect("unable to serialize message");
|
||||
let returned = encoder
|
||||
.decode(&mut buffer.as_slice())
|
||||
.expect("unable to deserialize message")
|
||||
.expect("eof");
|
||||
|
||||
match returned {
|
||||
PluginInput::Call(2, PluginCall::CustomValueOp(val, op)) => {
|
||||
assert_eq!("Foo", val.item.name());
|
||||
assert_eq!(data, val.item.data());
|
||||
assert_eq!(span, val.span);
|
||||
#[allow(unreachable_patterns)]
|
||||
match op {
|
||||
CustomValueOp::ToBaseValue => (),
|
||||
_ => panic!("wrong op: {op:?}"),
|
||||
}
|
||||
}
|
||||
_ => panic!("decoded into wrong value: {returned:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn response_round_trip_signature() {
|
||||
let signature = PluginSignature::new(
|
||||
Signature::build("nu-plugin")
|
||||
.required("first", SyntaxShape::String, "first required")
|
||||
.required("second", SyntaxShape::Int, "second required")
|
||||
.required_named("first-named", SyntaxShape::String, "first named", Some('f'))
|
||||
.required_named(
|
||||
"second-named",
|
||||
SyntaxShape::String,
|
||||
"second named",
|
||||
Some('s'),
|
||||
)
|
||||
.rest("remaining", SyntaxShape::Int, "remaining"),
|
||||
vec![],
|
||||
);
|
||||
|
||||
let response = PluginCallResponse::Signature(vec![signature.clone()]);
|
||||
let output = PluginOutput::CallResponse(3, response);
|
||||
|
||||
let encoder = $encoder;
|
||||
let mut buffer: Vec<u8> = Vec::new();
|
||||
encoder
|
||||
.encode(&output, &mut buffer)
|
||||
.expect("unable to serialize message");
|
||||
let returned = encoder
|
||||
.decode(&mut buffer.as_slice())
|
||||
.expect("unable to deserialize message")
|
||||
.expect("eof");
|
||||
|
||||
match returned {
|
||||
PluginOutput::CallResponse(
|
||||
3,
|
||||
PluginCallResponse::Signature(returned_signature),
|
||||
) => {
|
||||
assert_eq!(returned_signature.len(), 1);
|
||||
assert_eq!(signature.sig.name, returned_signature[0].sig.name);
|
||||
assert_eq!(signature.sig.usage, returned_signature[0].sig.usage);
|
||||
assert_eq!(
|
||||
signature.sig.extra_usage,
|
||||
returned_signature[0].sig.extra_usage
|
||||
);
|
||||
assert_eq!(signature.sig.is_filter, returned_signature[0].sig.is_filter);
|
||||
|
||||
signature
|
||||
.sig
|
||||
.required_positional
|
||||
.iter()
|
||||
.zip(returned_signature[0].sig.required_positional.iter())
|
||||
.for_each(|(lhs, rhs)| assert_eq!(lhs, rhs));
|
||||
|
||||
signature
|
||||
.sig
|
||||
.optional_positional
|
||||
.iter()
|
||||
.zip(returned_signature[0].sig.optional_positional.iter())
|
||||
.for_each(|(lhs, rhs)| assert_eq!(lhs, rhs));
|
||||
|
||||
signature
|
||||
.sig
|
||||
.named
|
||||
.iter()
|
||||
.zip(returned_signature[0].sig.named.iter())
|
||||
.for_each(|(lhs, rhs)| assert_eq!(lhs, rhs));
|
||||
|
||||
assert_eq!(
|
||||
signature.sig.rest_positional,
|
||||
returned_signature[0].sig.rest_positional,
|
||||
);
|
||||
}
|
||||
_ => panic!("decoded into wrong value: {returned:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn response_round_trip_value() {
|
||||
let value = Value::int(10, Span::new(2, 30));
|
||||
|
||||
let response = PluginCallResponse::value(value.clone());
|
||||
let output = PluginOutput::CallResponse(4, response);
|
||||
|
||||
let encoder = $encoder;
|
||||
let mut buffer: Vec<u8> = Vec::new();
|
||||
encoder
|
||||
.encode(&output, &mut buffer)
|
||||
.expect("unable to serialize message");
|
||||
let returned = encoder
|
||||
.decode(&mut buffer.as_slice())
|
||||
.expect("unable to deserialize message")
|
||||
.expect("eof");
|
||||
|
||||
match returned {
|
||||
PluginOutput::CallResponse(
|
||||
4,
|
||||
PluginCallResponse::PipelineData(PipelineDataHeader::Value(returned_value)),
|
||||
) => {
|
||||
assert_eq!(value, returned_value)
|
||||
}
|
||||
_ => panic!("decoded into wrong value: {returned:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn response_round_trip_plugin_custom_value() {
|
||||
let name = "test";
|
||||
|
||||
let data = vec![1, 2, 3, 4, 5];
|
||||
let span = Span::new(2, 30);
|
||||
|
||||
let value = Value::custom(
|
||||
Box::new(PluginCustomValue::new(name.into(), data.clone(), true)),
|
||||
span,
|
||||
);
|
||||
|
||||
let response = PluginCallResponse::PipelineData(PipelineDataHeader::Value(value));
|
||||
let output = PluginOutput::CallResponse(5, response);
|
||||
|
||||
let encoder = $encoder;
|
||||
let mut buffer: Vec<u8> = Vec::new();
|
||||
encoder
|
||||
.encode(&output, &mut buffer)
|
||||
.expect("unable to serialize message");
|
||||
let returned = encoder
|
||||
.decode(&mut buffer.as_slice())
|
||||
.expect("unable to deserialize message")
|
||||
.expect("eof");
|
||||
|
||||
match returned {
|
||||
PluginOutput::CallResponse(
|
||||
5,
|
||||
PluginCallResponse::PipelineData(PipelineDataHeader::Value(returned_value)),
|
||||
) => {
|
||||
assert_eq!(span, returned_value.span());
|
||||
|
||||
if let Some(plugin_val) = returned_value
|
||||
.as_custom_value()
|
||||
.unwrap()
|
||||
.as_any()
|
||||
.downcast_ref::<PluginCustomValue>()
|
||||
{
|
||||
assert_eq!(name, plugin_val.name());
|
||||
assert_eq!(data, plugin_val.data());
|
||||
assert!(plugin_val.notify_on_drop());
|
||||
} else {
|
||||
panic!("returned CustomValue is not a PluginCustomValue");
|
||||
}
|
||||
}
|
||||
_ => panic!("decoded into wrong value: {returned:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn response_round_trip_error() {
|
||||
let error = LabeledError::new("label")
|
||||
.with_code("test::error")
|
||||
.with_url("https://example.org/test/error")
|
||||
.with_help("some help")
|
||||
.with_label("msg", Span::new(2, 30))
|
||||
.with_inner(ShellError::IOError {
|
||||
msg: "io error".into(),
|
||||
});
|
||||
|
||||
let response = PluginCallResponse::Error(error.clone());
|
||||
let output = PluginOutput::CallResponse(6, response);
|
||||
|
||||
let encoder = $encoder;
|
||||
let mut buffer: Vec<u8> = Vec::new();
|
||||
encoder
|
||||
.encode(&output, &mut buffer)
|
||||
.expect("unable to serialize message");
|
||||
let returned = encoder
|
||||
.decode(&mut buffer.as_slice())
|
||||
.expect("unable to deserialize message")
|
||||
.expect("eof");
|
||||
|
||||
match returned {
|
||||
PluginOutput::CallResponse(6, PluginCallResponse::Error(msg)) => {
|
||||
assert_eq!(error, msg)
|
||||
}
|
||||
_ => panic!("decoded into wrong value: {returned:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn response_round_trip_error_none() {
|
||||
let error = LabeledError::new("error");
|
||||
let response = PluginCallResponse::Error(error.clone());
|
||||
let output = PluginOutput::CallResponse(7, response);
|
||||
|
||||
let encoder = $encoder;
|
||||
let mut buffer: Vec<u8> = Vec::new();
|
||||
encoder
|
||||
.encode(&output, &mut buffer)
|
||||
.expect("unable to serialize message");
|
||||
let returned = encoder
|
||||
.decode(&mut buffer.as_slice())
|
||||
.expect("unable to deserialize message")
|
||||
.expect("eof");
|
||||
|
||||
match returned {
|
||||
PluginOutput::CallResponse(7, PluginCallResponse::Error(msg)) => {
|
||||
assert_eq!(error, msg)
|
||||
}
|
||||
_ => panic!("decoded into wrong value: {returned:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn input_round_trip_stream_data_list() {
|
||||
let span = Span::new(12, 30);
|
||||
let item = Value::int(1, span);
|
||||
|
||||
let stream_data = StreamData::List(item.clone());
|
||||
let plugin_input = PluginInput::Data(0, stream_data);
|
||||
|
||||
let encoder = $encoder;
|
||||
let mut buffer: Vec<u8> = Vec::new();
|
||||
encoder
|
||||
.encode(&plugin_input, &mut buffer)
|
||||
.expect("unable to serialize message");
|
||||
let returned = encoder
|
||||
.decode(&mut buffer.as_slice())
|
||||
.expect("unable to deserialize message")
|
||||
.expect("eof");
|
||||
|
||||
match returned {
|
||||
PluginInput::Data(id, StreamData::List(list_data)) => {
|
||||
assert_eq!(0, id);
|
||||
assert_eq!(item, list_data);
|
||||
}
|
||||
_ => panic!("decoded into wrong value: {returned:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn input_round_trip_stream_data_raw() {
|
||||
let data = b"Hello world";
|
||||
|
||||
let stream_data = StreamData::Raw(Ok(data.to_vec()));
|
||||
let plugin_input = PluginInput::Data(1, stream_data);
|
||||
|
||||
let encoder = $encoder;
|
||||
let mut buffer: Vec<u8> = Vec::new();
|
||||
encoder
|
||||
.encode(&plugin_input, &mut buffer)
|
||||
.expect("unable to serialize message");
|
||||
let returned = encoder
|
||||
.decode(&mut buffer.as_slice())
|
||||
.expect("unable to deserialize message")
|
||||
.expect("eof");
|
||||
|
||||
match returned {
|
||||
PluginInput::Data(id, StreamData::Raw(bytes)) => {
|
||||
assert_eq!(1, id);
|
||||
match bytes {
|
||||
Ok(bytes) => assert_eq!(data, &bytes[..]),
|
||||
Err(err) => panic!("decoded into error variant: {err:?}"),
|
||||
}
|
||||
}
|
||||
_ => panic!("decoded into wrong value: {returned:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn output_round_trip_stream_data_list() {
|
||||
let span = Span::new(12, 30);
|
||||
let item = Value::int(1, span);
|
||||
|
||||
let stream_data = StreamData::List(item.clone());
|
||||
let plugin_output = PluginOutput::Data(4, stream_data);
|
||||
|
||||
let encoder = $encoder;
|
||||
let mut buffer: Vec<u8> = Vec::new();
|
||||
encoder
|
||||
.encode(&plugin_output, &mut buffer)
|
||||
.expect("unable to serialize message");
|
||||
let returned = encoder
|
||||
.decode(&mut buffer.as_slice())
|
||||
.expect("unable to deserialize message")
|
||||
.expect("eof");
|
||||
|
||||
match returned {
|
||||
PluginOutput::Data(id, StreamData::List(list_data)) => {
|
||||
assert_eq!(4, id);
|
||||
assert_eq!(item, list_data);
|
||||
}
|
||||
_ => panic!("decoded into wrong value: {returned:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn output_round_trip_stream_data_raw() {
|
||||
let data = b"Hello world";
|
||||
|
||||
let stream_data = StreamData::Raw(Ok(data.to_vec()));
|
||||
let plugin_output = PluginOutput::Data(5, stream_data);
|
||||
|
||||
let encoder = $encoder;
|
||||
let mut buffer: Vec<u8> = Vec::new();
|
||||
encoder
|
||||
.encode(&plugin_output, &mut buffer)
|
||||
.expect("unable to serialize message");
|
||||
let returned = encoder
|
||||
.decode(&mut buffer.as_slice())
|
||||
.expect("unable to deserialize message")
|
||||
.expect("eof");
|
||||
|
||||
match returned {
|
||||
PluginOutput::Data(id, StreamData::Raw(bytes)) => {
|
||||
assert_eq!(5, id);
|
||||
match bytes {
|
||||
Ok(bytes) => assert_eq!(data, &bytes[..]),
|
||||
Err(err) => panic!("decoded into error variant: {err:?}"),
|
||||
}
|
||||
}
|
||||
_ => panic!("decoded into wrong value: {returned:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn output_round_trip_option() {
|
||||
let plugin_output = PluginOutput::Option(PluginOption::GcDisabled(true));
|
||||
|
||||
let encoder = $encoder;
|
||||
let mut buffer: Vec<u8> = Vec::new();
|
||||
encoder
|
||||
.encode(&plugin_output, &mut buffer)
|
||||
.expect("unable to serialize message");
|
||||
let returned = encoder
|
||||
.decode(&mut buffer.as_slice())
|
||||
.expect("unable to deserialize message")
|
||||
.expect("eof");
|
||||
|
||||
match returned {
|
||||
PluginOutput::Option(PluginOption::GcDisabled(disabled)) => {
|
||||
assert!(disabled);
|
||||
}
|
||||
_ => panic!("decoded into wrong value: {returned:?}"),
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
pub(crate) use generate_tests;
|
7
crates/nu-plugin-core/src/util/mod.rs
Normal file
7
crates/nu-plugin-core/src/util/mod.rs
Normal file
@ -0,0 +1,7 @@
|
||||
mod sequence;
|
||||
mod waitable;
|
||||
mod with_custom_values_in;
|
||||
|
||||
pub use sequence::Sequence;
|
||||
pub use waitable::*;
|
||||
pub use with_custom_values_in::with_custom_values_in;
|
64
crates/nu-plugin-core/src/util/sequence.rs
Normal file
64
crates/nu-plugin-core/src/util/sequence.rs
Normal file
@ -0,0 +1,64 @@
|
||||
use nu_protocol::ShellError;
|
||||
use std::sync::atomic::{AtomicUsize, Ordering::Relaxed};
|
||||
|
||||
/// Implements an atomically incrementing sequential series of numbers
|
||||
#[derive(Debug, Default)]
|
||||
pub struct Sequence(AtomicUsize);
|
||||
|
||||
impl Sequence {
|
||||
/// Return the next available id from a sequence, returning an error on overflow
|
||||
#[track_caller]
|
||||
pub fn next(&self) -> Result<usize, ShellError> {
|
||||
// It's totally safe to use Relaxed ordering here, as there aren't other memory operations
|
||||
// that depend on this value having been set for safety
|
||||
//
|
||||
// We're only not using `fetch_add` so that we can check for overflow, as wrapping with the
|
||||
// identifier would lead to a serious bug - however unlikely that is.
|
||||
self.0
|
||||
.fetch_update(Relaxed, Relaxed, |current| current.checked_add(1))
|
||||
.map_err(|_| ShellError::NushellFailedHelp {
|
||||
msg: "an accumulator for identifiers overflowed".into(),
|
||||
help: format!("see {}", std::panic::Location::caller()),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn output_is_sequential() {
|
||||
let sequence = Sequence::default();
|
||||
|
||||
for (expected, generated) in (0..1000).zip(std::iter::repeat_with(|| sequence.next())) {
|
||||
assert_eq!(expected, generated.expect("error in sequence"));
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn output_is_unique_even_under_contention() {
|
||||
let sequence = Sequence::default();
|
||||
|
||||
std::thread::scope(|scope| {
|
||||
// Spawn four threads, all advancing the sequence simultaneously
|
||||
let threads = (0..4)
|
||||
.map(|_| {
|
||||
scope.spawn(|| {
|
||||
(0..100000)
|
||||
.map(|_| sequence.next())
|
||||
.collect::<Result<Vec<_>, _>>()
|
||||
})
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
// Collect all of the results into a single flat vec
|
||||
let mut results = threads
|
||||
.into_iter()
|
||||
.flat_map(|thread| thread.join().expect("panicked").expect("error"))
|
||||
.collect::<Vec<usize>>();
|
||||
|
||||
// Check uniqueness
|
||||
results.sort();
|
||||
let initial_length = results.len();
|
||||
results.dedup();
|
||||
let deduplicated_length = results.len();
|
||||
assert_eq!(initial_length, deduplicated_length);
|
||||
})
|
||||
}
|
181
crates/nu-plugin-core/src/util/waitable.rs
Normal file
181
crates/nu-plugin-core/src/util/waitable.rs
Normal file
@ -0,0 +1,181 @@
|
||||
use std::sync::{
|
||||
atomic::{AtomicBool, Ordering},
|
||||
Arc, Condvar, Mutex, MutexGuard, PoisonError,
|
||||
};
|
||||
|
||||
use nu_protocol::ShellError;
|
||||
|
||||
/// A shared container that may be empty, and allows threads to block until it has a value.
|
||||
///
|
||||
/// This side is read-only - use [`WaitableMut`] on threads that might write a value.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Waitable<T: Clone + Send> {
|
||||
shared: Arc<WaitableShared<T>>,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct WaitableMut<T: Clone + Send> {
|
||||
shared: Arc<WaitableShared<T>>,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct WaitableShared<T: Clone + Send> {
|
||||
is_set: AtomicBool,
|
||||
mutex: Mutex<SyncState<T>>,
|
||||
condvar: Condvar,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct SyncState<T: Clone + Send> {
|
||||
writers: usize,
|
||||
value: Option<T>,
|
||||
}
|
||||
|
||||
#[track_caller]
|
||||
fn fail_if_poisoned<'a, T>(
|
||||
result: Result<MutexGuard<'a, T>, PoisonError<MutexGuard<'a, T>>>,
|
||||
) -> Result<MutexGuard<'a, T>, ShellError> {
|
||||
match result {
|
||||
Ok(guard) => Ok(guard),
|
||||
Err(_) => Err(ShellError::NushellFailedHelp {
|
||||
msg: "Waitable mutex poisoned".into(),
|
||||
help: std::panic::Location::caller().to_string(),
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: Clone + Send> WaitableMut<T> {
|
||||
/// Create a new empty `WaitableMut`. Call [`.reader()`] to get [`Waitable`].
|
||||
pub fn new() -> WaitableMut<T> {
|
||||
WaitableMut {
|
||||
shared: Arc::new(WaitableShared {
|
||||
is_set: AtomicBool::new(false),
|
||||
mutex: Mutex::new(SyncState {
|
||||
writers: 1,
|
||||
value: None,
|
||||
}),
|
||||
condvar: Condvar::new(),
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn reader(&self) -> Waitable<T> {
|
||||
Waitable {
|
||||
shared: self.shared.clone(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Set the value and let waiting threads know.
|
||||
#[track_caller]
|
||||
pub fn set(&self, value: T) -> Result<(), ShellError> {
|
||||
let mut sync_state = fail_if_poisoned(self.shared.mutex.lock())?;
|
||||
self.shared.is_set.store(true, Ordering::SeqCst);
|
||||
sync_state.value = Some(value);
|
||||
self.shared.condvar.notify_all();
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: Clone + Send> Default for WaitableMut<T> {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: Clone + Send> Clone for WaitableMut<T> {
|
||||
fn clone(&self) -> Self {
|
||||
let shared = self.shared.clone();
|
||||
shared
|
||||
.mutex
|
||||
.lock()
|
||||
.expect("failed to lock mutex to increment writers")
|
||||
.writers += 1;
|
||||
WaitableMut { shared }
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: Clone + Send> Drop for WaitableMut<T> {
|
||||
fn drop(&mut self) {
|
||||
// Decrement writers...
|
||||
if let Ok(mut sync_state) = self.shared.mutex.lock() {
|
||||
sync_state.writers = sync_state
|
||||
.writers
|
||||
.checked_sub(1)
|
||||
.expect("would decrement writers below zero");
|
||||
}
|
||||
// and notify waiting threads so they have a chance to see it.
|
||||
self.shared.condvar.notify_all();
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: Clone + Send> Waitable<T> {
|
||||
/// Wait for a value to be available and then clone it.
|
||||
///
|
||||
/// Returns `Ok(None)` if there are no writers left that could possibly place a value.
|
||||
#[track_caller]
|
||||
pub fn get(&self) -> Result<Option<T>, ShellError> {
|
||||
let sync_state = fail_if_poisoned(self.shared.mutex.lock())?;
|
||||
if let Some(value) = sync_state.value.clone() {
|
||||
Ok(Some(value))
|
||||
} else if sync_state.writers == 0 {
|
||||
// There can't possibly be a value written, so no point in waiting.
|
||||
Ok(None)
|
||||
} else {
|
||||
let sync_state = fail_if_poisoned(
|
||||
self.shared
|
||||
.condvar
|
||||
.wait_while(sync_state, |g| g.writers > 0 && g.value.is_none()),
|
||||
)?;
|
||||
Ok(sync_state.value.clone())
|
||||
}
|
||||
}
|
||||
|
||||
/// Clone the value if one is available, but don't wait if not.
|
||||
#[track_caller]
|
||||
pub fn try_get(&self) -> Result<Option<T>, ShellError> {
|
||||
let sync_state = fail_if_poisoned(self.shared.mutex.lock())?;
|
||||
Ok(sync_state.value.clone())
|
||||
}
|
||||
|
||||
/// Returns true if value is available.
|
||||
#[track_caller]
|
||||
pub fn is_set(&self) -> bool {
|
||||
self.shared.is_set.load(Ordering::SeqCst)
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn set_from_other_thread() -> Result<(), ShellError> {
|
||||
let waitable_mut = WaitableMut::new();
|
||||
let waitable = waitable_mut.reader();
|
||||
|
||||
assert!(!waitable.is_set());
|
||||
|
||||
std::thread::spawn(move || {
|
||||
waitable_mut.set(42).expect("error on set");
|
||||
});
|
||||
|
||||
assert_eq!(Some(42), waitable.get()?);
|
||||
assert_eq!(Some(42), waitable.try_get()?);
|
||||
assert!(waitable.is_set());
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn dont_deadlock_if_waiting_without_writer() {
|
||||
use std::time::Duration;
|
||||
|
||||
let (tx, rx) = std::sync::mpsc::channel();
|
||||
let writer = WaitableMut::<()>::new();
|
||||
let waitable = writer.reader();
|
||||
// Ensure there are no writers
|
||||
drop(writer);
|
||||
std::thread::spawn(move || {
|
||||
let _ = tx.send(waitable.get());
|
||||
});
|
||||
let result = rx
|
||||
.recv_timeout(Duration::from_secs(10))
|
||||
.expect("timed out")
|
||||
.expect("error");
|
||||
assert!(result.is_none());
|
||||
}
|
96
crates/nu-plugin-core/src/util/with_custom_values_in.rs
Normal file
96
crates/nu-plugin-core/src/util/with_custom_values_in.rs
Normal file
@ -0,0 +1,96 @@
|
||||
use nu_protocol::{CustomValue, IntoSpanned, ShellError, Spanned, Value};
|
||||
|
||||
/// Do something with all [`CustomValue`]s recursively within a `Value`. This is not limited to
|
||||
/// plugin custom values.
|
||||
///
|
||||
/// `LazyRecord`s will be collected to plain values for completeness.
|
||||
pub fn with_custom_values_in<E>(
|
||||
value: &mut Value,
|
||||
mut f: impl FnMut(Spanned<&mut Box<dyn CustomValue>>) -> Result<(), E>,
|
||||
) -> Result<(), E>
|
||||
where
|
||||
E: From<ShellError>,
|
||||
{
|
||||
value.recurse_mut(&mut |value| {
|
||||
let span = value.span();
|
||||
match value {
|
||||
Value::Custom { val, .. } => {
|
||||
// Operate on a CustomValue.
|
||||
f(val.into_spanned(span))
|
||||
}
|
||||
// LazyRecord would be a problem for us, since it could return something else the
|
||||
// next time, and we have to collect it anyway to serialize it. Collect it in place,
|
||||
// and then use the result
|
||||
Value::LazyRecord { val, .. } => {
|
||||
*value = val.collect()?;
|
||||
Ok(())
|
||||
}
|
||||
_ => Ok(()),
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn find_custom_values() {
|
||||
use nu_plugin_protocol::test_util::test_plugin_custom_value;
|
||||
use nu_protocol::{engine::Closure, record, LazyRecord, Span};
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct Lazy;
|
||||
impl<'a> LazyRecord<'a> for Lazy {
|
||||
fn column_names(&'a self) -> Vec<&'a str> {
|
||||
vec!["custom", "plain"]
|
||||
}
|
||||
|
||||
fn get_column_value(&self, column: &str) -> Result<Value, ShellError> {
|
||||
Ok(match column {
|
||||
"custom" => Value::test_custom_value(Box::new(test_plugin_custom_value())),
|
||||
"plain" => Value::test_int(42),
|
||||
_ => unimplemented!(),
|
||||
})
|
||||
}
|
||||
|
||||
fn span(&self) -> Span {
|
||||
Span::test_data()
|
||||
}
|
||||
|
||||
fn clone_value(&self, span: Span) -> Value {
|
||||
Value::lazy_record(Box::new(self.clone()), span)
|
||||
}
|
||||
}
|
||||
|
||||
let mut cv = Value::test_custom_value(Box::new(test_plugin_custom_value()));
|
||||
|
||||
let mut value = Value::test_record(record! {
|
||||
"bare" => cv.clone(),
|
||||
"list" => Value::test_list(vec![
|
||||
cv.clone(),
|
||||
Value::test_int(4),
|
||||
]),
|
||||
"closure" => Value::test_closure(
|
||||
Closure {
|
||||
block_id: 0,
|
||||
captures: vec![(0, cv.clone()), (1, Value::test_string("foo"))]
|
||||
}
|
||||
),
|
||||
"lazy" => Value::test_lazy_record(Box::new(Lazy)),
|
||||
});
|
||||
|
||||
// Do with_custom_values_in, and count the number of custom values found
|
||||
let mut found = 0;
|
||||
with_custom_values_in::<ShellError>(&mut value, |_| {
|
||||
found += 1;
|
||||
Ok(())
|
||||
})
|
||||
.expect("error");
|
||||
assert_eq!(4, found, "found in value");
|
||||
|
||||
// Try it on bare custom value too
|
||||
found = 0;
|
||||
with_custom_values_in::<ShellError>(&mut cv, |_| {
|
||||
found += 1;
|
||||
Ok(())
|
||||
})
|
||||
.expect("error");
|
||||
assert_eq!(1, found, "bare custom value didn't work");
|
||||
}
|
Reference in New Issue
Block a user