Refactor 'JobSignal' into Waiter and Completer

This commit is contained in:
Renan Ribeiro 2025-04-21 18:03:16 -03:00
parent 75db02781d
commit d3cbf728f2
3 changed files with 123 additions and 94 deletions

View File

@ -8,7 +8,7 @@ use std::{
use nu_engine::{command_prelude::*, ClosureEvalOnce}; use nu_engine::{command_prelude::*, ClosureEvalOnce};
use nu_protocol::{ use nu_protocol::{
engine::{Closure, Job, Redirection, ThreadJob, WaitSignal}, engine::{completion_signal, Closure, Job, Redirection, ThreadJob},
report_shell_error, OutDest, Signals, report_shell_error, OutDest, Signals,
}; };
@ -77,10 +77,10 @@ impl Command for JobSpawn {
let jobs = job_state.jobs.clone(); let jobs = job_state.jobs.clone();
let mut jobs = jobs.lock().expect("jobs lock is poisoned!"); let mut jobs = jobs.lock().expect("jobs lock is poisoned!");
let on_termination = Arc::new(WaitSignal::new()); let (complete, wait) = completion_signal();
let id = { let id = {
let thread_job = ThreadJob::new(job_signals, tag, on_termination.clone()); let thread_job = ThreadJob::new(job_signals, tag, wait);
job_state.current_thread_job = Some(thread_job.clone()); job_state.current_thread_job = Some(thread_job.clone());
jobs.add_job(Job::Thread(thread_job)) jobs.add_job(Job::Thread(thread_job))
}; };
@ -104,7 +104,7 @@ impl Command for JobSpawn {
Value::error(err, head) Value::error(err, head)
}); });
on_termination.signal(result_value); complete.complete(result_value);
{ {
let mut jobs = job_state.jobs.lock().expect("jobs lock is poisoned!"); let mut jobs = job_state.jobs.lock().expect("jobs lock is poisoned!");

View File

@ -60,12 +60,12 @@ impl Command for JobWait {
} }
Some(Job::Thread(job)) => { Some(Job::Thread(job)) => {
let on_termination = job.on_termination().clone(); let waiter = job.on_termination().clone();
// .join() blocks so we drop our mutex guard // .wait() blocks so we drop our mutex guard
drop(jobs); drop(jobs);
let result = on_termination.join().clone().with_span(head); let result = waiter.wait().clone().with_span(head);
Ok(result.into_pipeline_data()) Ok(result.into_pipeline_data())
} }

View File

@ -139,15 +139,11 @@ pub struct ThreadJob {
signals: Signals, signals: Signals,
pids: Arc<Mutex<HashSet<u32>>>, pids: Arc<Mutex<HashSet<u32>>>,
tag: Option<String>, tag: Option<String>,
on_termination: Arc<WaitSignal<Value>>, on_termination: Waiter<Value>,
} }
impl ThreadJob { impl ThreadJob {
pub fn new( pub fn new(signals: Signals, tag: Option<String>, on_termination: Waiter<Value>) -> Self {
signals: Signals,
tag: Option<String>,
on_termination: Arc<WaitSignal<Value>>,
) -> Self {
ThreadJob { ThreadJob {
signals, signals,
pids: Arc::new(Mutex::new(HashSet::default())), pids: Arc::new(Mutex::new(HashSet::default())),
@ -201,7 +197,7 @@ impl ThreadJob {
pids.remove(&pid); pids.remove(&pid);
} }
pub fn on_termination(&self) -> &Arc<WaitSignal<Value>> { pub fn on_termination(&self) -> &Waiter<Value> {
return &self.on_termination; return &self.on_termination;
} }
} }
@ -251,95 +247,129 @@ impl FrozenJob {
use std::sync::OnceLock; use std::sync::OnceLock;
/// A synchronization primitive that allows multiple threads to wait for a single event to occur. /// A synchronization primitive that allows multiple threads to wait for a single event to be completed.
/// ///
/// Threads that call the [`join`] method will block until the [`signal`] method is called. /// A Waiter/Completer pair is similar to a Receiver/Sender pair from std::sync::mpsc, with a few important differences:
/// Once [`signal`] is called, all currently waiting threads will be woken up and will return from their `join` calls. /// - Only one value can only be sent/completed, subsequent completions are ignored
/// Subsequent calls to [`join`] will not block and will return immediately. /// - Multiple threads can wait for the completion of an event (`Waiter` is `Clone` unlike `Receiver`)
/// ///
/// The [`was_signaled`] method can be used to check if the signal has been emitted without blocking. /// This type differs from `OnceLock` only in a few regards:
pub struct WaitSignal<T> { /// - It is split into `Waiter` and `Completer` halfs
mutex: std::sync::Mutex<bool>, /// - It allows users to `wait` on the completion event with a timeout
value: std::sync::OnceLock<T>, ///
var: std::sync::Condvar, /// Threads that call the [`wait`] method of the `Waiter` block until the [`complete`] method of a matching `Completer` is called.
/// Once [`complete`] is called, all currently waiting threads will be woken up and will return from their `wait` calls.
/// Subsequent calls to [`wait`] will not block and will return immediately.
///
pub fn completion_signal<T>() -> (Completer<T>, Waiter<T>) {
let inner = Arc::new(InnerWaitCompleteSignal::new());
return (
Completer {
inner: inner.clone(),
},
Waiter { inner },
);
} }
impl<T> WaitSignal<T> { /// Waiter and Completer are effectively just `Arc` wrappers around this type.
/// Creates a new `WaitSignal` in an unsignaled state. struct InnerWaitCompleteSignal<T> {
/// // One may ask: "Why the mutex and the convar"?
/// No threads will be woken up initially. // It turns out OnceLock doesn't have a `wait_timeout` method, so
// we use the one from the condvar.
//
// We once again, assume acquire-release semamntics for Rust mutexes
mutex: std::sync::Mutex<()>,
var: std::sync::Condvar,
value: std::sync::OnceLock<T>,
}
impl<T> InnerWaitCompleteSignal<T> {
pub fn new() -> Self { pub fn new() -> Self {
WaitSignal { InnerWaitCompleteSignal {
mutex: std::sync::Mutex::new(false), mutex: std::sync::Mutex::new(()),
value: OnceLock::new(), value: OnceLock::new(),
var: std::sync::Condvar::new(), var: std::sync::Condvar::new(),
} }
} }
}
/// Blocks the current thread until this `WaitSignal` is signaled. #[derive(Clone)]
pub struct Waiter<T> {
inner: Arc<InnerWaitCompleteSignal<T>>,
}
pub struct Completer<T> {
inner: Arc<InnerWaitCompleteSignal<T>>,
}
impl<T> Waiter<T> {
/// Blocks the current thread until a completion signal is sent.
/// ///
/// If the signal has already been emitted, this method returns immediately. /// If the signal has already been emitted, this method returns immediately.
/// ///
/// # Panics pub fn wait(&self) -> &T {
/// let inner: &InnerWaitCompleteSignal<T> = self.inner.as_ref();
/// This method will panic if the underlying mutex is poisoned. This can happen if another
/// thread holding the mutex panics.
pub fn join(&self) -> &T {
let mut guard = self.mutex.lock().expect("mutex is poisoned!");
while !*guard { let mut guard = inner.mutex.lock().expect("mutex is poisoned!");
match self.var.wait(guard) {
Ok(it) => guard = it, loop {
Err(_) => panic!("mutex is poisoned!"), match inner.value.get() {
None => match inner.var.wait(guard) {
Ok(it) => guard = it,
Err(_) => panic!("mutex is poisoned!"),
},
Some(value) => return value,
} }
} }
return self.value.get().unwrap();
} }
/// Signals all threads currently waiting on this `WaitSignal`. // TODO: add wait_timeout
///
/// This method sets the internal state to "signaled" and wakes up all threads that are blocked
/// in the [`join`] method. Subsequent calls to [`join`] from any thread will return immediately.
/// This operation has no effect if the signal has already been emitted.
pub fn signal(&self, value: T) {
let mut guard = self.mutex.lock().expect("mutex is poisoned!");
*guard = true; /// Checks if this completion signal has been signaled.
let _ = self.value.set(value);
self.var.notify_all();
}
/// Checks if this `WaitSignal` has been signaled.
/// ///
/// This method returns `true` if the [`signal`] method has been called at least once, /// This method returns `true` if the [`signal`] method has been called at least once,
/// and `false` otherwise. This method does not block the current thread. /// and `false` otherwise. This method does not block the current thread.
/// ///
/// # Panics pub fn is_completed(&self) -> bool {
/// self.try_get().is_some()
/// This method will panic if the underlying mutex is poisoned. This can happen if another }
/// thread holding the mutex panics.
pub fn was_signaled(&self) -> bool {
let guard = self.mutex.lock().expect("mutex is poisoned!");
*guard /// Returns the completed value, or None if none was sent.
pub fn try_get(&self) -> Option<&T> {
let _guard = self.inner.mutex.lock().expect("mutex is poisoned!");
self.inner.value.get()
}
}
impl<T> Completer<T> {
/// Signals all threads currently waiting on this completion signal.
///
/// This method sets wakes up all threads that are blocked in the [`wait`] method
/// of an attached `Waiter`. Subsequent calls to [`wait`] from any thread will return immediately.
/// This operation has no effect if this completion signal has already been completed.
pub fn complete(&self, value: T) {
let inner: &InnerWaitCompleteSignal<T> = self.inner.as_ref();
let mut _guard = inner.mutex.lock().expect("mutex is poisoned!");
let _ = inner.value.set(value);
inner.var.notify_all();
} }
} }
// TODO: move to testing directory
#[cfg(test)] #[cfg(test)]
mod test { mod completion_signal_tests {
use std::{ use std::{
sync::{mpsc, Arc}, sync::mpsc,
thread::{self, sleep}, thread::{self, sleep},
time::Duration, time::Duration,
}; };
use pretty_assertions::assert_eq; use crate::engine::completion_signal;
use crate::engine::jobs::WaitSignal;
fn run_with_timeout<F>(duration: Duration, lambda: F) fn run_with_timeout<F>(duration: Duration, lambda: F)
where where
@ -363,52 +393,51 @@ mod test {
send.send(false).expect("send failed"); send.send(false).expect("send failed");
}); });
let result = recv.recv().expect("recv failed!"); let ok = recv.recv().expect("recv failed!");
assert!(result == true, "timeout!"); assert!(ok, "got timeout!");
} }
#[test] #[test]
fn join_returns_when_signaled_from_another_thread() { fn wait_returns_when_signaled_from_another_thread() {
run_with_timeout(Duration::from_secs(1), || { run_with_timeout(Duration::from_secs(1), || {
let signal = Arc::new(WaitSignal::new()); let (complete, wait) = completion_signal();
let thread_signal = signal.clone(); let wait_ = wait.clone();
thread::spawn(move || { thread::spawn(move || {
sleep(Duration::from_millis(200)); sleep(Duration::from_millis(200));
assert!(!thread_signal.was_signaled()); assert!(!wait_.is_completed());
thread_signal.signal(123); complete.complete(123);
}); });
let result = signal.join(); let result = wait.wait();
assert!(signal.was_signaled()); assert!(wait.is_completed());
assert_eq!(*result, 123); assert_eq!(*result, 123);
}); });
} }
#[test] #[test]
fn join_works_from_multiple_threads() { fn wait_works_from_multiple_threads() {
run_with_timeout(Duration::from_secs(1), || { run_with_timeout(Duration::from_secs(1), || {
let signal = Arc::new(WaitSignal::new()); let (complete, wait) = completion_signal();
let (send, recv) = mpsc::channel(); let (send, recv) = mpsc::channel();
let thread_count = 4; let thread_count = 4;
for _ in 0..thread_count { for _ in 0..thread_count {
let signal_ = signal.clone(); let wait_ = wait.clone();
let send_ = send.clone(); let send_ = send.clone();
thread::spawn(move || { thread::spawn(move || {
let value = signal_.join(); let value = wait_.wait();
send_.send(*value).expect("send failed"); send_.send(*value).expect("send failed");
}); });
} }
signal.signal(321); complete.complete(321);
for _ in 0..thread_count { for _ in 0..thread_count {
let result = recv.recv().expect("recv failed"); let result = recv.recv().expect("recv failed");
@ -420,30 +449,30 @@ mod test {
#[test] #[test]
fn was_signaled_returns_false_when_struct_is_initalized() { fn was_signaled_returns_false_when_struct_is_initalized() {
let signal = Arc::new(WaitSignal::<()>::new()); let (_, wait) = completion_signal::<()>();
assert!(!signal.was_signaled()) assert!(!wait.is_completed())
} }
#[test] #[test]
fn was_signaled_returns_true_when_signal_is_called() { fn was_signaled_returns_true_when_signal_is_called() {
let signal = Arc::new(WaitSignal::new()); let (complete, wait) = completion_signal();
signal.signal(()); complete.complete(());
assert!(signal.was_signaled()) assert!(wait.is_completed())
} }
#[test] #[test]
fn join_returns_when_own_thread_signals() { fn wait_returns_when_own_thread_signals() {
run_with_timeout(Duration::from_secs(1), || { run_with_timeout(Duration::from_secs(1), || {
let signal = Arc::new(WaitSignal::new()); let (complete, wait) = completion_signal();
signal.signal(()); complete.complete(());
signal.join(); wait.wait();
assert!(signal.was_signaled()) assert!(wait.is_completed())
}) })
} }
} }