From b77b72d33308198ec2e0b8f5541080fdfdc1e66f Mon Sep 17 00:00:00 2001 From: dignifiedquire Date: Wed, 25 Sep 2019 19:36:59 +0200 Subject: [PATCH] feat: implement sync::Barrier Based on the implementation in https://github.com/tokio-rs/tokio/pull/1571 --- Cargo.toml | 1 + src/sync/barrier.rs | 174 ++++++++++++++++++++++++++++++++++++++++++++ src/sync/mod.rs | 2 + tests/barrier.rs | 52 +++++++++++++ 4 files changed, 229 insertions(+) create mode 100644 src/sync/barrier.rs create mode 100644 tests/barrier.rs diff --git a/Cargo.toml b/Cargo.toml index f768f81..dd614b1 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -42,6 +42,7 @@ num_cpus = "1.10.1" pin-utils = "0.1.0-alpha.4" slab = "0.4.2" kv-log-macro = "1.0.4" +broadcaster = "0.2.4" [dev-dependencies] femme = "1.2.0" diff --git a/src/sync/barrier.rs b/src/sync/barrier.rs new file mode 100644 index 0000000..40b042c --- /dev/null +++ b/src/sync/barrier.rs @@ -0,0 +1,174 @@ +use broadcaster::BroadcastChannel; + +use crate::sync::Mutex; + +/// A barrier enables multiple tasks to synchronize the beginning +/// of some computation. +/// +/// ``` +/// # fn main() { async_std::task::block_on(async { +/// # +/// use std::sync::Arc; +/// use async_std::sync::Barrier; +/// use async_std::task; +/// +/// let mut handles = Vec::with_capacity(10); +/// let barrier = Arc::new(Barrier::new(10)); +/// for _ in 0..10 { +/// let c = barrier.clone(); +/// // The same messages will be printed together. +/// // You will NOT see any interleaving. +/// handles.push(task::spawn(async move { +/// println!("before wait"); +/// let wr = c.wait().await; +/// println!("after wait"); +/// wr +/// })); +/// } +/// // Wait for the other futures to finish. +/// for handle in handles { +/// handle.await; +/// } +/// # }); +/// # } +/// ``` +#[derive(Debug)] +pub struct Barrier { + state: Mutex, + wait: BroadcastChannel<(usize, usize)>, + n: usize, +} + +// The inner state of a double barrier +#[derive(Debug)] +struct BarrierState { + waker: BroadcastChannel<(usize, usize)>, + count: usize, + generation_id: usize, +} + +/// A `BarrierWaitResult` is returned by `wait` when all threads in the `Barrier` have rendezvoused. +/// +/// [`wait`]: struct.Barrier.html#method.wait +/// [`Barrier`]: struct.Barrier.html +/// +/// # Examples +/// +/// ``` +/// use async_std::sync::Barrier; +/// +/// let barrier = Barrier::new(1); +/// let barrier_wait_result = barrier.wait(); +/// ``` +#[derive(Debug, Clone)] +pub struct BarrierWaitResult(bool); + +impl Barrier { + /// Creates a new barrier that can block a given number of tasks. + /// + /// A barrier will block `n`-1 tasks which call [`wait`] and then wake up + /// all tasks at once when the `n`th task calls [`wait`]. + /// + /// [`wait`]: #method.wait + /// + /// # Examples + /// + /// ``` + /// use std::sync::Barrier; + /// + /// let barrier = Barrier::new(10); + /// ``` + pub fn new(mut n: usize) -> Barrier { + let waker = BroadcastChannel::new(); + let wait = waker.clone(); + + if n == 0 { + // if n is 0, it's not clear what behavior the user wants. + // in std::sync::Barrier, an n of 0 exhibits the same behavior as n == 1, where every + // .wait() immediately unblocks, so we adopt that here as well. + n = 1; + } + + Barrier { + state: Mutex::new(BarrierState { + waker, + count: 0, + generation_id: 1, + }), + n, + wait, + } + } + + /// Blocks the current task until all tasks have rendezvoused here. + /// + /// Barriers are re-usable after all tasks have rendezvoused once, and can + /// be used continuously. + /// + /// A single (arbitrary) task will receive a [`BarrierWaitResult`] that + /// returns `true` from [`is_leader`] when returning from this function, and + /// all other tasks will receive a result that will return `false` from + /// [`is_leader`]. + /// + /// [`BarrierWaitResult`]: struct.BarrierWaitResult.html + /// [`is_leader`]: struct.BarrierWaitResult.html#method.is_leader + pub async fn wait(&self) -> BarrierWaitResult { + let mut lock = self.state.lock().await; + let local_gen = lock.generation_id; + + lock.count += 1; + + if lock.count < self.n { + let mut wait = self.wait.clone(); + + let mut generation_id = lock.generation_id; + let mut count = lock.count; + + drop(lock); + + while local_gen == generation_id && count < self.n { + let (g, c) = wait.recv().await.expect("sender hasn not been closed"); + generation_id = g; + count = c; + } + + BarrierWaitResult(false) + } else { + lock.count = 0; + lock.generation_id = lock.generation_id.wrapping_add(1); + + lock.waker + .send(&(lock.generation_id, lock.count)) + .await + .expect("there should be at least one receiver"); + + BarrierWaitResult(true) + } + } +} + +impl BarrierWaitResult { + /// Returns `true` if this task from [`wait`] is the "leader task". + /// + /// Only one task will have `true` returned from their result, all other + /// tasks will have `false` returned. + /// + /// [`wait`]: struct.Barrier.html#method.wait + /// + /// # Examples + /// + /// ``` + /// # fn main() { async_std::task::block_on(async { + /// # + /// use async_std::sync::Barrier; + /// + /// let barrier = Barrier::new(1); + /// let barrier_wait_result = barrier.wait().await; + /// println!("{:?}", barrier_wait_result.is_leader()); + /// # }); + /// # } + /// ``` + pub fn is_leader(&self) -> bool { + self.0 + } +} diff --git a/src/sync/mod.rs b/src/sync/mod.rs index 1a8e255..99e63c3 100644 --- a/src/sync/mod.rs +++ b/src/sync/mod.rs @@ -32,8 +32,10 @@ #[doc(inline)] pub use std::sync::{Arc, Weak}; +pub use barrier::{Barrier, BarrierWaitResult}; pub use mutex::{Mutex, MutexGuard}; pub use rwlock::{RwLock, RwLockReadGuard, RwLockWriteGuard}; +mod barrier; mod mutex; mod rwlock; diff --git a/tests/barrier.rs b/tests/barrier.rs new file mode 100644 index 0000000..3284944 --- /dev/null +++ b/tests/barrier.rs @@ -0,0 +1,52 @@ +use std::sync::Arc; + +use futures_channel::mpsc::unbounded; +use futures_util::sink::SinkExt; +use futures_util::stream::StreamExt; + +use async_std::sync::Barrier; +use async_std::task; + +#[test] +fn test_barrier() { + // Based on the test in std, I was seeing some race conditions, so running it in a loop to make sure + // things are solid. + + for _ in 0..1_000 { + task::block_on(async move { + const N: usize = 10; + + let barrier = Arc::new(Barrier::new(N)); + let (tx, mut rx) = unbounded(); + + for _ in 0..N - 1 { + let c = barrier.clone(); + let mut tx = tx.clone(); + task::spawn(async move { + let res = c.wait().await; + + tx.send(res.is_leader()).await.unwrap(); + }); + } + + // At this point, all spawned threads should be blocked, + // so we shouldn't get anything from the port + let res = rx.try_next(); + assert!(match res { + Err(_err) => true, + _ => false, + }); + + let mut leader_found = barrier.wait().await.is_leader(); + + // Now, the barrier is cleared and we should get data. + for _ in 0..N - 1 { + if rx.next().await.unwrap() { + assert!(!leader_found); + leader_found = true; + } + } + assert!(leader_found); + }); + } +}