diff --git a/crates/nu-command/src/experimental/job_spawn.rs b/crates/nu-command/src/experimental/job_spawn.rs index b0061a40fa..15b1a7faef 100644 --- a/crates/nu-command/src/experimental/job_spawn.rs +++ b/crates/nu-command/src/experimental/job_spawn.rs @@ -8,7 +8,7 @@ use std::{ use nu_engine::{command_prelude::*, ClosureEvalOnce}; use nu_protocol::{ - engine::{Closure, Job, Redirection, ThreadJob, WaitSignal}, + engine::{completion_signal, Closure, Job, Redirection, ThreadJob}, report_shell_error, OutDest, Signals, }; @@ -77,10 +77,10 @@ impl Command for JobSpawn { let jobs = job_state.jobs.clone(); let mut jobs = jobs.lock().expect("jobs lock is poisoned!"); - let on_termination = Arc::new(WaitSignal::new()); + let (complete, wait) = completion_signal(); let id = { - let thread_job = ThreadJob::new(job_signals, tag, on_termination.clone()); + let thread_job = ThreadJob::new(job_signals, tag, wait); job_state.current_thread_job = Some(thread_job.clone()); jobs.add_job(Job::Thread(thread_job)) }; @@ -104,7 +104,7 @@ impl Command for JobSpawn { Value::error(err, head) }); - on_termination.signal(result_value); + complete.complete(result_value); { let mut jobs = job_state.jobs.lock().expect("jobs lock is poisoned!"); diff --git a/crates/nu-command/src/experimental/job_wait.rs b/crates/nu-command/src/experimental/job_wait.rs index 50a2252ea9..d95065555e 100644 --- a/crates/nu-command/src/experimental/job_wait.rs +++ b/crates/nu-command/src/experimental/job_wait.rs @@ -60,12 +60,12 @@ impl Command for JobWait { } Some(Job::Thread(job)) => { - let on_termination = job.on_termination().clone(); + let waiter = job.on_termination().clone(); - // .join() blocks so we drop our mutex guard + // .wait() blocks so we drop our mutex guard drop(jobs); - let result = on_termination.join().clone().with_span(head); + let result = waiter.wait().clone().with_span(head); Ok(result.into_pipeline_data()) } diff --git a/crates/nu-protocol/src/engine/jobs.rs b/crates/nu-protocol/src/engine/jobs.rs index c9bbbe48f1..7f974b17c7 100644 --- a/crates/nu-protocol/src/engine/jobs.rs +++ b/crates/nu-protocol/src/engine/jobs.rs @@ -139,15 +139,11 @@ pub struct ThreadJob { signals: Signals, pids: Arc>>, tag: Option, - on_termination: Arc>, + on_termination: Waiter, } impl ThreadJob { - pub fn new( - signals: Signals, - tag: Option, - on_termination: Arc>, - ) -> Self { + pub fn new(signals: Signals, tag: Option, on_termination: Waiter) -> Self { ThreadJob { signals, pids: Arc::new(Mutex::new(HashSet::default())), @@ -201,7 +197,7 @@ impl ThreadJob { pids.remove(&pid); } - pub fn on_termination(&self) -> &Arc> { + pub fn on_termination(&self) -> &Waiter { return &self.on_termination; } } @@ -251,95 +247,129 @@ impl FrozenJob { use std::sync::OnceLock; -/// A synchronization primitive that allows multiple threads to wait for a single event to occur. +/// A synchronization primitive that allows multiple threads to wait for a single event to be completed. /// -/// Threads that call the [`join`] method will block until the [`signal`] method is called. -/// Once [`signal`] is called, all currently waiting threads will be woken up and will return from their `join` calls. -/// Subsequent calls to [`join`] will not block and will return immediately. +/// A Waiter/Completer pair is similar to a Receiver/Sender pair from std::sync::mpsc, with a few important differences: +/// - Only one value can only be sent/completed, subsequent completions are ignored +/// - Multiple threads can wait for the completion of an event (`Waiter` is `Clone` unlike `Receiver`) /// -/// The [`was_signaled`] method can be used to check if the signal has been emitted without blocking. -pub struct WaitSignal { - mutex: std::sync::Mutex, - value: std::sync::OnceLock, - var: std::sync::Condvar, +/// This type differs from `OnceLock` only in a few regards: +/// - It is split into `Waiter` and `Completer` halfs +/// - It allows users to `wait` on the completion event with a timeout +/// +/// Threads that call the [`wait`] method of the `Waiter` block until the [`complete`] method of a matching `Completer` is called. +/// Once [`complete`] is called, all currently waiting threads will be woken up and will return from their `wait` calls. +/// Subsequent calls to [`wait`] will not block and will return immediately. +/// +pub fn completion_signal() -> (Completer, Waiter) { + let inner = Arc::new(InnerWaitCompleteSignal::new()); + + return ( + Completer { + inner: inner.clone(), + }, + Waiter { inner }, + ); } -impl WaitSignal { - /// Creates a new `WaitSignal` in an unsignaled state. - /// - /// No threads will be woken up initially. +/// Waiter and Completer are effectively just `Arc` wrappers around this type. +struct InnerWaitCompleteSignal { + // One may ask: "Why the mutex and the convar"? + // It turns out OnceLock doesn't have a `wait_timeout` method, so + // we use the one from the condvar. + // + // We once again, assume acquire-release semamntics for Rust mutexes + mutex: std::sync::Mutex<()>, + var: std::sync::Condvar, + value: std::sync::OnceLock, +} + +impl InnerWaitCompleteSignal { pub fn new() -> Self { - WaitSignal { - mutex: std::sync::Mutex::new(false), + InnerWaitCompleteSignal { + mutex: std::sync::Mutex::new(()), value: OnceLock::new(), var: std::sync::Condvar::new(), } } +} - /// Blocks the current thread until this `WaitSignal` is signaled. +#[derive(Clone)] +pub struct Waiter { + inner: Arc>, +} + +pub struct Completer { + inner: Arc>, +} + +impl Waiter { + /// Blocks the current thread until a completion signal is sent. /// /// If the signal has already been emitted, this method returns immediately. /// - /// # Panics - /// - /// This method will panic if the underlying mutex is poisoned. This can happen if another - /// thread holding the mutex panics. - pub fn join(&self) -> &T { - let mut guard = self.mutex.lock().expect("mutex is poisoned!"); + pub fn wait(&self) -> &T { + let inner: &InnerWaitCompleteSignal = self.inner.as_ref(); - while !*guard { - match self.var.wait(guard) { - Ok(it) => guard = it, - Err(_) => panic!("mutex is poisoned!"), + let mut guard = inner.mutex.lock().expect("mutex is poisoned!"); + + loop { + match inner.value.get() { + None => match inner.var.wait(guard) { + Ok(it) => guard = it, + Err(_) => panic!("mutex is poisoned!"), + }, + Some(value) => return value, } } - - return self.value.get().unwrap(); } - /// Signals all threads currently waiting on this `WaitSignal`. - /// - /// This method sets the internal state to "signaled" and wakes up all threads that are blocked - /// in the [`join`] method. Subsequent calls to [`join`] from any thread will return immediately. - /// This operation has no effect if the signal has already been emitted. - pub fn signal(&self, value: T) { - let mut guard = self.mutex.lock().expect("mutex is poisoned!"); + // TODO: add wait_timeout - *guard = true; - let _ = self.value.set(value); - - self.var.notify_all(); - } - - /// Checks if this `WaitSignal` has been signaled. + /// Checks if this completion signal has been signaled. /// /// This method returns `true` if the [`signal`] method has been called at least once, /// and `false` otherwise. This method does not block the current thread. /// - /// # Panics - /// - /// This method will panic if the underlying mutex is poisoned. This can happen if another - /// thread holding the mutex panics. - pub fn was_signaled(&self) -> bool { - let guard = self.mutex.lock().expect("mutex is poisoned!"); + pub fn is_completed(&self) -> bool { + self.try_get().is_some() + } - *guard + /// Returns the completed value, or None if none was sent. + pub fn try_get(&self) -> Option<&T> { + let _guard = self.inner.mutex.lock().expect("mutex is poisoned!"); + + self.inner.value.get() + } +} + +impl Completer { + /// Signals all threads currently waiting on this completion signal. + /// + /// This method sets wakes up all threads that are blocked in the [`wait`] method + /// of an attached `Waiter`. Subsequent calls to [`wait`] from any thread will return immediately. + /// This operation has no effect if this completion signal has already been completed. + pub fn complete(&self, value: T) { + let inner: &InnerWaitCompleteSignal = self.inner.as_ref(); + + let mut _guard = inner.mutex.lock().expect("mutex is poisoned!"); + + let _ = inner.value.set(value); + + inner.var.notify_all(); } } -// TODO: move to testing directory #[cfg(test)] -mod test { +mod completion_signal_tests { use std::{ - sync::{mpsc, Arc}, + sync::mpsc, thread::{self, sleep}, time::Duration, }; - use pretty_assertions::assert_eq; - - use crate::engine::jobs::WaitSignal; + use crate::engine::completion_signal; fn run_with_timeout(duration: Duration, lambda: F) where @@ -363,52 +393,51 @@ mod test { send.send(false).expect("send failed"); }); - let result = recv.recv().expect("recv failed!"); + let ok = recv.recv().expect("recv failed!"); - assert!(result == true, "timeout!"); + assert!(ok, "got timeout!"); } #[test] - fn join_returns_when_signaled_from_another_thread() { + fn wait_returns_when_signaled_from_another_thread() { run_with_timeout(Duration::from_secs(1), || { - let signal = Arc::new(WaitSignal::new()); + let (complete, wait) = completion_signal(); - let thread_signal = signal.clone(); + let wait_ = wait.clone(); thread::spawn(move || { sleep(Duration::from_millis(200)); - assert!(!thread_signal.was_signaled()); - thread_signal.signal(123); + assert!(!wait_.is_completed()); + complete.complete(123); }); - let result = signal.join(); + let result = wait.wait(); - assert!(signal.was_signaled()); + assert!(wait.is_completed()); assert_eq!(*result, 123); }); } #[test] - fn join_works_from_multiple_threads() { + fn wait_works_from_multiple_threads() { run_with_timeout(Duration::from_secs(1), || { - let signal = Arc::new(WaitSignal::new()); - + let (complete, wait) = completion_signal(); let (send, recv) = mpsc::channel(); let thread_count = 4; for _ in 0..thread_count { - let signal_ = signal.clone(); + let wait_ = wait.clone(); let send_ = send.clone(); thread::spawn(move || { - let value = signal_.join(); + let value = wait_.wait(); send_.send(*value).expect("send failed"); }); } - signal.signal(321); + complete.complete(321); for _ in 0..thread_count { let result = recv.recv().expect("recv failed"); @@ -420,30 +449,30 @@ mod test { #[test] fn was_signaled_returns_false_when_struct_is_initalized() { - let signal = Arc::new(WaitSignal::<()>::new()); + let (_, wait) = completion_signal::<()>(); - assert!(!signal.was_signaled()) + assert!(!wait.is_completed()) } #[test] fn was_signaled_returns_true_when_signal_is_called() { - let signal = Arc::new(WaitSignal::new()); + let (complete, wait) = completion_signal(); - signal.signal(()); + complete.complete(()); - assert!(signal.was_signaled()) + assert!(wait.is_completed()) } #[test] - fn join_returns_when_own_thread_signals() { + fn wait_returns_when_own_thread_signals() { run_with_timeout(Duration::from_secs(1), || { - let signal = Arc::new(WaitSignal::new()); + let (complete, wait) = completion_signal(); - signal.signal(()); + complete.complete(()); - signal.join(); + wait.wait(); - assert!(signal.was_signaled()) + assert!(wait.is_completed()) }) } }