feat: implement Barrier using Condvar

master
Thayne McCombs 5 years ago committed by GitHub
parent 10f7abb3b6
commit 6f6fced103
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -31,7 +31,7 @@ default = [
"pin-project-lite",
]
docs = ["attributes", "unstable", "default"]
unstable = ["std", "broadcaster"]
unstable = ["std"]
attributes = ["async-attributes"]
std = [
"alloc",
@ -55,7 +55,6 @@ alloc = [
[dependencies]
async-attributes = { version = "1.1.1", optional = true }
async-task = { version = "3.0.0", optional = true }
broadcaster = { version = "1.0.0", optional = true }
crossbeam-utils = { version = "0.7.2", optional = true }
futures-core = { version = "0.3.4", optional = true, default-features = false }
futures-io = { version = "0.3.4", optional = true }

@ -1,6 +1,4 @@
use broadcaster::BroadcastChannel;
use crate::sync::Mutex;
use crate::sync::{Condvar,Mutex};
/// A barrier enables multiple tasks to synchronize the beginning
/// of some computation.
@ -36,14 +34,13 @@ use crate::sync::Mutex;
#[derive(Debug)]
pub struct Barrier {
state: Mutex<BarrierState>,
wait: BroadcastChannel<(usize, usize)>,
n: usize,
cvar: Condvar,
num_tasks: usize,
}
// The inner state of a double barrier
#[derive(Debug)]
struct BarrierState {
waker: BroadcastChannel<(usize, usize)>,
count: usize,
generation_id: usize,
}
@ -81,25 +78,14 @@ impl Barrier {
///
/// let barrier = Barrier::new(10);
/// ```
pub fn new(mut n: usize) -> Barrier {
let waker = BroadcastChannel::new();
let wait = waker.clone();
if n == 0 {
// if n is 0, it's not clear what behavior the user wants.
// in std::sync::Barrier, an n of 0 exhibits the same behavior as n == 1, where every
// .wait() immediately unblocks, so we adopt that here as well.
n = 1;
}
pub fn new(n: usize) -> Barrier {
Barrier {
state: Mutex::new(BarrierState {
waker,
count: 0,
generation_id: 1,
}),
n,
wait,
cvar: Condvar::new(),
num_tasks: n,
}
}
@ -143,35 +129,20 @@ impl Barrier {
/// # });
/// ```
pub async fn wait(&self) -> BarrierWaitResult {
let mut lock = self.state.lock().await;
let local_gen = lock.generation_id;
lock.count += 1;
let mut state = self.state.lock().await;
let local_gen = state.generation_id;
state.count += 1;
if lock.count < self.n {
let mut wait = self.wait.clone();
let mut generation_id = lock.generation_id;
let mut count = lock.count;
drop(lock);
while local_gen == generation_id && count < self.n {
let (g, c) = wait.recv().await.expect("sender has not been closed");
generation_id = g;
count = c;
if state.count < self.num_tasks {
while local_gen == state.generation_id && state.count < self.num_tasks {
state = self.cvar.wait(state).await;
}
BarrierWaitResult(false)
} else {
lock.count = 0;
lock.generation_id = lock.generation_id.wrapping_add(1);
lock.waker
.send(&(lock.generation_id, lock.count))
.await
.expect("there should be at least one receiver");
state.count = 0;
state.generation_id = state.generation_id.wrapping_add(1);
self.cvar.notify_all();
BarrierWaitResult(true)
}
}

Loading…
Cancel
Save