|
|
|
@ -1,6 +1,4 @@
|
|
|
|
|
use broadcaster::BroadcastChannel;
|
|
|
|
|
|
|
|
|
|
use crate::sync::Mutex;
|
|
|
|
|
use crate::sync::{Condvar,Mutex};
|
|
|
|
|
|
|
|
|
|
/// A barrier enables multiple tasks to synchronize the beginning
|
|
|
|
|
/// of some computation.
|
|
|
|
@ -36,14 +34,13 @@ use crate::sync::Mutex;
|
|
|
|
|
#[derive(Debug)]
|
|
|
|
|
pub struct Barrier {
|
|
|
|
|
state: Mutex<BarrierState>,
|
|
|
|
|
wait: BroadcastChannel<(usize, usize)>,
|
|
|
|
|
n: usize,
|
|
|
|
|
cvar: Condvar,
|
|
|
|
|
num_tasks: usize,
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// The inner state of a double barrier
|
|
|
|
|
#[derive(Debug)]
|
|
|
|
|
struct BarrierState {
|
|
|
|
|
waker: BroadcastChannel<(usize, usize)>,
|
|
|
|
|
count: usize,
|
|
|
|
|
generation_id: usize,
|
|
|
|
|
}
|
|
|
|
@ -81,25 +78,14 @@ impl Barrier {
|
|
|
|
|
///
|
|
|
|
|
/// let barrier = Barrier::new(10);
|
|
|
|
|
/// ```
|
|
|
|
|
pub fn new(mut n: usize) -> Barrier {
|
|
|
|
|
let waker = BroadcastChannel::new();
|
|
|
|
|
let wait = waker.clone();
|
|
|
|
|
|
|
|
|
|
if n == 0 {
|
|
|
|
|
// if n is 0, it's not clear what behavior the user wants.
|
|
|
|
|
// in std::sync::Barrier, an n of 0 exhibits the same behavior as n == 1, where every
|
|
|
|
|
// .wait() immediately unblocks, so we adopt that here as well.
|
|
|
|
|
n = 1;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
pub fn new(n: usize) -> Barrier {
|
|
|
|
|
Barrier {
|
|
|
|
|
state: Mutex::new(BarrierState {
|
|
|
|
|
waker,
|
|
|
|
|
count: 0,
|
|
|
|
|
generation_id: 1,
|
|
|
|
|
}),
|
|
|
|
|
n,
|
|
|
|
|
wait,
|
|
|
|
|
cvar: Condvar::new(),
|
|
|
|
|
num_tasks: n,
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -143,35 +129,20 @@ impl Barrier {
|
|
|
|
|
/// # });
|
|
|
|
|
/// ```
|
|
|
|
|
pub async fn wait(&self) -> BarrierWaitResult {
|
|
|
|
|
let mut lock = self.state.lock().await;
|
|
|
|
|
let local_gen = lock.generation_id;
|
|
|
|
|
|
|
|
|
|
lock.count += 1;
|
|
|
|
|
let mut state = self.state.lock().await;
|
|
|
|
|
let local_gen = state.generation_id;
|
|
|
|
|
state.count += 1;
|
|
|
|
|
|
|
|
|
|
if lock.count < self.n {
|
|
|
|
|
let mut wait = self.wait.clone();
|
|
|
|
|
|
|
|
|
|
let mut generation_id = lock.generation_id;
|
|
|
|
|
let mut count = lock.count;
|
|
|
|
|
|
|
|
|
|
drop(lock);
|
|
|
|
|
|
|
|
|
|
while local_gen == generation_id && count < self.n {
|
|
|
|
|
let (g, c) = wait.recv().await.expect("sender has not been closed");
|
|
|
|
|
generation_id = g;
|
|
|
|
|
count = c;
|
|
|
|
|
if state.count < self.num_tasks {
|
|
|
|
|
while local_gen == state.generation_id && state.count < self.num_tasks {
|
|
|
|
|
state = self.cvar.wait(state).await;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
BarrierWaitResult(false)
|
|
|
|
|
} else {
|
|
|
|
|
lock.count = 0;
|
|
|
|
|
lock.generation_id = lock.generation_id.wrapping_add(1);
|
|
|
|
|
|
|
|
|
|
lock.waker
|
|
|
|
|
.send(&(lock.generation_id, lock.count))
|
|
|
|
|
.await
|
|
|
|
|
.expect("there should be at least one receiver");
|
|
|
|
|
|
|
|
|
|
state.count = 0;
|
|
|
|
|
state.generation_id = state.generation_id.wrapping_add(1);
|
|
|
|
|
self.cvar.notify_all();
|
|
|
|
|
BarrierWaitResult(true)
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|