forked from mirror/async-std
feat: implement sync::Barrier
Based on the implementation in https://github.com/tokio-rs/tokio/pull/1571staging
parent
785371cbc4
commit
b77b72d333
@ -0,0 +1,174 @@
|
|||||||
|
use broadcaster::BroadcastChannel;
|
||||||
|
|
||||||
|
use crate::sync::Mutex;
|
||||||
|
|
||||||
|
/// A barrier enables multiple tasks to synchronize the beginning
|
||||||
|
/// of some computation.
|
||||||
|
///
|
||||||
|
/// ```
|
||||||
|
/// # fn main() { async_std::task::block_on(async {
|
||||||
|
/// #
|
||||||
|
/// use std::sync::Arc;
|
||||||
|
/// use async_std::sync::Barrier;
|
||||||
|
/// use async_std::task;
|
||||||
|
///
|
||||||
|
/// let mut handles = Vec::with_capacity(10);
|
||||||
|
/// let barrier = Arc::new(Barrier::new(10));
|
||||||
|
/// for _ in 0..10 {
|
||||||
|
/// let c = barrier.clone();
|
||||||
|
/// // The same messages will be printed together.
|
||||||
|
/// // You will NOT see any interleaving.
|
||||||
|
/// handles.push(task::spawn(async move {
|
||||||
|
/// println!("before wait");
|
||||||
|
/// let wr = c.wait().await;
|
||||||
|
/// println!("after wait");
|
||||||
|
/// wr
|
||||||
|
/// }));
|
||||||
|
/// }
|
||||||
|
/// // Wait for the other futures to finish.
|
||||||
|
/// for handle in handles {
|
||||||
|
/// handle.await;
|
||||||
|
/// }
|
||||||
|
/// # });
|
||||||
|
/// # }
|
||||||
|
/// ```
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub struct Barrier {
|
||||||
|
state: Mutex<BarrierState>,
|
||||||
|
wait: BroadcastChannel<(usize, usize)>,
|
||||||
|
n: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
// The inner state of a double barrier
|
||||||
|
#[derive(Debug)]
|
||||||
|
struct BarrierState {
|
||||||
|
waker: BroadcastChannel<(usize, usize)>,
|
||||||
|
count: usize,
|
||||||
|
generation_id: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// A `BarrierWaitResult` is returned by `wait` when all threads in the `Barrier` have rendezvoused.
|
||||||
|
///
|
||||||
|
/// [`wait`]: struct.Barrier.html#method.wait
|
||||||
|
/// [`Barrier`]: struct.Barrier.html
|
||||||
|
///
|
||||||
|
/// # Examples
|
||||||
|
///
|
||||||
|
/// ```
|
||||||
|
/// use async_std::sync::Barrier;
|
||||||
|
///
|
||||||
|
/// let barrier = Barrier::new(1);
|
||||||
|
/// let barrier_wait_result = barrier.wait();
|
||||||
|
/// ```
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct BarrierWaitResult(bool);
|
||||||
|
|
||||||
|
impl Barrier {
|
||||||
|
/// Creates a new barrier that can block a given number of tasks.
|
||||||
|
///
|
||||||
|
/// A barrier will block `n`-1 tasks which call [`wait`] and then wake up
|
||||||
|
/// all tasks at once when the `n`th task calls [`wait`].
|
||||||
|
///
|
||||||
|
/// [`wait`]: #method.wait
|
||||||
|
///
|
||||||
|
/// # Examples
|
||||||
|
///
|
||||||
|
/// ```
|
||||||
|
/// use std::sync::Barrier;
|
||||||
|
///
|
||||||
|
/// let barrier = Barrier::new(10);
|
||||||
|
/// ```
|
||||||
|
pub fn new(mut n: usize) -> Barrier {
|
||||||
|
let waker = BroadcastChannel::new();
|
||||||
|
let wait = waker.clone();
|
||||||
|
|
||||||
|
if n == 0 {
|
||||||
|
// if n is 0, it's not clear what behavior the user wants.
|
||||||
|
// in std::sync::Barrier, an n of 0 exhibits the same behavior as n == 1, where every
|
||||||
|
// .wait() immediately unblocks, so we adopt that here as well.
|
||||||
|
n = 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
Barrier {
|
||||||
|
state: Mutex::new(BarrierState {
|
||||||
|
waker,
|
||||||
|
count: 0,
|
||||||
|
generation_id: 1,
|
||||||
|
}),
|
||||||
|
n,
|
||||||
|
wait,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Blocks the current task until all tasks have rendezvoused here.
|
||||||
|
///
|
||||||
|
/// Barriers are re-usable after all tasks have rendezvoused once, and can
|
||||||
|
/// be used continuously.
|
||||||
|
///
|
||||||
|
/// A single (arbitrary) task will receive a [`BarrierWaitResult`] that
|
||||||
|
/// returns `true` from [`is_leader`] when returning from this function, and
|
||||||
|
/// all other tasks will receive a result that will return `false` from
|
||||||
|
/// [`is_leader`].
|
||||||
|
///
|
||||||
|
/// [`BarrierWaitResult`]: struct.BarrierWaitResult.html
|
||||||
|
/// [`is_leader`]: struct.BarrierWaitResult.html#method.is_leader
|
||||||
|
pub async fn wait(&self) -> BarrierWaitResult {
|
||||||
|
let mut lock = self.state.lock().await;
|
||||||
|
let local_gen = lock.generation_id;
|
||||||
|
|
||||||
|
lock.count += 1;
|
||||||
|
|
||||||
|
if lock.count < self.n {
|
||||||
|
let mut wait = self.wait.clone();
|
||||||
|
|
||||||
|
let mut generation_id = lock.generation_id;
|
||||||
|
let mut count = lock.count;
|
||||||
|
|
||||||
|
drop(lock);
|
||||||
|
|
||||||
|
while local_gen == generation_id && count < self.n {
|
||||||
|
let (g, c) = wait.recv().await.expect("sender hasn not been closed");
|
||||||
|
generation_id = g;
|
||||||
|
count = c;
|
||||||
|
}
|
||||||
|
|
||||||
|
BarrierWaitResult(false)
|
||||||
|
} else {
|
||||||
|
lock.count = 0;
|
||||||
|
lock.generation_id = lock.generation_id.wrapping_add(1);
|
||||||
|
|
||||||
|
lock.waker
|
||||||
|
.send(&(lock.generation_id, lock.count))
|
||||||
|
.await
|
||||||
|
.expect("there should be at least one receiver");
|
||||||
|
|
||||||
|
BarrierWaitResult(true)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl BarrierWaitResult {
|
||||||
|
/// Returns `true` if this task from [`wait`] is the "leader task".
|
||||||
|
///
|
||||||
|
/// Only one task will have `true` returned from their result, all other
|
||||||
|
/// tasks will have `false` returned.
|
||||||
|
///
|
||||||
|
/// [`wait`]: struct.Barrier.html#method.wait
|
||||||
|
///
|
||||||
|
/// # Examples
|
||||||
|
///
|
||||||
|
/// ```
|
||||||
|
/// # fn main() { async_std::task::block_on(async {
|
||||||
|
/// #
|
||||||
|
/// use async_std::sync::Barrier;
|
||||||
|
///
|
||||||
|
/// let barrier = Barrier::new(1);
|
||||||
|
/// let barrier_wait_result = barrier.wait().await;
|
||||||
|
/// println!("{:?}", barrier_wait_result.is_leader());
|
||||||
|
/// # });
|
||||||
|
/// # }
|
||||||
|
/// ```
|
||||||
|
pub fn is_leader(&self) -> bool {
|
||||||
|
self.0
|
||||||
|
}
|
||||||
|
}
|
@ -0,0 +1,52 @@
|
|||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
|
use futures_channel::mpsc::unbounded;
|
||||||
|
use futures_util::sink::SinkExt;
|
||||||
|
use futures_util::stream::StreamExt;
|
||||||
|
|
||||||
|
use async_std::sync::Barrier;
|
||||||
|
use async_std::task;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_barrier() {
|
||||||
|
// Based on the test in std, I was seeing some race conditions, so running it in a loop to make sure
|
||||||
|
// things are solid.
|
||||||
|
|
||||||
|
for _ in 0..1_000 {
|
||||||
|
task::block_on(async move {
|
||||||
|
const N: usize = 10;
|
||||||
|
|
||||||
|
let barrier = Arc::new(Barrier::new(N));
|
||||||
|
let (tx, mut rx) = unbounded();
|
||||||
|
|
||||||
|
for _ in 0..N - 1 {
|
||||||
|
let c = barrier.clone();
|
||||||
|
let mut tx = tx.clone();
|
||||||
|
task::spawn(async move {
|
||||||
|
let res = c.wait().await;
|
||||||
|
|
||||||
|
tx.send(res.is_leader()).await.unwrap();
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
// At this point, all spawned threads should be blocked,
|
||||||
|
// so we shouldn't get anything from the port
|
||||||
|
let res = rx.try_next();
|
||||||
|
assert!(match res {
|
||||||
|
Err(_err) => true,
|
||||||
|
_ => false,
|
||||||
|
});
|
||||||
|
|
||||||
|
let mut leader_found = barrier.wait().await.is_leader();
|
||||||
|
|
||||||
|
// Now, the barrier is cleared and we should get data.
|
||||||
|
for _ in 0..N - 1 {
|
||||||
|
if rx.next().await.unwrap() {
|
||||||
|
assert!(!leader_found);
|
||||||
|
leader_found = true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
assert!(leader_found);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
Loading…
Reference in New Issue