mirror of
https://github.com/async-rs/async-std.git
synced 2025-02-06 20:55:33 +00:00
feat: implement Barrier using Condvar
This commit is contained in:
parent
10f7abb3b6
commit
6f6fced103
2 changed files with 16 additions and 46 deletions
|
@ -31,7 +31,7 @@ default = [
|
||||||
"pin-project-lite",
|
"pin-project-lite",
|
||||||
]
|
]
|
||||||
docs = ["attributes", "unstable", "default"]
|
docs = ["attributes", "unstable", "default"]
|
||||||
unstable = ["std", "broadcaster"]
|
unstable = ["std"]
|
||||||
attributes = ["async-attributes"]
|
attributes = ["async-attributes"]
|
||||||
std = [
|
std = [
|
||||||
"alloc",
|
"alloc",
|
||||||
|
@ -55,7 +55,6 @@ alloc = [
|
||||||
[dependencies]
|
[dependencies]
|
||||||
async-attributes = { version = "1.1.1", optional = true }
|
async-attributes = { version = "1.1.1", optional = true }
|
||||||
async-task = { version = "3.0.0", optional = true }
|
async-task = { version = "3.0.0", optional = true }
|
||||||
broadcaster = { version = "1.0.0", optional = true }
|
|
||||||
crossbeam-utils = { version = "0.7.2", optional = true }
|
crossbeam-utils = { version = "0.7.2", optional = true }
|
||||||
futures-core = { version = "0.3.4", optional = true, default-features = false }
|
futures-core = { version = "0.3.4", optional = true, default-features = false }
|
||||||
futures-io = { version = "0.3.4", optional = true }
|
futures-io = { version = "0.3.4", optional = true }
|
||||||
|
|
|
@ -1,6 +1,4 @@
|
||||||
use broadcaster::BroadcastChannel;
|
use crate::sync::{Condvar,Mutex};
|
||||||
|
|
||||||
use crate::sync::Mutex;
|
|
||||||
|
|
||||||
/// A barrier enables multiple tasks to synchronize the beginning
|
/// A barrier enables multiple tasks to synchronize the beginning
|
||||||
/// of some computation.
|
/// of some computation.
|
||||||
|
@ -36,14 +34,13 @@ use crate::sync::Mutex;
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
pub struct Barrier {
|
pub struct Barrier {
|
||||||
state: Mutex<BarrierState>,
|
state: Mutex<BarrierState>,
|
||||||
wait: BroadcastChannel<(usize, usize)>,
|
cvar: Condvar,
|
||||||
n: usize,
|
num_tasks: usize,
|
||||||
}
|
}
|
||||||
|
|
||||||
// The inner state of a double barrier
|
// The inner state of a double barrier
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
struct BarrierState {
|
struct BarrierState {
|
||||||
waker: BroadcastChannel<(usize, usize)>,
|
|
||||||
count: usize,
|
count: usize,
|
||||||
generation_id: usize,
|
generation_id: usize,
|
||||||
}
|
}
|
||||||
|
@ -81,25 +78,14 @@ impl Barrier {
|
||||||
///
|
///
|
||||||
/// let barrier = Barrier::new(10);
|
/// let barrier = Barrier::new(10);
|
||||||
/// ```
|
/// ```
|
||||||
pub fn new(mut n: usize) -> Barrier {
|
pub fn new(n: usize) -> Barrier {
|
||||||
let waker = BroadcastChannel::new();
|
|
||||||
let wait = waker.clone();
|
|
||||||
|
|
||||||
if n == 0 {
|
|
||||||
// if n is 0, it's not clear what behavior the user wants.
|
|
||||||
// in std::sync::Barrier, an n of 0 exhibits the same behavior as n == 1, where every
|
|
||||||
// .wait() immediately unblocks, so we adopt that here as well.
|
|
||||||
n = 1;
|
|
||||||
}
|
|
||||||
|
|
||||||
Barrier {
|
Barrier {
|
||||||
state: Mutex::new(BarrierState {
|
state: Mutex::new(BarrierState {
|
||||||
waker,
|
|
||||||
count: 0,
|
count: 0,
|
||||||
generation_id: 1,
|
generation_id: 1,
|
||||||
}),
|
}),
|
||||||
n,
|
cvar: Condvar::new(),
|
||||||
wait,
|
num_tasks: n,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -143,35 +129,20 @@ impl Barrier {
|
||||||
/// # });
|
/// # });
|
||||||
/// ```
|
/// ```
|
||||||
pub async fn wait(&self) -> BarrierWaitResult {
|
pub async fn wait(&self) -> BarrierWaitResult {
|
||||||
let mut lock = self.state.lock().await;
|
let mut state = self.state.lock().await;
|
||||||
let local_gen = lock.generation_id;
|
let local_gen = state.generation_id;
|
||||||
|
state.count += 1;
|
||||||
|
|
||||||
lock.count += 1;
|
if state.count < self.num_tasks {
|
||||||
|
while local_gen == state.generation_id && state.count < self.num_tasks {
|
||||||
if lock.count < self.n {
|
state = self.cvar.wait(state).await;
|
||||||
let mut wait = self.wait.clone();
|
|
||||||
|
|
||||||
let mut generation_id = lock.generation_id;
|
|
||||||
let mut count = lock.count;
|
|
||||||
|
|
||||||
drop(lock);
|
|
||||||
|
|
||||||
while local_gen == generation_id && count < self.n {
|
|
||||||
let (g, c) = wait.recv().await.expect("sender has not been closed");
|
|
||||||
generation_id = g;
|
|
||||||
count = c;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
BarrierWaitResult(false)
|
BarrierWaitResult(false)
|
||||||
} else {
|
} else {
|
||||||
lock.count = 0;
|
state.count = 0;
|
||||||
lock.generation_id = lock.generation_id.wrapping_add(1);
|
state.generation_id = state.generation_id.wrapping_add(1);
|
||||||
|
self.cvar.notify_all();
|
||||||
lock.waker
|
|
||||||
.send(&(lock.generation_id, lock.count))
|
|
||||||
.await
|
|
||||||
.expect("there should be at least one receiver");
|
|
||||||
|
|
||||||
BarrierWaitResult(true)
|
BarrierWaitResult(true)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in a new issue