mirror of
https://github.com/async-rs/async-std.git
synced 2025-01-16 10:49:55 +00:00
feat: implement sync::Barrier
Based on the implementation in https://github.com/tokio-rs/tokio/pull/1571
This commit is contained in:
parent
785371cbc4
commit
b77b72d333
4 changed files with 229 additions and 0 deletions
|
@ -42,6 +42,7 @@ num_cpus = "1.10.1"
|
|||
pin-utils = "0.1.0-alpha.4"
|
||||
slab = "0.4.2"
|
||||
kv-log-macro = "1.0.4"
|
||||
broadcaster = "0.2.4"
|
||||
|
||||
[dev-dependencies]
|
||||
femme = "1.2.0"
|
||||
|
|
174
src/sync/barrier.rs
Normal file
174
src/sync/barrier.rs
Normal file
|
@ -0,0 +1,174 @@
|
|||
use broadcaster::BroadcastChannel;
|
||||
|
||||
use crate::sync::Mutex;
|
||||
|
||||
/// A barrier enables multiple tasks to synchronize the beginning
|
||||
/// of some computation.
|
||||
///
|
||||
/// ```
|
||||
/// # fn main() { async_std::task::block_on(async {
|
||||
/// #
|
||||
/// use std::sync::Arc;
|
||||
/// use async_std::sync::Barrier;
|
||||
/// use async_std::task;
|
||||
///
|
||||
/// let mut handles = Vec::with_capacity(10);
|
||||
/// let barrier = Arc::new(Barrier::new(10));
|
||||
/// for _ in 0..10 {
|
||||
/// let c = barrier.clone();
|
||||
/// // The same messages will be printed together.
|
||||
/// // You will NOT see any interleaving.
|
||||
/// handles.push(task::spawn(async move {
|
||||
/// println!("before wait");
|
||||
/// let wr = c.wait().await;
|
||||
/// println!("after wait");
|
||||
/// wr
|
||||
/// }));
|
||||
/// }
|
||||
/// // Wait for the other futures to finish.
|
||||
/// for handle in handles {
|
||||
/// handle.await;
|
||||
/// }
|
||||
/// # });
|
||||
/// # }
|
||||
/// ```
|
||||
#[derive(Debug)]
|
||||
pub struct Barrier {
|
||||
state: Mutex<BarrierState>,
|
||||
wait: BroadcastChannel<(usize, usize)>,
|
||||
n: usize,
|
||||
}
|
||||
|
||||
// The inner state of a double barrier
|
||||
#[derive(Debug)]
|
||||
struct BarrierState {
|
||||
waker: BroadcastChannel<(usize, usize)>,
|
||||
count: usize,
|
||||
generation_id: usize,
|
||||
}
|
||||
|
||||
/// A `BarrierWaitResult` is returned by `wait` when all threads in the `Barrier` have rendezvoused.
|
||||
///
|
||||
/// [`wait`]: struct.Barrier.html#method.wait
|
||||
/// [`Barrier`]: struct.Barrier.html
|
||||
///
|
||||
/// # Examples
|
||||
///
|
||||
/// ```
|
||||
/// use async_std::sync::Barrier;
|
||||
///
|
||||
/// let barrier = Barrier::new(1);
|
||||
/// let barrier_wait_result = barrier.wait();
|
||||
/// ```
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct BarrierWaitResult(bool);
|
||||
|
||||
impl Barrier {
|
||||
/// Creates a new barrier that can block a given number of tasks.
|
||||
///
|
||||
/// A barrier will block `n`-1 tasks which call [`wait`] and then wake up
|
||||
/// all tasks at once when the `n`th task calls [`wait`].
|
||||
///
|
||||
/// [`wait`]: #method.wait
|
||||
///
|
||||
/// # Examples
|
||||
///
|
||||
/// ```
|
||||
/// use std::sync::Barrier;
|
||||
///
|
||||
/// let barrier = Barrier::new(10);
|
||||
/// ```
|
||||
pub fn new(mut n: usize) -> Barrier {
|
||||
let waker = BroadcastChannel::new();
|
||||
let wait = waker.clone();
|
||||
|
||||
if n == 0 {
|
||||
// if n is 0, it's not clear what behavior the user wants.
|
||||
// in std::sync::Barrier, an n of 0 exhibits the same behavior as n == 1, where every
|
||||
// .wait() immediately unblocks, so we adopt that here as well.
|
||||
n = 1;
|
||||
}
|
||||
|
||||
Barrier {
|
||||
state: Mutex::new(BarrierState {
|
||||
waker,
|
||||
count: 0,
|
||||
generation_id: 1,
|
||||
}),
|
||||
n,
|
||||
wait,
|
||||
}
|
||||
}
|
||||
|
||||
/// Blocks the current task until all tasks have rendezvoused here.
|
||||
///
|
||||
/// Barriers are re-usable after all tasks have rendezvoused once, and can
|
||||
/// be used continuously.
|
||||
///
|
||||
/// A single (arbitrary) task will receive a [`BarrierWaitResult`] that
|
||||
/// returns `true` from [`is_leader`] when returning from this function, and
|
||||
/// all other tasks will receive a result that will return `false` from
|
||||
/// [`is_leader`].
|
||||
///
|
||||
/// [`BarrierWaitResult`]: struct.BarrierWaitResult.html
|
||||
/// [`is_leader`]: struct.BarrierWaitResult.html#method.is_leader
|
||||
pub async fn wait(&self) -> BarrierWaitResult {
|
||||
let mut lock = self.state.lock().await;
|
||||
let local_gen = lock.generation_id;
|
||||
|
||||
lock.count += 1;
|
||||
|
||||
if lock.count < self.n {
|
||||
let mut wait = self.wait.clone();
|
||||
|
||||
let mut generation_id = lock.generation_id;
|
||||
let mut count = lock.count;
|
||||
|
||||
drop(lock);
|
||||
|
||||
while local_gen == generation_id && count < self.n {
|
||||
let (g, c) = wait.recv().await.expect("sender hasn not been closed");
|
||||
generation_id = g;
|
||||
count = c;
|
||||
}
|
||||
|
||||
BarrierWaitResult(false)
|
||||
} else {
|
||||
lock.count = 0;
|
||||
lock.generation_id = lock.generation_id.wrapping_add(1);
|
||||
|
||||
lock.waker
|
||||
.send(&(lock.generation_id, lock.count))
|
||||
.await
|
||||
.expect("there should be at least one receiver");
|
||||
|
||||
BarrierWaitResult(true)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl BarrierWaitResult {
|
||||
/// Returns `true` if this task from [`wait`] is the "leader task".
|
||||
///
|
||||
/// Only one task will have `true` returned from their result, all other
|
||||
/// tasks will have `false` returned.
|
||||
///
|
||||
/// [`wait`]: struct.Barrier.html#method.wait
|
||||
///
|
||||
/// # Examples
|
||||
///
|
||||
/// ```
|
||||
/// # fn main() { async_std::task::block_on(async {
|
||||
/// #
|
||||
/// use async_std::sync::Barrier;
|
||||
///
|
||||
/// let barrier = Barrier::new(1);
|
||||
/// let barrier_wait_result = barrier.wait().await;
|
||||
/// println!("{:?}", barrier_wait_result.is_leader());
|
||||
/// # });
|
||||
/// # }
|
||||
/// ```
|
||||
pub fn is_leader(&self) -> bool {
|
||||
self.0
|
||||
}
|
||||
}
|
|
@ -32,8 +32,10 @@
|
|||
#[doc(inline)]
|
||||
pub use std::sync::{Arc, Weak};
|
||||
|
||||
pub use barrier::{Barrier, BarrierWaitResult};
|
||||
pub use mutex::{Mutex, MutexGuard};
|
||||
pub use rwlock::{RwLock, RwLockReadGuard, RwLockWriteGuard};
|
||||
|
||||
mod barrier;
|
||||
mod mutex;
|
||||
mod rwlock;
|
||||
|
|
52
tests/barrier.rs
Normal file
52
tests/barrier.rs
Normal file
|
@ -0,0 +1,52 @@
|
|||
use std::sync::Arc;
|
||||
|
||||
use futures_channel::mpsc::unbounded;
|
||||
use futures_util::sink::SinkExt;
|
||||
use futures_util::stream::StreamExt;
|
||||
|
||||
use async_std::sync::Barrier;
|
||||
use async_std::task;
|
||||
|
||||
#[test]
|
||||
fn test_barrier() {
|
||||
// Based on the test in std, I was seeing some race conditions, so running it in a loop to make sure
|
||||
// things are solid.
|
||||
|
||||
for _ in 0..1_000 {
|
||||
task::block_on(async move {
|
||||
const N: usize = 10;
|
||||
|
||||
let barrier = Arc::new(Barrier::new(N));
|
||||
let (tx, mut rx) = unbounded();
|
||||
|
||||
for _ in 0..N - 1 {
|
||||
let c = barrier.clone();
|
||||
let mut tx = tx.clone();
|
||||
task::spawn(async move {
|
||||
let res = c.wait().await;
|
||||
|
||||
tx.send(res.is_leader()).await.unwrap();
|
||||
});
|
||||
}
|
||||
|
||||
// At this point, all spawned threads should be blocked,
|
||||
// so we shouldn't get anything from the port
|
||||
let res = rx.try_next();
|
||||
assert!(match res {
|
||||
Err(_err) => true,
|
||||
_ => false,
|
||||
});
|
||||
|
||||
let mut leader_found = barrier.wait().await.is_leader();
|
||||
|
||||
// Now, the barrier is cleared and we should get data.
|
||||
for _ in 0..N - 1 {
|
||||
if rx.next().await.unwrap() {
|
||||
assert!(!leader_found);
|
||||
leader_found = true;
|
||||
}
|
||||
}
|
||||
assert!(leader_found);
|
||||
});
|
||||
}
|
||||
}
|
Loading…
Reference in a new issue