diff --git a/src/sync/condvar.rs b/src/sync/condvar.rs new file mode 100644 index 0000000..67507f3 --- /dev/null +++ b/src/sync/condvar.rs @@ -0,0 +1,417 @@ +use std::fmt; +use std::pin::Pin; +use std::time::Duration; + +use super::mutex::{guard_lock, MutexGuard}; +use crate::future::{timeout, Future}; +use crate::sync::WakerSet; +use crate::task::{Context, Poll}; + +#[derive(Debug, PartialEq, Eq, Copy, Clone)] +pub struct WaitTimeoutResult(bool); + +/// A type indicating whether a timed wait on a condition variable returned due to a time out or +/// not +impl WaitTimeoutResult { + /// Returns `true` if the wait was known to have timed out. + pub fn timed_out(self) -> bool { + self.0 + } +} + +/// A Condition Variable +/// +/// This type is an async version of [`std::sync::Mutex`]. +/// +/// [`std::sync::Condvar`]: https://doc.rust-lang.org/std/sync/struct.Condvar.html +/// +/// # Examples +/// +/// ``` +/// # async_std::task::block_on(async { +/// # +/// use std::sync::Arc; +/// +/// use async_std::sync::{Mutex, Condvar}; +/// use async_std::task; +/// +/// let pair = Arc::new((Mutex::new(false), Condvar::new())); +/// let pair2 = pair.clone(); +/// +/// // Inside of our lock, spawn a new thread, and then wait for it to start. +/// task::spawn(async move { +/// let (lock, cvar) = &*pair2; +/// let mut started = lock.lock().await; +/// *started = true; +/// // We notify the condvar that the value has changed. +/// cvar.notify_one(); +/// }); +/// +/// // Wait for the thread to start up. +/// let (lock, cvar) = &*pair; +/// let mut started = lock.lock().await; +/// while !*started { +/// started = cvar.wait(started).await; +/// } +/// +/// # }) +/// ``` +pub struct Condvar { + wakers: WakerSet, +} + +unsafe impl Send for Condvar {} +unsafe impl Sync for Condvar {} + +impl Default for Condvar { + fn default() -> Self { + Condvar::new() + } +} + +impl Condvar { + /// Creates a new condition variable + /// + /// # Examples + /// + /// ``` + /// use async_std::sync::Condvar; + /// + /// let cvar = Condvar::new(); + /// ``` + pub fn new() -> Self { + Condvar { + wakers: WakerSet::new(), + } + } + + /// Blocks the current task until this condition variable receives a notification. + /// + /// Unlike the std equivalent, this does not check that a single mutex is used at runtime. + /// However, as a best practice avoid using with multiple mutexes. + /// + /// # Examples + /// + /// ``` + /// # async_std::task::block_on(async { + /// use std::sync::Arc; + /// + /// use async_std::sync::{Mutex, Condvar}; + /// use async_std::task; + /// + /// let pair = Arc::new((Mutex::new(false), Condvar::new())); + /// let pair2 = pair.clone(); + /// + /// task::spawn(async move { + /// let (lock, cvar) = &*pair2; + /// let mut started = lock.lock().await; + /// *started = true; + /// // We notify the condvar that the value has changed. + /// cvar.notify_one(); + /// }); + /// + /// // Wait for the thread to start up. + /// let (lock, cvar) = &*pair; + /// let mut started = lock.lock().await; + /// while !*started { + /// started = cvar.wait(started).await; + /// } + /// # }) + /// ``` + #[allow(clippy::needless_lifetimes)] + pub async fn wait<'a, T>(&self, guard: MutexGuard<'a, T>) -> MutexGuard<'a, T> { + let mutex = guard_lock(&guard); + + self.await_notify(guard).await; + + mutex.lock().await + } + + fn await_notify<'a, T>(&self, guard: MutexGuard<'a, T>) -> AwaitNotify<'_, 'a, T> { + AwaitNotify { + cond: self, + guard: Some(guard), + key: None, + } + } + + /// Blocks the current taks until this condition variable receives a notification and the + /// required condition is met. Spurious wakeups are ignored and this function will only + /// return once the condition has been met. + /// + /// # Examples + /// + /// ``` + /// # async_std::task::block_on(async { + /// # + /// use std::sync::Arc; + /// + /// use async_std::sync::{Mutex, Condvar}; + /// use async_std::task; + /// + /// let pair = Arc::new((Mutex::new(false), Condvar::new())); + /// let pair2 = pair.clone(); + /// + /// task::spawn(async move { + /// let (lock, cvar) = &*pair2; + /// let mut started = lock.lock().await; + /// *started = true; + /// // We notify the condvar that the value has changed. + /// cvar.notify_one(); + /// }); + /// + /// // Wait for the thread to start up. + /// let (lock, cvar) = &*pair; + /// // As long as the value inside the `Mutex` is `false`, we wait. + /// let _guard = cvar.wait_until(lock.lock().await, |started| { *started }).await; + /// # + /// # }) + /// ``` + #[allow(clippy::needless_lifetimes)] + pub async fn wait_until<'a, T, F>( + &self, + mut guard: MutexGuard<'a, T>, + mut condition: F, + ) -> MutexGuard<'a, T> + where + F: FnMut(&mut T) -> bool, + { + while !condition(&mut *guard) { + guard = self.wait(guard).await; + } + guard + } + + /// Waits on this condition variable for a notification, timing out after a specified duration. + /// + /// For these reasons `Condvar::wait_timeout_until` is recommended in most cases. + /// + /// # Examples + /// + /// ``` + /// # async_std::task::block_on(async { + /// # + /// use std::sync::Arc; + /// use std::time::Duration; + /// + /// use async_std::sync::{Mutex, Condvar}; + /// use async_std::task; + /// + /// let pair = Arc::new((Mutex::new(false), Condvar::new())); + /// let pair2 = pair.clone(); + /// + /// task::spawn(async move { + /// let (lock, cvar) = &*pair2; + /// let mut started = lock.lock().await; + /// *started = true; + /// // We notify the condvar that the value has changed. + /// cvar.notify_one(); + /// }); + /// + /// // wait for the thread to start up + /// let (lock, cvar) = &*pair; + /// let mut started = lock.lock().await; + /// loop { + /// let result = cvar.wait_timeout(started, Duration::from_millis(10)).await; + /// started = result.0; + /// if *started == true { + /// // We received the notification and the value has been updated, we can leave. + /// break + /// } + /// } + /// # + /// # }) + /// ``` + #[allow(clippy::needless_lifetimes)] + pub async fn wait_timeout<'a, T>( + &self, + guard: MutexGuard<'a, T>, + dur: Duration, + ) -> (MutexGuard<'a, T>, WaitTimeoutResult) { + let mutex = guard_lock(&guard); + match timeout(dur, self.wait(guard)).await { + Ok(guard) => (guard, WaitTimeoutResult(false)), + Err(_) => (mutex.lock().await, WaitTimeoutResult(true)), + } + } + + /// Waits on this condition variable for a notification, timing out after a specified duration. + /// Spurious wakes will not cause this function to return. + /// + /// # Examples + /// ``` + /// # async_std::task::block_on(async { + /// use std::sync::Arc; + /// use std::time::Duration; + /// + /// use async_std::sync::{Mutex, Condvar}; + /// use async_std::task; + /// + /// let pair = Arc::new((Mutex::new(false), Condvar::new())); + /// let pair2 = pair.clone(); + /// + /// task::spawn(async move { + /// let (lock, cvar) = &*pair2; + /// let mut started = lock.lock().await; + /// *started = true; + /// // We notify the condvar that the value has changed. + /// cvar.notify_one(); + /// }); + /// + /// // wait for the thread to start up + /// let (lock, cvar) = &*pair; + /// let result = cvar.wait_timeout_until( + /// lock.lock().await, + /// Duration::from_millis(100), + /// |&mut started| started, + /// ).await; + /// if result.1.timed_out() { + /// // timed-out without the condition ever evaluating to true. + /// } + /// // access the locked mutex via result.0 + /// # }); + /// ``` + #[allow(clippy::needless_lifetimes)] + pub async fn wait_timeout_until<'a, T, F>( + &self, + guard: MutexGuard<'a, T>, + dur: Duration, + condition: F, + ) -> (MutexGuard<'a, T>, WaitTimeoutResult) + where + F: FnMut(&mut T) -> bool, + { + let mutex = guard_lock(&guard); + match timeout(dur, self.wait_until(guard, condition)).await { + Ok(guard) => (guard, WaitTimeoutResult(false)), + Err(_) => (mutex.lock().await, WaitTimeoutResult(true)), + } + } + + /// Wakes up one blocked task on this condvar. + /// + /// # Examples + /// + /// ``` + /// # fn main() { async_std::task::block_on(async { + /// use std::sync::Arc; + /// + /// use async_std::sync::{Mutex, Condvar}; + /// use async_std::task; + /// + /// let pair = Arc::new((Mutex::new(false), Condvar::new())); + /// let pair2 = pair.clone(); + /// + /// task::spawn(async move { + /// let (lock, cvar) = &*pair2; + /// let mut started = lock.lock().await; + /// *started = true; + /// // We notify the condvar that the value has changed. + /// cvar.notify_one(); + /// }); + /// + /// // Wait for the thread to start up. + /// let (lock, cvar) = &*pair; + /// let mut started = lock.lock().await; + /// while !*started { + /// started = cvar.wait(started).await; + /// } + /// # }) } + /// ``` + pub fn notify_one(&self) { + self.wakers.notify_one(); + } + + /// Wakes up all blocked tasks on this condvar. + /// + /// # Examples + /// ``` + /// # fn main() { async_std::task::block_on(async { + /// # + /// use std::sync::Arc; + /// + /// use async_std::sync::{Mutex, Condvar}; + /// use async_std::task; + /// + /// let pair = Arc::new((Mutex::new(false), Condvar::new())); + /// let pair2 = pair.clone(); + /// + /// task::spawn(async move { + /// let (lock, cvar) = &*pair2; + /// let mut started = lock.lock().await; + /// *started = true; + /// // We notify the condvar that the value has changed. + /// cvar.notify_all(); + /// }); + /// + /// // Wait for the thread to start up. + /// let (lock, cvar) = &*pair; + /// let mut started = lock.lock().await; + /// // As long as the value inside the `Mutex` is `false`, we wait. + /// while !*started { + /// started = cvar.wait(started).await; + /// } + /// # + /// # }) } + /// ``` + pub fn notify_all(&self) { + self.wakers.notify_all(); + } +} + +impl fmt::Debug for Condvar { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.pad("Condvar { .. }") + } +} + +/// A future that waits for another task to notify the condition variable. +/// +/// This is an internal future that `wait` and `wait_until` await on. +struct AwaitNotify<'a, 'b, T> { + /// The condition variable that we are waiting on + cond: &'a Condvar, + /// The lock used with `cond`. + /// This will be released the first time the future is polled, + /// after registering the context to be notified. + guard: Option>, + /// A key into the conditions variable's `WakerSet`. + /// This is set to the index of the `Waker` for the context each time + /// the future is polled and not completed. + key: Option, +} + +impl<'a, 'b, T> Future for AwaitNotify<'a, 'b, T> { + type Output = (); + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + match self.guard.take() { + Some(_) => { + self.key = Some(self.cond.wakers.insert(cx)); + // the guard is dropped when we return, which frees the lock + Poll::Pending + } + None => { + if let Some(key) = self.key { + if self.cond.wakers.remove_if_notified(key, cx) { + self.key = None; + Poll::Ready(()) + } else { + Poll::Pending + } + } else { + // This should only happen if it is polled twice after receiving a notification + Poll::Ready(()) + } + } + } + } +} + +impl<'a, 'b, T> Drop for AwaitNotify<'a, 'b, T> { + fn drop(&mut self) { + if let Some(key) = self.key { + self.cond.wakers.cancel(key); + } + } +} diff --git a/src/sync/mod.rs b/src/sync/mod.rs index 82759fb..1531f8c 100644 --- a/src/sync/mod.rs +++ b/src/sync/mod.rs @@ -185,8 +185,10 @@ mod rwlock; cfg_unstable! { pub use barrier::{Barrier, BarrierWaitResult}; pub use channel::{channel, Sender, Receiver, RecvError, TryRecvError, TrySendError}; + pub use condvar::Condvar; mod barrier; + mod condvar; mod channel; } diff --git a/src/sync/mutex.rs b/src/sync/mutex.rs index c62b561..ae953fd 100644 --- a/src/sync/mutex.rs +++ b/src/sync/mutex.rs @@ -287,3 +287,8 @@ impl DerefMut for MutexGuard<'_, T> { unsafe { &mut *self.0.value.get() } } } + +#[cfg(feature = "unstable")] +pub fn guard_lock<'a, T>(guard: &MutexGuard<'a, T>) -> &'a Mutex { + guard.0 +} diff --git a/src/sync/waker_set.rs b/src/sync/waker_set.rs index 7e897af..881304b 100644 --- a/src/sync/waker_set.rs +++ b/src/sync/waker_set.rs @@ -80,6 +80,28 @@ impl WakerSet { } } + /// If the waker for this key is still waiting for a notification, then update + /// the waker for the entry, and return false. If the waker has been notified, + /// treat the entry as completed and return true. + #[cfg(feature = "unstable")] + pub fn remove_if_notified(&self, key: usize, cx: &Context<'_>) -> bool { + let mut inner = self.lock(); + + match &mut inner.entries[key] { + None => { + inner.entries.remove(key); + true + } + Some(w) => { + // We were never woken, so update instead + if !w.will_wake(cx.waker()) { + *w = cx.waker().clone(); + } + false + } + } + } + /// Removes the waker of a cancelled operation. /// /// Returns `true` if another blocked operation from the set was notified. diff --git a/tests/condvar.rs b/tests/condvar.rs new file mode 100644 index 0000000..c4d680f --- /dev/null +++ b/tests/condvar.rs @@ -0,0 +1,91 @@ +#![cfg(feature = "unstable")] +use std::sync::Arc; +use std::time::Duration; + +use async_std::sync::{Condvar, Mutex}; +use async_std::task::{self, JoinHandle}; + +#[test] +fn wait_timeout_with_lock() { + task::block_on(async { + let pair = Arc::new((Mutex::new(false), Condvar::new())); + let pair2 = pair.clone(); + + task::spawn(async move { + let (m, c) = &*pair2; + let _g = m.lock().await; + task::sleep(Duration::from_millis(20)).await; + c.notify_one(); + }); + + let (m, c) = &*pair; + let (_, wait_result) = c + .wait_timeout(m.lock().await, Duration::from_millis(10)) + .await; + assert!(wait_result.timed_out()); + }) +} + +#[test] +fn wait_timeout_without_lock() { + task::block_on(async { + let m = Mutex::new(false); + let c = Condvar::new(); + + let (_, wait_result) = c + .wait_timeout(m.lock().await, Duration::from_millis(10)) + .await; + assert!(wait_result.timed_out()); + }) +} + +#[test] +fn wait_timeout_until_timed_out() { + task::block_on(async { + let m = Mutex::new(false); + let c = Condvar::new(); + + let (_, wait_result) = c + .wait_timeout_until(m.lock().await, Duration::from_millis(10), |&mut started| { + started + }) + .await; + assert!(wait_result.timed_out()); + }) +} + +#[test] +fn notify_all() { + task::block_on(async { + let mut tasks: Vec> = Vec::new(); + let pair = Arc::new((Mutex::new(0u32), Condvar::new())); + + for _ in 0..10 { + let pair = pair.clone(); + tasks.push(task::spawn(async move { + let (m, c) = &*pair; + let mut count = m.lock().await; + while *count == 0 { + count = c.wait(count).await; + } + *count += 1; + })); + } + + // Give some time for tasks to start up + task::sleep(Duration::from_millis(5)).await; + + let (m, c) = &*pair; + { + let mut count = m.lock().await; + *count += 1; + c.notify_all(); + } + + for t in tasks { + t.await; + } + let count = m.lock().await; + assert_eq!(11, *count); + }) +}