Add utility type WakerSet to the sync module (#390)

* Add utility type Registry to the sync module

* Remove unused import

* Split unregister into complete and cancel

* Refactoring and renaming

* Split remove() into complete() and cancel()

* Rename to WakerSet

* Ignore clippy warning

* Ignore another clippy warning

* Use stronger SeqCst ordering
pull/433/head
Stjepan Glavina 5 years ago committed by GitHub
parent 3dd59d7056
commit 87de4e1598
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -0,0 +1,42 @@
#![feature(test)]
extern crate test;
use std::sync::Arc;
use async_std::sync::Mutex;
use async_std::task;
use test::Bencher;
#[bench]
fn create(b: &mut Bencher) {
b.iter(|| Mutex::new(()));
}
#[bench]
fn contention(b: &mut Bencher) {
b.iter(|| task::block_on(run(10, 1000)));
}
#[bench]
fn no_contention(b: &mut Bencher) {
b.iter(|| task::block_on(run(1, 10000)));
}
async fn run(task: usize, iter: usize) {
let m = Arc::new(Mutex::new(()));
let mut tasks = Vec::new();
for _ in 0..task {
let m = m.clone();
tasks.push(task::spawn(async move {
for _ in 0..iter {
let _ = m.lock().await;
}
}));
}
for t in tasks {
t.await;
}
}

@ -0,0 +1,11 @@
#![feature(test)]
extern crate test;
use async_std::task;
use test::Bencher;
#[bench]
fn block_on(b: &mut Bencher) {
b.iter(|| task::block_on(async {}));
}

@ -4,17 +4,17 @@ use std::future::Future;
use std::isize;
use std::marker::PhantomData;
use std::mem;
use std::ops::{Deref, DerefMut};
use std::pin::Pin;
use std::process;
use std::ptr;
use std::sync::atomic::{self, AtomicBool, AtomicUsize, Ordering};
use std::sync::atomic::{self, AtomicUsize, Ordering};
use std::sync::Arc;
use std::task::{Context, Poll, Waker};
use std::task::{Context, Poll};
use crossbeam_utils::{Backoff, CachePadded};
use futures_core::stream::Stream;
use slab::Slab;
use crate::stream::Stream;
use crate::sync::WakerSet;
/// Creates a bounded multi-producer multi-consumer channel.
///
@ -128,7 +128,7 @@ impl<T> Sender<T> {
/// ```
pub async fn send(&self, msg: T) {
struct SendFuture<'a, T> {
sender: &'a Sender<T>,
channel: &'a Channel<T>,
msg: Option<T>,
opt_key: Option<usize>,
}
@ -142,23 +142,23 @@ impl<T> Sender<T> {
let msg = self.msg.take().unwrap();
// Try sending the message.
let poll = match self.sender.channel.push(msg) {
let poll = match self.channel.try_send(msg) {
Ok(()) => Poll::Ready(()),
Err(PushError::Disconnected(msg)) => {
Err(TrySendError::Disconnected(msg)) => {
self.msg = Some(msg);
Poll::Pending
}
Err(PushError::Full(msg)) => {
// Register the current task.
Err(TrySendError::Full(msg)) => {
// Insert this send operation.
match self.opt_key {
None => self.opt_key = Some(self.sender.channel.sends.register(cx)),
Some(key) => self.sender.channel.sends.reregister(key, cx),
None => self.opt_key = Some(self.channel.send_wakers.insert(cx)),
Some(key) => self.channel.send_wakers.update(key, cx),
}
// Try sending the message again.
match self.sender.channel.push(msg) {
match self.channel.try_send(msg) {
Ok(()) => Poll::Ready(()),
Err(PushError::Disconnected(msg)) | Err(PushError::Full(msg)) => {
Err(TrySendError::Disconnected(msg)) | Err(TrySendError::Full(msg)) => {
self.msg = Some(msg);
Poll::Pending
}
@ -167,10 +167,9 @@ impl<T> Sender<T> {
};
if poll.is_ready() {
// If the current task was registered, unregister now.
// If the current task is in the set, remove it.
if let Some(key) = self.opt_key.take() {
// `true` means the send operation is completed.
self.sender.channel.sends.unregister(key, true);
self.channel.send_wakers.complete(key);
}
}
@ -180,16 +179,16 @@ impl<T> Sender<T> {
impl<T> Drop for SendFuture<'_, T> {
fn drop(&mut self) {
// If the current task was registered, unregister now.
// If the current task is still in the set, that means it is being cancelled now.
// Wake up another task instead.
if let Some(key) = self.opt_key {
// `false` means the send operation is cancelled.
self.sender.channel.sends.unregister(key, false);
self.channel.send_wakers.cancel(key);
}
}
}
SendFuture {
sender: self,
channel: &self.channel,
msg: Some(msg),
opt_key: None,
}
@ -340,7 +339,7 @@ pub struct Receiver<T> {
/// The inner channel.
channel: Arc<Channel<T>>,
/// The registration key for this receiver in the `channel.streams` registry.
/// The key for this receiver in the `channel.stream_wakers` set.
opt_key: Option<usize>,
}
@ -382,16 +381,20 @@ impl<T> Receiver<T> {
type Output = Option<T>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
poll_recv(&self.channel, &self.channel.recvs, &mut self.opt_key, cx)
poll_recv(
&self.channel,
&self.channel.recv_wakers,
&mut self.opt_key,
cx,
)
}
}
impl<T> Drop for RecvFuture<'_, T> {
fn drop(&mut self) {
// If the current task was registered, unregister now.
// If the current task is still in the set, that means it is being cancelled now.
if let Some(key) = self.opt_key {
// `false` means the receive operation is cancelled.
self.channel.recvs.unregister(key, false);
self.channel.recv_wakers.cancel(key);
}
}
}
@ -484,10 +487,9 @@ impl<T> Receiver<T> {
impl<T> Drop for Receiver<T> {
fn drop(&mut self) {
// If the current task was registered as blocked on this stream, unregister now.
// If the current task is still in the stream set, that means it is being cancelled now.
if let Some(key) = self.opt_key {
// `false` means the last request for a stream item is cancelled.
self.channel.streams.unregister(key, false);
self.channel.stream_wakers.cancel(key);
}
// Decrement the receiver count and disconnect the channel if it drops down to zero.
@ -518,7 +520,12 @@ impl<T> Stream for Receiver<T> {
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let this = &mut *self;
poll_recv(&this.channel, &this.channel.streams, &mut this.opt_key, cx)
poll_recv(
&this.channel,
&this.channel.stream_wakers,
&mut this.opt_key,
cx,
)
}
}
@ -530,39 +537,38 @@ impl<T> fmt::Debug for Receiver<T> {
/// Polls a receive operation on a channel.
///
/// If the receive operation is blocked, the current task will be registered in `registry` and its
/// registration key will then be stored in `opt_key`.
/// If the receive operation is blocked, the current task will be inserted into `wakers` and its
/// associated key will then be stored in `opt_key`.
fn poll_recv<T>(
channel: &Channel<T>,
registry: &Registry,
wakers: &WakerSet,
opt_key: &mut Option<usize>,
cx: &mut Context<'_>,
) -> Poll<Option<T>> {
// Try receiving a message.
let poll = match channel.pop() {
let poll = match channel.try_recv() {
Ok(msg) => Poll::Ready(Some(msg)),
Err(PopError::Disconnected) => Poll::Ready(None),
Err(PopError::Empty) => {
// Register the current task.
Err(TryRecvError::Disconnected) => Poll::Ready(None),
Err(TryRecvError::Empty) => {
// Insert this receive operation.
match *opt_key {
None => *opt_key = Some(registry.register(cx)),
Some(key) => registry.reregister(key, cx),
None => *opt_key = Some(wakers.insert(cx)),
Some(key) => wakers.update(key, cx),
}
// Try receiving a message again.
match channel.pop() {
match channel.try_recv() {
Ok(msg) => Poll::Ready(Some(msg)),
Err(PopError::Disconnected) => Poll::Ready(None),
Err(PopError::Empty) => Poll::Pending,
Err(TryRecvError::Disconnected) => Poll::Ready(None),
Err(TryRecvError::Empty) => Poll::Pending,
}
}
};
if poll.is_ready() {
// If the current task was registered, unregister now.
// If the current task is in the set, remove it.
if let Some(key) = opt_key.take() {
// `true` means the receive operation is completed.
registry.unregister(key, true);
wakers.complete(key);
}
}
@ -612,13 +618,13 @@ struct Channel<T> {
mark_bit: usize,
/// Send operations waiting while the channel is full.
sends: Registry,
send_wakers: WakerSet,
/// Receive operations waiting while the channel is empty and not disconnected.
recvs: Registry,
recv_wakers: WakerSet,
/// Streams waiting while the channel is empty and not disconnected.
streams: Registry,
stream_wakers: WakerSet,
/// The number of currently active `Sender`s.
sender_count: AtomicUsize,
@ -672,17 +678,17 @@ impl<T> Channel<T> {
mark_bit,
head: CachePadded::new(AtomicUsize::new(head)),
tail: CachePadded::new(AtomicUsize::new(tail)),
sends: Registry::new(),
recvs: Registry::new(),
streams: Registry::new(),
send_wakers: WakerSet::new(),
recv_wakers: WakerSet::new(),
stream_wakers: WakerSet::new(),
sender_count: AtomicUsize::new(1),
receiver_count: AtomicUsize::new(1),
_marker: PhantomData,
}
}
/// Attempts to push a message.
fn push(&self, msg: T) -> Result<(), PushError<T>> {
/// Attempts to send a message.
fn try_send(&self, msg: T) -> Result<(), TrySendError<T>> {
let backoff = Backoff::new();
let mut tail = self.tail.load(Ordering::Relaxed);
@ -721,10 +727,10 @@ impl<T> Channel<T> {
slot.stamp.store(stamp, Ordering::Release);
// Wake a blocked receive operation.
self.recvs.notify_one();
self.recv_wakers.notify_one();
// Wake all blocked streams.
self.streams.notify_all();
self.stream_wakers.notify_all();
return Ok(());
}
@ -743,9 +749,9 @@ impl<T> Channel<T> {
// Check if the channel is disconnected.
if tail & self.mark_bit != 0 {
return Err(PushError::Disconnected(msg));
return Err(TrySendError::Disconnected(msg));
} else {
return Err(PushError::Full(msg));
return Err(TrySendError::Full(msg));
}
}
@ -759,8 +765,8 @@ impl<T> Channel<T> {
}
}
/// Attempts to pop a message.
fn pop(&self) -> Result<T, PopError> {
/// Attempts to receive a message.
fn try_recv(&self) -> Result<T, TryRecvError> {
let backoff = Backoff::new();
let mut head = self.head.load(Ordering::Relaxed);
@ -799,7 +805,7 @@ impl<T> Channel<T> {
slot.stamp.store(stamp, Ordering::Release);
// Wake a blocked send operation.
self.sends.notify_one();
self.send_wakers.notify_one();
return Ok(msg);
}
@ -816,10 +822,10 @@ impl<T> Channel<T> {
if (tail & !self.mark_bit) == head {
// If the channel is disconnected...
if tail & self.mark_bit != 0 {
return Err(PopError::Disconnected);
return Err(TryRecvError::Disconnected);
} else {
// Otherwise, the receive operation is not ready.
return Err(PopError::Empty);
return Err(TryRecvError::Empty);
}
}
@ -888,9 +894,9 @@ impl<T> Channel<T> {
if tail & self.mark_bit == 0 {
// Notify everyone blocked on this channel.
self.sends.notify_all();
self.recvs.notify_all();
self.streams.notify_all();
self.send_wakers.notify_all();
self.recv_wakers.notify_all();
self.stream_wakers.notify_all();
}
}
}
@ -921,8 +927,8 @@ impl<T> Drop for Channel<T> {
}
}
/// An error returned from the `push()` method.
enum PushError<T> {
/// An error returned from the `try_send()` method.
enum TrySendError<T> {
/// The channel is full but not disconnected.
Full(T),
@ -930,203 +936,11 @@ enum PushError<T> {
Disconnected(T),
}
/// An error returned from the `pop()` method.
enum PopError {
/// An error returned from the `try_recv()` method.
enum TryRecvError {
/// The channel is empty but not disconnected.
Empty,
/// The channel is empty and disconnected.
Disconnected,
}
/// A list of blocked channel operations.
struct Blocked {
/// A list of registered channel operations.
///
/// Each entry has a waker associated with the task that is executing the operation. If the
/// waker is set to `None`, that means the task has been woken up but hasn't removed itself
/// from the registry yet.
entries: Slab<Option<Waker>>,
/// The number of wakers in the entry list.
waker_count: usize,
}
/// A registry of blocked channel operations.
///
/// Blocked operations register themselves in a registry. Successful operations on the opposite
/// side of the channel wake blocked operations in the registry.
struct Registry {
/// A list of blocked channel operations.
blocked: Spinlock<Blocked>,
/// Set to `true` if there are no wakers in the registry.
///
/// Note that this either means there are no entries in the registry, or that all entries have
/// been notified.
is_empty: AtomicBool,
}
impl Registry {
/// Creates a new registry.
fn new() -> Registry {
Registry {
blocked: Spinlock::new(Blocked {
entries: Slab::new(),
waker_count: 0,
}),
is_empty: AtomicBool::new(true),
}
}
/// Registers a blocked channel operation and returns a key associated with it.
fn register(&self, cx: &Context<'_>) -> usize {
let mut blocked = self.blocked.lock();
// Insert a new entry into the list of blocked tasks.
let w = cx.waker().clone();
let key = blocked.entries.insert(Some(w));
blocked.waker_count += 1;
if blocked.waker_count == 1 {
self.is_empty.store(false, Ordering::SeqCst);
}
key
}
/// Re-registers a blocked channel operation by filling in its waker.
fn reregister(&self, key: usize, cx: &Context<'_>) {
let mut blocked = self.blocked.lock();
let was_none = blocked.entries[key].is_none();
let w = cx.waker().clone();
blocked.entries[key] = Some(w);
if was_none {
blocked.waker_count += 1;
if blocked.waker_count == 1 {
self.is_empty.store(false, Ordering::SeqCst);
}
}
}
/// Unregisters a channel operation.
///
/// If `completed` is `true`, the operation will be removed from the registry. If `completed`
/// is `false`, that means the operation was cancelled so another one will be notified.
fn unregister(&self, key: usize, completed: bool) {
let mut blocked = self.blocked.lock();
let mut removed = false;
match blocked.entries.remove(key) {
Some(_) => removed = true,
None => {
if !completed {
// This operation was cancelled. Notify another one.
if let Some((_, opt_waker)) = blocked.entries.iter_mut().next() {
if let Some(w) = opt_waker.take() {
w.wake();
removed = true;
}
}
}
}
}
if removed {
blocked.waker_count -= 1;
if blocked.waker_count == 0 {
self.is_empty.store(true, Ordering::SeqCst);
}
}
}
/// Notifies one blocked channel operation.
#[inline]
fn notify_one(&self) {
if !self.is_empty.load(Ordering::SeqCst) {
let mut blocked = self.blocked.lock();
if let Some((_, opt_waker)) = blocked.entries.iter_mut().next() {
// If there is no waker in this entry, that means it was already woken.
if let Some(w) = opt_waker.take() {
w.wake();
blocked.waker_count -= 1;
if blocked.waker_count == 0 {
self.is_empty.store(true, Ordering::SeqCst);
}
}
}
}
}
/// Notifies all blocked channel operations.
#[inline]
fn notify_all(&self) {
if !self.is_empty.load(Ordering::SeqCst) {
let mut blocked = self.blocked.lock();
for (_, opt_waker) in blocked.entries.iter_mut() {
// If there is no waker in this entry, that means it was already woken.
if let Some(w) = opt_waker.take() {
w.wake();
}
}
blocked.waker_count = 0;
self.is_empty.store(true, Ordering::SeqCst);
}
}
}
/// A simple spinlock.
struct Spinlock<T> {
flag: AtomicBool,
value: UnsafeCell<T>,
}
impl<T> Spinlock<T> {
/// Returns a new spinlock initialized with `value`.
fn new(value: T) -> Spinlock<T> {
Spinlock {
flag: AtomicBool::new(false),
value: UnsafeCell::new(value),
}
}
/// Locks the spinlock.
fn lock(&self) -> SpinlockGuard<'_, T> {
let backoff = Backoff::new();
while self.flag.swap(true, Ordering::Acquire) {
backoff.snooze();
}
SpinlockGuard { parent: self }
}
}
/// A guard holding a spinlock locked.
struct SpinlockGuard<'a, T> {
parent: &'a Spinlock<T>,
}
impl<'a, T> Drop for SpinlockGuard<'a, T> {
fn drop(&mut self) {
self.parent.flag.store(false, Ordering::Release);
}
}
impl<'a, T> Deref for SpinlockGuard<'a, T> {
type Target = T;
fn deref(&self) -> &T {
unsafe { &*self.parent.value.get() }
}
}
impl<'a, T> DerefMut for SpinlockGuard<'a, T> {
fn deref_mut(&mut self) -> &mut T {
unsafe { &mut *self.parent.value.get() }
}
}

@ -191,3 +191,6 @@ cfg_unstable! {
mod barrier;
mod channel;
}
pub(crate) mod waker_set;
pub(crate) use waker_set::WakerSet;

@ -2,18 +2,11 @@ use std::cell::UnsafeCell;
use std::fmt;
use std::ops::{Deref, DerefMut};
use std::pin::Pin;
use std::sync::atomic::{AtomicUsize, Ordering};
use slab::Slab;
use std::sync::atomic::{AtomicBool, Ordering};
use crate::future::Future;
use crate::task::{Context, Poll, Waker};
/// Set if the mutex is locked.
const LOCK: usize = 1;
/// Set if there are tasks blocked on the mutex.
const BLOCKED: usize = 1 << 1;
use crate::sync::WakerSet;
use crate::task::{Context, Poll};
/// A mutual exclusion primitive for protecting shared data.
///
@ -49,8 +42,8 @@ const BLOCKED: usize = 1 << 1;
/// # })
/// ```
pub struct Mutex<T> {
state: AtomicUsize,
blocked: std::sync::Mutex<Slab<Option<Waker>>>,
locked: AtomicBool,
wakers: WakerSet,
value: UnsafeCell<T>,
}
@ -69,8 +62,8 @@ impl<T> Mutex<T> {
/// ```
pub fn new(t: T) -> Mutex<T> {
Mutex {
state: AtomicUsize::new(0),
blocked: std::sync::Mutex::new(Slab::new()),
locked: AtomicBool::new(false),
wakers: WakerSet::new(),
value: UnsafeCell::new(t),
}
}
@ -105,75 +98,46 @@ impl<T> Mutex<T> {
pub struct LockFuture<'a, T> {
mutex: &'a Mutex<T>,
opt_key: Option<usize>,
acquired: bool,
}
impl<'a, T> Future for LockFuture<'a, T> {
type Output = MutexGuard<'a, T>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
match self.mutex.try_lock() {
Some(guard) => {
self.acquired = true;
Poll::Ready(guard)
}
let poll = match self.mutex.try_lock() {
Some(guard) => Poll::Ready(guard),
None => {
let mut blocked = self.mutex.blocked.lock().unwrap();
// Register the current task.
// Insert this lock operation.
match self.opt_key {
None => {
// Insert a new entry into the list of blocked tasks.
let w = cx.waker().clone();
let key = blocked.insert(Some(w));
self.opt_key = Some(key);
if blocked.len() == 1 {
self.mutex.state.fetch_or(BLOCKED, Ordering::Relaxed);
}
}
Some(key) => {
// There is already an entry in the list of blocked tasks. Just
// reset the waker if it was removed.
if blocked[key].is_none() {
let w = cx.waker().clone();
blocked[key] = Some(w);
}
}
None => self.opt_key = Some(self.mutex.wakers.insert(cx)),
Some(key) => self.mutex.wakers.update(key, cx),
}
// Try locking again because it's possible the mutex got unlocked just
// before the current task was registered as a blocked task.
// before the current task was inserted into the waker set.
match self.mutex.try_lock() {
Some(guard) => {
self.acquired = true;
Poll::Ready(guard)
}
Some(guard) => Poll::Ready(guard),
None => Poll::Pending,
}
}
};
if poll.is_ready() {
// If the current task is in the set, remove it.
if let Some(key) = self.opt_key.take() {
self.mutex.wakers.complete(key);
}
}
poll
}
}
impl<T> Drop for LockFuture<'_, T> {
fn drop(&mut self) {
// If the current task is still in the set, that means it is being cancelled now.
if let Some(key) = self.opt_key {
let mut blocked = self.mutex.blocked.lock().unwrap();
let opt_waker = blocked.remove(key);
if opt_waker.is_none() && !self.acquired {
// We were awoken but didn't acquire the lock. Wake up another task.
if let Some((_, opt_waker)) = blocked.iter_mut().next() {
if let Some(w) = opt_waker.take() {
w.wake();
}
}
}
if blocked.is_empty() {
self.mutex.state.fetch_and(!BLOCKED, Ordering::Relaxed);
}
self.mutex.wakers.cancel(key);
}
}
}
@ -181,7 +145,6 @@ impl<T> Mutex<T> {
LockFuture {
mutex: self,
opt_key: None,
acquired: false,
}
.await
}
@ -220,7 +183,7 @@ impl<T> Mutex<T> {
/// # })
/// ```
pub fn try_lock(&self) -> Option<MutexGuard<'_, T>> {
if self.state.fetch_or(LOCK, Ordering::Acquire) & LOCK == 0 {
if !self.locked.swap(true, Ordering::SeqCst) {
Some(MutexGuard(self))
} else {
None
@ -266,18 +229,15 @@ impl<T> Mutex<T> {
impl<T: fmt::Debug> fmt::Debug for Mutex<T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self.try_lock() {
None => {
struct LockedPlaceholder;
impl fmt::Debug for LockedPlaceholder {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str("<locked>")
}
}
f.debug_struct("Mutex")
.field("data", &LockedPlaceholder)
.finish()
struct Locked;
impl fmt::Debug for Locked {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str("<locked>")
}
}
match self.try_lock() {
None => f.debug_struct("Mutex").field("data", &Locked).finish(),
Some(guard) => f.debug_struct("Mutex").field("data", &&*guard).finish(),
}
}
@ -303,19 +263,11 @@ unsafe impl<T: Sync> Sync for MutexGuard<'_, T> {}
impl<T> Drop for MutexGuard<'_, T> {
fn drop(&mut self) {
let state = self.0.state.fetch_and(!LOCK, Ordering::AcqRel);
// If there are any blocked tasks, wake one of them up.
if state & BLOCKED != 0 {
let mut blocked = self.0.blocked.lock().unwrap();
// Use `SeqCst` ordering to synchronize with `WakerSet::insert()` and `WakerSet::update()`.
self.0.locked.store(false, Ordering::SeqCst);
if let Some((_, opt_waker)) = blocked.iter_mut().next() {
// If there is no waker in this entry, that means it was already woken.
if let Some(w) = opt_waker.take() {
w.wake();
}
}
}
// Notify one blocked `lock()` operation.
self.0.wakers.notify_one();
}
}

@ -10,7 +10,8 @@ use crate::future::Future;
use crate::task::{Context, Poll, Waker};
/// Set if a write lock is held.
const WRITE_LOCK: usize = 1;
#[allow(clippy::identity_op)]
const WRITE_LOCK: usize = 1 << 0;
/// Set if there are read operations blocked on the lock.
const BLOCKED_READS: usize = 1 << 1;

@ -0,0 +1,200 @@
//! A common utility for building synchronization primitives.
//!
//! When an async operation is blocked, it needs to register itself somewhere so that it can be
//! notified later on. The `WakerSet` type helps with keeping track of such async operations and
//! notifying them when they may make progress.
use std::cell::UnsafeCell;
use std::ops::{Deref, DerefMut};
use std::sync::atomic::{AtomicUsize, Ordering};
use crossbeam_utils::Backoff;
use slab::Slab;
use crate::task::{Context, Waker};
/// Set when the entry list is locked.
#[allow(clippy::identity_op)]
const LOCKED: usize = 1 << 0;
/// Set when there are tasks for `notify_one()` to wake.
const NOTIFY_ONE: usize = 1 << 1;
/// Set when there are tasks for `notify_all()` to wake.
const NOTIFY_ALL: usize = 1 << 2;
/// Inner representation of `WakerSet`.
struct Inner {
/// A list of entries in the set.
///
/// Each entry has an optional waker associated with the task that is executing the operation.
/// If the waker is set to `None`, that means the task has been woken up but hasn't removed
/// itself from the `WakerSet` yet.
///
/// The key of each entry is its index in the `Slab`.
entries: Slab<Option<Waker>>,
/// The number of entries that have the waker set to `None`.
none_count: usize,
}
/// A set holding wakers.
pub struct WakerSet {
/// Holds three bits: `LOCKED`, `NOTIFY_ONE`, and `NOTIFY_ALL`.
flag: AtomicUsize,
/// A set holding wakers.
inner: UnsafeCell<Inner>,
}
impl WakerSet {
/// Creates a new `WakerSet`.
#[inline]
pub fn new() -> WakerSet {
WakerSet {
flag: AtomicUsize::new(0),
inner: UnsafeCell::new(Inner {
entries: Slab::new(),
none_count: 0,
}),
}
}
/// Inserts a waker for a blocked operation and returns a key associated with it.
pub fn insert(&self, cx: &Context<'_>) -> usize {
let w = cx.waker().clone();
self.lock().entries.insert(Some(w))
}
/// Updates the waker of a previously inserted entry.
pub fn update(&self, key: usize, cx: &Context<'_>) {
let mut inner = self.lock();
match &mut inner.entries[key] {
None => {
// Fill in the waker.
let w = cx.waker().clone();
inner.entries[key] = Some(w);
inner.none_count -= 1;
}
Some(w) => {
// Replace the waker if the existing one is different.
if !w.will_wake(cx.waker()) {
*w = cx.waker().clone();
}
}
}
}
/// Removes the waker of a completed operation.
pub fn complete(&self, key: usize) {
let mut inner = self.lock();
if inner.entries.remove(key).is_none() {
inner.none_count -= 1;
}
}
/// Removes the waker of a cancelled operation.
pub fn cancel(&self, key: usize) {
let mut inner = self.lock();
if inner.entries.remove(key).is_none() {
inner.none_count -= 1;
// The operation was cancelled and notified so notify another operation instead.
if let Some((_, opt_waker)) = inner.entries.iter_mut().next() {
// If there is no waker in this entry, that means it was already woken.
if let Some(w) = opt_waker.take() {
w.wake();
inner.none_count += 1;
}
}
}
}
/// Notifies one blocked operation.
#[inline]
pub fn notify_one(&self) {
// Use `SeqCst` ordering to synchronize with `Lock::drop()`.
if self.flag.load(Ordering::SeqCst) & NOTIFY_ONE != 0 {
self.notify(false);
}
}
/// Notifies all blocked operations.
// TODO: Delete this attribute when `crate::sync::channel()` is stabilized.
#[cfg(feature = "unstable")]
#[inline]
pub fn notify_all(&self) {
// Use `SeqCst` ordering to synchronize with `Lock::drop()`.
if self.flag.load(Ordering::SeqCst) & NOTIFY_ALL != 0 {
self.notify(true);
}
}
/// Notifies blocked operations, either one or all of them.
fn notify(&self, all: bool) {
let mut inner = &mut *self.lock();
for (_, opt_waker) in inner.entries.iter_mut() {
// If there is no waker in this entry, that means it was already woken.
if let Some(w) = opt_waker.take() {
w.wake();
inner.none_count += 1;
}
if !all {
break;
}
}
}
/// Locks the list of entries.
#[cold]
fn lock(&self) -> Lock<'_> {
let backoff = Backoff::new();
while self.flag.fetch_or(LOCKED, Ordering::Acquire) & LOCKED != 0 {
backoff.snooze();
}
Lock { waker_set: self }
}
}
/// A guard holding a `WakerSet` locked.
struct Lock<'a> {
waker_set: &'a WakerSet,
}
impl Drop for Lock<'_> {
#[inline]
fn drop(&mut self) {
let mut flag = 0;
// If there is at least one entry and all are `Some`, then `notify_one()` has work to do.
if !self.entries.is_empty() && self.none_count == 0 {
flag |= NOTIFY_ONE;
}
// If there is at least one `Some` entry, then `notify_all()` has work to do.
if self.entries.len() - self.none_count > 0 {
flag |= NOTIFY_ALL;
}
// Use `SeqCst` ordering to synchronize with `WakerSet::lock_to_notify()`.
self.waker_set.flag.store(flag, Ordering::SeqCst);
}
}
impl Deref for Lock<'_> {
type Target = Inner;
#[inline]
fn deref(&self) -> &Inner {
unsafe { &*self.waker_set.inner.get() }
}
}
impl DerefMut for Lock<'_> {
#[inline]
fn deref_mut(&mut self) -> &mut Inner {
unsafe { &mut *self.waker_set.inner.get() }
}
}
Loading…
Cancel
Save