feat: implement sync::Barrier

Based on the implementation in https://github.com/tokio-rs/tokio/pull/1571
This commit is contained in:
dignifiedquire 2019-09-25 19:36:59 +02:00
parent 785371cbc4
commit b77b72d333
4 changed files with 229 additions and 0 deletions

View file

@ -42,6 +42,7 @@ num_cpus = "1.10.1"
pin-utils = "0.1.0-alpha.4" pin-utils = "0.1.0-alpha.4"
slab = "0.4.2" slab = "0.4.2"
kv-log-macro = "1.0.4" kv-log-macro = "1.0.4"
broadcaster = "0.2.4"
[dev-dependencies] [dev-dependencies]
femme = "1.2.0" femme = "1.2.0"

174
src/sync/barrier.rs Normal file
View file

@ -0,0 +1,174 @@
use broadcaster::BroadcastChannel;
use crate::sync::Mutex;
/// A barrier enables multiple tasks to synchronize the beginning
/// of some computation.
///
/// ```
/// # fn main() { async_std::task::block_on(async {
/// #
/// use std::sync::Arc;
/// use async_std::sync::Barrier;
/// use async_std::task;
///
/// let mut handles = Vec::with_capacity(10);
/// let barrier = Arc::new(Barrier::new(10));
/// for _ in 0..10 {
/// let c = barrier.clone();
/// // The same messages will be printed together.
/// // You will NOT see any interleaving.
/// handles.push(task::spawn(async move {
/// println!("before wait");
/// let wr = c.wait().await;
/// println!("after wait");
/// wr
/// }));
/// }
/// // Wait for the other futures to finish.
/// for handle in handles {
/// handle.await;
/// }
/// # });
/// # }
/// ```
#[derive(Debug)]
pub struct Barrier {
state: Mutex<BarrierState>,
wait: BroadcastChannel<(usize, usize)>,
n: usize,
}
// The inner state of a double barrier
#[derive(Debug)]
struct BarrierState {
waker: BroadcastChannel<(usize, usize)>,
count: usize,
generation_id: usize,
}
/// A `BarrierWaitResult` is returned by `wait` when all threads in the `Barrier` have rendezvoused.
///
/// [`wait`]: struct.Barrier.html#method.wait
/// [`Barrier`]: struct.Barrier.html
///
/// # Examples
///
/// ```
/// use async_std::sync::Barrier;
///
/// let barrier = Barrier::new(1);
/// let barrier_wait_result = barrier.wait();
/// ```
#[derive(Debug, Clone)]
pub struct BarrierWaitResult(bool);
impl Barrier {
/// Creates a new barrier that can block a given number of tasks.
///
/// A barrier will block `n`-1 tasks which call [`wait`] and then wake up
/// all tasks at once when the `n`th task calls [`wait`].
///
/// [`wait`]: #method.wait
///
/// # Examples
///
/// ```
/// use std::sync::Barrier;
///
/// let barrier = Barrier::new(10);
/// ```
pub fn new(mut n: usize) -> Barrier {
let waker = BroadcastChannel::new();
let wait = waker.clone();
if n == 0 {
// if n is 0, it's not clear what behavior the user wants.
// in std::sync::Barrier, an n of 0 exhibits the same behavior as n == 1, where every
// .wait() immediately unblocks, so we adopt that here as well.
n = 1;
}
Barrier {
state: Mutex::new(BarrierState {
waker,
count: 0,
generation_id: 1,
}),
n,
wait,
}
}
/// Blocks the current task until all tasks have rendezvoused here.
///
/// Barriers are re-usable after all tasks have rendezvoused once, and can
/// be used continuously.
///
/// A single (arbitrary) task will receive a [`BarrierWaitResult`] that
/// returns `true` from [`is_leader`] when returning from this function, and
/// all other tasks will receive a result that will return `false` from
/// [`is_leader`].
///
/// [`BarrierWaitResult`]: struct.BarrierWaitResult.html
/// [`is_leader`]: struct.BarrierWaitResult.html#method.is_leader
pub async fn wait(&self) -> BarrierWaitResult {
let mut lock = self.state.lock().await;
let local_gen = lock.generation_id;
lock.count += 1;
if lock.count < self.n {
let mut wait = self.wait.clone();
let mut generation_id = lock.generation_id;
let mut count = lock.count;
drop(lock);
while local_gen == generation_id && count < self.n {
let (g, c) = wait.recv().await.expect("sender hasn not been closed");
generation_id = g;
count = c;
}
BarrierWaitResult(false)
} else {
lock.count = 0;
lock.generation_id = lock.generation_id.wrapping_add(1);
lock.waker
.send(&(lock.generation_id, lock.count))
.await
.expect("there should be at least one receiver");
BarrierWaitResult(true)
}
}
}
impl BarrierWaitResult {
/// Returns `true` if this task from [`wait`] is the "leader task".
///
/// Only one task will have `true` returned from their result, all other
/// tasks will have `false` returned.
///
/// [`wait`]: struct.Barrier.html#method.wait
///
/// # Examples
///
/// ```
/// # fn main() { async_std::task::block_on(async {
/// #
/// use async_std::sync::Barrier;
///
/// let barrier = Barrier::new(1);
/// let barrier_wait_result = barrier.wait().await;
/// println!("{:?}", barrier_wait_result.is_leader());
/// # });
/// # }
/// ```
pub fn is_leader(&self) -> bool {
self.0
}
}

View file

@ -32,8 +32,10 @@
#[doc(inline)] #[doc(inline)]
pub use std::sync::{Arc, Weak}; pub use std::sync::{Arc, Weak};
pub use barrier::{Barrier, BarrierWaitResult};
pub use mutex::{Mutex, MutexGuard}; pub use mutex::{Mutex, MutexGuard};
pub use rwlock::{RwLock, RwLockReadGuard, RwLockWriteGuard}; pub use rwlock::{RwLock, RwLockReadGuard, RwLockWriteGuard};
mod barrier;
mod mutex; mod mutex;
mod rwlock; mod rwlock;

52
tests/barrier.rs Normal file
View file

@ -0,0 +1,52 @@
use std::sync::Arc;
use futures_channel::mpsc::unbounded;
use futures_util::sink::SinkExt;
use futures_util::stream::StreamExt;
use async_std::sync::Barrier;
use async_std::task;
#[test]
fn test_barrier() {
// Based on the test in std, I was seeing some race conditions, so running it in a loop to make sure
// things are solid.
for _ in 0..1_000 {
task::block_on(async move {
const N: usize = 10;
let barrier = Arc::new(Barrier::new(N));
let (tx, mut rx) = unbounded();
for _ in 0..N - 1 {
let c = barrier.clone();
let mut tx = tx.clone();
task::spawn(async move {
let res = c.wait().await;
tx.send(res.is_leader()).await.unwrap();
});
}
// At this point, all spawned threads should be blocked,
// so we shouldn't get anything from the port
let res = rx.try_next();
assert!(match res {
Err(_err) => true,
_ => false,
});
let mut leader_found = barrier.wait().await.is_leader();
// Now, the barrier is cleared and we should get data.
for _ in 0..N - 1 {
if rx.next().await.unwrap() {
assert!(!leader_found);
leader_found = true;
}
}
assert!(leader_found);
});
}
}