From 87de4e15982bbec890aa909600da5997c4f6edc7 Mon Sep 17 00:00:00 2001 From: Stjepan Glavina Date: Fri, 1 Nov 2019 02:45:50 +0100 Subject: [PATCH] Add utility type WakerSet to the sync module (#390) * Add utility type Registry to the sync module * Remove unused import * Split unregister into complete and cancel * Refactoring and renaming * Split remove() into complete() and cancel() * Rename to WakerSet * Ignore clippy warning * Ignore another clippy warning * Use stronger SeqCst ordering --- benches/mutex.rs | 42 ++++++ benches/task.rs | 11 ++ src/sync/channel.rs | 336 ++++++++++-------------------------------- src/sync/mod.rs | 3 + src/sync/mutex.rs | 124 +++++----------- src/sync/rwlock.rs | 3 +- src/sync/waker_set.rs | 200 +++++++++++++++++++++++++ 7 files changed, 371 insertions(+), 348 deletions(-) create mode 100644 benches/mutex.rs create mode 100644 benches/task.rs create mode 100644 src/sync/waker_set.rs diff --git a/benches/mutex.rs b/benches/mutex.rs new file mode 100644 index 00000000..b159ba12 --- /dev/null +++ b/benches/mutex.rs @@ -0,0 +1,42 @@ +#![feature(test)] + +extern crate test; + +use std::sync::Arc; + +use async_std::sync::Mutex; +use async_std::task; +use test::Bencher; + +#[bench] +fn create(b: &mut Bencher) { + b.iter(|| Mutex::new(())); +} + +#[bench] +fn contention(b: &mut Bencher) { + b.iter(|| task::block_on(run(10, 1000))); +} + +#[bench] +fn no_contention(b: &mut Bencher) { + b.iter(|| task::block_on(run(1, 10000))); +} + +async fn run(task: usize, iter: usize) { + let m = Arc::new(Mutex::new(())); + let mut tasks = Vec::new(); + + for _ in 0..task { + let m = m.clone(); + tasks.push(task::spawn(async move { + for _ in 0..iter { + let _ = m.lock().await; + } + })); + } + + for t in tasks { + t.await; + } +} diff --git a/benches/task.rs b/benches/task.rs new file mode 100644 index 00000000..b3144713 --- /dev/null +++ b/benches/task.rs @@ -0,0 +1,11 @@ +#![feature(test)] + +extern crate test; + +use async_std::task; +use test::Bencher; + +#[bench] +fn block_on(b: &mut Bencher) { + b.iter(|| task::block_on(async {})); +} diff --git a/src/sync/channel.rs b/src/sync/channel.rs index 6751417a..403bee74 100644 --- a/src/sync/channel.rs +++ b/src/sync/channel.rs @@ -4,17 +4,17 @@ use std::future::Future; use std::isize; use std::marker::PhantomData; use std::mem; -use std::ops::{Deref, DerefMut}; use std::pin::Pin; use std::process; use std::ptr; -use std::sync::atomic::{self, AtomicBool, AtomicUsize, Ordering}; +use std::sync::atomic::{self, AtomicUsize, Ordering}; use std::sync::Arc; -use std::task::{Context, Poll, Waker}; +use std::task::{Context, Poll}; use crossbeam_utils::{Backoff, CachePadded}; -use futures_core::stream::Stream; -use slab::Slab; + +use crate::stream::Stream; +use crate::sync::WakerSet; /// Creates a bounded multi-producer multi-consumer channel. /// @@ -128,7 +128,7 @@ impl Sender { /// ``` pub async fn send(&self, msg: T) { struct SendFuture<'a, T> { - sender: &'a Sender, + channel: &'a Channel, msg: Option, opt_key: Option, } @@ -142,23 +142,23 @@ impl Sender { let msg = self.msg.take().unwrap(); // Try sending the message. - let poll = match self.sender.channel.push(msg) { + let poll = match self.channel.try_send(msg) { Ok(()) => Poll::Ready(()), - Err(PushError::Disconnected(msg)) => { + Err(TrySendError::Disconnected(msg)) => { self.msg = Some(msg); Poll::Pending } - Err(PushError::Full(msg)) => { - // Register the current task. + Err(TrySendError::Full(msg)) => { + // Insert this send operation. match self.opt_key { - None => self.opt_key = Some(self.sender.channel.sends.register(cx)), - Some(key) => self.sender.channel.sends.reregister(key, cx), + None => self.opt_key = Some(self.channel.send_wakers.insert(cx)), + Some(key) => self.channel.send_wakers.update(key, cx), } // Try sending the message again. - match self.sender.channel.push(msg) { + match self.channel.try_send(msg) { Ok(()) => Poll::Ready(()), - Err(PushError::Disconnected(msg)) | Err(PushError::Full(msg)) => { + Err(TrySendError::Disconnected(msg)) | Err(TrySendError::Full(msg)) => { self.msg = Some(msg); Poll::Pending } @@ -167,10 +167,9 @@ impl Sender { }; if poll.is_ready() { - // If the current task was registered, unregister now. + // If the current task is in the set, remove it. if let Some(key) = self.opt_key.take() { - // `true` means the send operation is completed. - self.sender.channel.sends.unregister(key, true); + self.channel.send_wakers.complete(key); } } @@ -180,16 +179,16 @@ impl Sender { impl Drop for SendFuture<'_, T> { fn drop(&mut self) { - // If the current task was registered, unregister now. + // If the current task is still in the set, that means it is being cancelled now. + // Wake up another task instead. if let Some(key) = self.opt_key { - // `false` means the send operation is cancelled. - self.sender.channel.sends.unregister(key, false); + self.channel.send_wakers.cancel(key); } } } SendFuture { - sender: self, + channel: &self.channel, msg: Some(msg), opt_key: None, } @@ -340,7 +339,7 @@ pub struct Receiver { /// The inner channel. channel: Arc>, - /// The registration key for this receiver in the `channel.streams` registry. + /// The key for this receiver in the `channel.stream_wakers` set. opt_key: Option, } @@ -382,16 +381,20 @@ impl Receiver { type Output = Option; fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - poll_recv(&self.channel, &self.channel.recvs, &mut self.opt_key, cx) + poll_recv( + &self.channel, + &self.channel.recv_wakers, + &mut self.opt_key, + cx, + ) } } impl Drop for RecvFuture<'_, T> { fn drop(&mut self) { - // If the current task was registered, unregister now. + // If the current task is still in the set, that means it is being cancelled now. if let Some(key) = self.opt_key { - // `false` means the receive operation is cancelled. - self.channel.recvs.unregister(key, false); + self.channel.recv_wakers.cancel(key); } } } @@ -484,10 +487,9 @@ impl Receiver { impl Drop for Receiver { fn drop(&mut self) { - // If the current task was registered as blocked on this stream, unregister now. + // If the current task is still in the stream set, that means it is being cancelled now. if let Some(key) = self.opt_key { - // `false` means the last request for a stream item is cancelled. - self.channel.streams.unregister(key, false); + self.channel.stream_wakers.cancel(key); } // Decrement the receiver count and disconnect the channel if it drops down to zero. @@ -518,7 +520,12 @@ impl Stream for Receiver { fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { let this = &mut *self; - poll_recv(&this.channel, &this.channel.streams, &mut this.opt_key, cx) + poll_recv( + &this.channel, + &this.channel.stream_wakers, + &mut this.opt_key, + cx, + ) } } @@ -530,39 +537,38 @@ impl fmt::Debug for Receiver { /// Polls a receive operation on a channel. /// -/// If the receive operation is blocked, the current task will be registered in `registry` and its -/// registration key will then be stored in `opt_key`. +/// If the receive operation is blocked, the current task will be inserted into `wakers` and its +/// associated key will then be stored in `opt_key`. fn poll_recv( channel: &Channel, - registry: &Registry, + wakers: &WakerSet, opt_key: &mut Option, cx: &mut Context<'_>, ) -> Poll> { // Try receiving a message. - let poll = match channel.pop() { + let poll = match channel.try_recv() { Ok(msg) => Poll::Ready(Some(msg)), - Err(PopError::Disconnected) => Poll::Ready(None), - Err(PopError::Empty) => { - // Register the current task. + Err(TryRecvError::Disconnected) => Poll::Ready(None), + Err(TryRecvError::Empty) => { + // Insert this receive operation. match *opt_key { - None => *opt_key = Some(registry.register(cx)), - Some(key) => registry.reregister(key, cx), + None => *opt_key = Some(wakers.insert(cx)), + Some(key) => wakers.update(key, cx), } // Try receiving a message again. - match channel.pop() { + match channel.try_recv() { Ok(msg) => Poll::Ready(Some(msg)), - Err(PopError::Disconnected) => Poll::Ready(None), - Err(PopError::Empty) => Poll::Pending, + Err(TryRecvError::Disconnected) => Poll::Ready(None), + Err(TryRecvError::Empty) => Poll::Pending, } } }; if poll.is_ready() { - // If the current task was registered, unregister now. + // If the current task is in the set, remove it. if let Some(key) = opt_key.take() { - // `true` means the receive operation is completed. - registry.unregister(key, true); + wakers.complete(key); } } @@ -612,13 +618,13 @@ struct Channel { mark_bit: usize, /// Send operations waiting while the channel is full. - sends: Registry, + send_wakers: WakerSet, /// Receive operations waiting while the channel is empty and not disconnected. - recvs: Registry, + recv_wakers: WakerSet, /// Streams waiting while the channel is empty and not disconnected. - streams: Registry, + stream_wakers: WakerSet, /// The number of currently active `Sender`s. sender_count: AtomicUsize, @@ -672,17 +678,17 @@ impl Channel { mark_bit, head: CachePadded::new(AtomicUsize::new(head)), tail: CachePadded::new(AtomicUsize::new(tail)), - sends: Registry::new(), - recvs: Registry::new(), - streams: Registry::new(), + send_wakers: WakerSet::new(), + recv_wakers: WakerSet::new(), + stream_wakers: WakerSet::new(), sender_count: AtomicUsize::new(1), receiver_count: AtomicUsize::new(1), _marker: PhantomData, } } - /// Attempts to push a message. - fn push(&self, msg: T) -> Result<(), PushError> { + /// Attempts to send a message. + fn try_send(&self, msg: T) -> Result<(), TrySendError> { let backoff = Backoff::new(); let mut tail = self.tail.load(Ordering::Relaxed); @@ -721,10 +727,10 @@ impl Channel { slot.stamp.store(stamp, Ordering::Release); // Wake a blocked receive operation. - self.recvs.notify_one(); + self.recv_wakers.notify_one(); // Wake all blocked streams. - self.streams.notify_all(); + self.stream_wakers.notify_all(); return Ok(()); } @@ -743,9 +749,9 @@ impl Channel { // Check if the channel is disconnected. if tail & self.mark_bit != 0 { - return Err(PushError::Disconnected(msg)); + return Err(TrySendError::Disconnected(msg)); } else { - return Err(PushError::Full(msg)); + return Err(TrySendError::Full(msg)); } } @@ -759,8 +765,8 @@ impl Channel { } } - /// Attempts to pop a message. - fn pop(&self) -> Result { + /// Attempts to receive a message. + fn try_recv(&self) -> Result { let backoff = Backoff::new(); let mut head = self.head.load(Ordering::Relaxed); @@ -799,7 +805,7 @@ impl Channel { slot.stamp.store(stamp, Ordering::Release); // Wake a blocked send operation. - self.sends.notify_one(); + self.send_wakers.notify_one(); return Ok(msg); } @@ -816,10 +822,10 @@ impl Channel { if (tail & !self.mark_bit) == head { // If the channel is disconnected... if tail & self.mark_bit != 0 { - return Err(PopError::Disconnected); + return Err(TryRecvError::Disconnected); } else { // Otherwise, the receive operation is not ready. - return Err(PopError::Empty); + return Err(TryRecvError::Empty); } } @@ -888,9 +894,9 @@ impl Channel { if tail & self.mark_bit == 0 { // Notify everyone blocked on this channel. - self.sends.notify_all(); - self.recvs.notify_all(); - self.streams.notify_all(); + self.send_wakers.notify_all(); + self.recv_wakers.notify_all(); + self.stream_wakers.notify_all(); } } } @@ -921,8 +927,8 @@ impl Drop for Channel { } } -/// An error returned from the `push()` method. -enum PushError { +/// An error returned from the `try_send()` method. +enum TrySendError { /// The channel is full but not disconnected. Full(T), @@ -930,203 +936,11 @@ enum PushError { Disconnected(T), } -/// An error returned from the `pop()` method. -enum PopError { +/// An error returned from the `try_recv()` method. +enum TryRecvError { /// The channel is empty but not disconnected. Empty, /// The channel is empty and disconnected. Disconnected, } - -/// A list of blocked channel operations. -struct Blocked { - /// A list of registered channel operations. - /// - /// Each entry has a waker associated with the task that is executing the operation. If the - /// waker is set to `None`, that means the task has been woken up but hasn't removed itself - /// from the registry yet. - entries: Slab>, - - /// The number of wakers in the entry list. - waker_count: usize, -} - -/// A registry of blocked channel operations. -/// -/// Blocked operations register themselves in a registry. Successful operations on the opposite -/// side of the channel wake blocked operations in the registry. -struct Registry { - /// A list of blocked channel operations. - blocked: Spinlock, - - /// Set to `true` if there are no wakers in the registry. - /// - /// Note that this either means there are no entries in the registry, or that all entries have - /// been notified. - is_empty: AtomicBool, -} - -impl Registry { - /// Creates a new registry. - fn new() -> Registry { - Registry { - blocked: Spinlock::new(Blocked { - entries: Slab::new(), - waker_count: 0, - }), - is_empty: AtomicBool::new(true), - } - } - - /// Registers a blocked channel operation and returns a key associated with it. - fn register(&self, cx: &Context<'_>) -> usize { - let mut blocked = self.blocked.lock(); - - // Insert a new entry into the list of blocked tasks. - let w = cx.waker().clone(); - let key = blocked.entries.insert(Some(w)); - - blocked.waker_count += 1; - if blocked.waker_count == 1 { - self.is_empty.store(false, Ordering::SeqCst); - } - - key - } - - /// Re-registers a blocked channel operation by filling in its waker. - fn reregister(&self, key: usize, cx: &Context<'_>) { - let mut blocked = self.blocked.lock(); - - let was_none = blocked.entries[key].is_none(); - let w = cx.waker().clone(); - blocked.entries[key] = Some(w); - - if was_none { - blocked.waker_count += 1; - if blocked.waker_count == 1 { - self.is_empty.store(false, Ordering::SeqCst); - } - } - } - - /// Unregisters a channel operation. - /// - /// If `completed` is `true`, the operation will be removed from the registry. If `completed` - /// is `false`, that means the operation was cancelled so another one will be notified. - fn unregister(&self, key: usize, completed: bool) { - let mut blocked = self.blocked.lock(); - let mut removed = false; - - match blocked.entries.remove(key) { - Some(_) => removed = true, - None => { - if !completed { - // This operation was cancelled. Notify another one. - if let Some((_, opt_waker)) = blocked.entries.iter_mut().next() { - if let Some(w) = opt_waker.take() { - w.wake(); - removed = true; - } - } - } - } - } - - if removed { - blocked.waker_count -= 1; - if blocked.waker_count == 0 { - self.is_empty.store(true, Ordering::SeqCst); - } - } - } - - /// Notifies one blocked channel operation. - #[inline] - fn notify_one(&self) { - if !self.is_empty.load(Ordering::SeqCst) { - let mut blocked = self.blocked.lock(); - - if let Some((_, opt_waker)) = blocked.entries.iter_mut().next() { - // If there is no waker in this entry, that means it was already woken. - if let Some(w) = opt_waker.take() { - w.wake(); - - blocked.waker_count -= 1; - if blocked.waker_count == 0 { - self.is_empty.store(true, Ordering::SeqCst); - } - } - } - } - } - - /// Notifies all blocked channel operations. - #[inline] - fn notify_all(&self) { - if !self.is_empty.load(Ordering::SeqCst) { - let mut blocked = self.blocked.lock(); - - for (_, opt_waker) in blocked.entries.iter_mut() { - // If there is no waker in this entry, that means it was already woken. - if let Some(w) = opt_waker.take() { - w.wake(); - } - } - - blocked.waker_count = 0; - self.is_empty.store(true, Ordering::SeqCst); - } - } -} - -/// A simple spinlock. -struct Spinlock { - flag: AtomicBool, - value: UnsafeCell, -} - -impl Spinlock { - /// Returns a new spinlock initialized with `value`. - fn new(value: T) -> Spinlock { - Spinlock { - flag: AtomicBool::new(false), - value: UnsafeCell::new(value), - } - } - - /// Locks the spinlock. - fn lock(&self) -> SpinlockGuard<'_, T> { - let backoff = Backoff::new(); - while self.flag.swap(true, Ordering::Acquire) { - backoff.snooze(); - } - SpinlockGuard { parent: self } - } -} - -/// A guard holding a spinlock locked. -struct SpinlockGuard<'a, T> { - parent: &'a Spinlock, -} - -impl<'a, T> Drop for SpinlockGuard<'a, T> { - fn drop(&mut self) { - self.parent.flag.store(false, Ordering::Release); - } -} - -impl<'a, T> Deref for SpinlockGuard<'a, T> { - type Target = T; - - fn deref(&self) -> &T { - unsafe { &*self.parent.value.get() } - } -} - -impl<'a, T> DerefMut for SpinlockGuard<'a, T> { - fn deref_mut(&mut self) -> &mut T { - unsafe { &mut *self.parent.value.get() } - } -} diff --git a/src/sync/mod.rs b/src/sync/mod.rs index d10e6bdf..ea991fe6 100644 --- a/src/sync/mod.rs +++ b/src/sync/mod.rs @@ -191,3 +191,6 @@ cfg_unstable! { mod barrier; mod channel; } + +pub(crate) mod waker_set; +pub(crate) use waker_set::WakerSet; diff --git a/src/sync/mutex.rs b/src/sync/mutex.rs index cd7a3577..fcd030d8 100644 --- a/src/sync/mutex.rs +++ b/src/sync/mutex.rs @@ -2,18 +2,11 @@ use std::cell::UnsafeCell; use std::fmt; use std::ops::{Deref, DerefMut}; use std::pin::Pin; -use std::sync::atomic::{AtomicUsize, Ordering}; - -use slab::Slab; +use std::sync::atomic::{AtomicBool, Ordering}; use crate::future::Future; -use crate::task::{Context, Poll, Waker}; - -/// Set if the mutex is locked. -const LOCK: usize = 1; - -/// Set if there are tasks blocked on the mutex. -const BLOCKED: usize = 1 << 1; +use crate::sync::WakerSet; +use crate::task::{Context, Poll}; /// A mutual exclusion primitive for protecting shared data. /// @@ -49,8 +42,8 @@ const BLOCKED: usize = 1 << 1; /// # }) /// ``` pub struct Mutex { - state: AtomicUsize, - blocked: std::sync::Mutex>>, + locked: AtomicBool, + wakers: WakerSet, value: UnsafeCell, } @@ -69,8 +62,8 @@ impl Mutex { /// ``` pub fn new(t: T) -> Mutex { Mutex { - state: AtomicUsize::new(0), - blocked: std::sync::Mutex::new(Slab::new()), + locked: AtomicBool::new(false), + wakers: WakerSet::new(), value: UnsafeCell::new(t), } } @@ -105,75 +98,46 @@ impl Mutex { pub struct LockFuture<'a, T> { mutex: &'a Mutex, opt_key: Option, - acquired: bool, } impl<'a, T> Future for LockFuture<'a, T> { type Output = MutexGuard<'a, T>; fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - match self.mutex.try_lock() { - Some(guard) => { - self.acquired = true; - Poll::Ready(guard) - } + let poll = match self.mutex.try_lock() { + Some(guard) => Poll::Ready(guard), None => { - let mut blocked = self.mutex.blocked.lock().unwrap(); - - // Register the current task. + // Insert this lock operation. match self.opt_key { - None => { - // Insert a new entry into the list of blocked tasks. - let w = cx.waker().clone(); - let key = blocked.insert(Some(w)); - self.opt_key = Some(key); - - if blocked.len() == 1 { - self.mutex.state.fetch_or(BLOCKED, Ordering::Relaxed); - } - } - Some(key) => { - // There is already an entry in the list of blocked tasks. Just - // reset the waker if it was removed. - if blocked[key].is_none() { - let w = cx.waker().clone(); - blocked[key] = Some(w); - } - } + None => self.opt_key = Some(self.mutex.wakers.insert(cx)), + Some(key) => self.mutex.wakers.update(key, cx), } // Try locking again because it's possible the mutex got unlocked just - // before the current task was registered as a blocked task. + // before the current task was inserted into the waker set. match self.mutex.try_lock() { - Some(guard) => { - self.acquired = true; - Poll::Ready(guard) - } + Some(guard) => Poll::Ready(guard), None => Poll::Pending, } } + }; + + if poll.is_ready() { + // If the current task is in the set, remove it. + if let Some(key) = self.opt_key.take() { + self.mutex.wakers.complete(key); + } } + + poll } } impl Drop for LockFuture<'_, T> { fn drop(&mut self) { + // If the current task is still in the set, that means it is being cancelled now. if let Some(key) = self.opt_key { - let mut blocked = self.mutex.blocked.lock().unwrap(); - let opt_waker = blocked.remove(key); - - if opt_waker.is_none() && !self.acquired { - // We were awoken but didn't acquire the lock. Wake up another task. - if let Some((_, opt_waker)) = blocked.iter_mut().next() { - if let Some(w) = opt_waker.take() { - w.wake(); - } - } - } - - if blocked.is_empty() { - self.mutex.state.fetch_and(!BLOCKED, Ordering::Relaxed); - } + self.mutex.wakers.cancel(key); } } } @@ -181,7 +145,6 @@ impl Mutex { LockFuture { mutex: self, opt_key: None, - acquired: false, } .await } @@ -220,7 +183,7 @@ impl Mutex { /// # }) /// ``` pub fn try_lock(&self) -> Option> { - if self.state.fetch_or(LOCK, Ordering::Acquire) & LOCK == 0 { + if !self.locked.swap(true, Ordering::SeqCst) { Some(MutexGuard(self)) } else { None @@ -266,18 +229,15 @@ impl Mutex { impl fmt::Debug for Mutex { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - match self.try_lock() { - None => { - struct LockedPlaceholder; - impl fmt::Debug for LockedPlaceholder { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.write_str("") - } - } - f.debug_struct("Mutex") - .field("data", &LockedPlaceholder) - .finish() + struct Locked; + impl fmt::Debug for Locked { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.write_str("") } + } + + match self.try_lock() { + None => f.debug_struct("Mutex").field("data", &Locked).finish(), Some(guard) => f.debug_struct("Mutex").field("data", &&*guard).finish(), } } @@ -303,19 +263,11 @@ unsafe impl Sync for MutexGuard<'_, T> {} impl Drop for MutexGuard<'_, T> { fn drop(&mut self) { - let state = self.0.state.fetch_and(!LOCK, Ordering::AcqRel); - - // If there are any blocked tasks, wake one of them up. - if state & BLOCKED != 0 { - let mut blocked = self.0.blocked.lock().unwrap(); + // Use `SeqCst` ordering to synchronize with `WakerSet::insert()` and `WakerSet::update()`. + self.0.locked.store(false, Ordering::SeqCst); - if let Some((_, opt_waker)) = blocked.iter_mut().next() { - // If there is no waker in this entry, that means it was already woken. - if let Some(w) = opt_waker.take() { - w.wake(); - } - } - } + // Notify one blocked `lock()` operation. + self.0.wakers.notify_one(); } } diff --git a/src/sync/rwlock.rs b/src/sync/rwlock.rs index ed1d2185..a0d0f07a 100644 --- a/src/sync/rwlock.rs +++ b/src/sync/rwlock.rs @@ -10,7 +10,8 @@ use crate::future::Future; use crate::task::{Context, Poll, Waker}; /// Set if a write lock is held. -const WRITE_LOCK: usize = 1; +#[allow(clippy::identity_op)] +const WRITE_LOCK: usize = 1 << 0; /// Set if there are read operations blocked on the lock. const BLOCKED_READS: usize = 1 << 1; diff --git a/src/sync/waker_set.rs b/src/sync/waker_set.rs new file mode 100644 index 00000000..8fd1b621 --- /dev/null +++ b/src/sync/waker_set.rs @@ -0,0 +1,200 @@ +//! A common utility for building synchronization primitives. +//! +//! When an async operation is blocked, it needs to register itself somewhere so that it can be +//! notified later on. The `WakerSet` type helps with keeping track of such async operations and +//! notifying them when they may make progress. + +use std::cell::UnsafeCell; +use std::ops::{Deref, DerefMut}; +use std::sync::atomic::{AtomicUsize, Ordering}; + +use crossbeam_utils::Backoff; +use slab::Slab; + +use crate::task::{Context, Waker}; + +/// Set when the entry list is locked. +#[allow(clippy::identity_op)] +const LOCKED: usize = 1 << 0; + +/// Set when there are tasks for `notify_one()` to wake. +const NOTIFY_ONE: usize = 1 << 1; + +/// Set when there are tasks for `notify_all()` to wake. +const NOTIFY_ALL: usize = 1 << 2; + +/// Inner representation of `WakerSet`. +struct Inner { + /// A list of entries in the set. + /// + /// Each entry has an optional waker associated with the task that is executing the operation. + /// If the waker is set to `None`, that means the task has been woken up but hasn't removed + /// itself from the `WakerSet` yet. + /// + /// The key of each entry is its index in the `Slab`. + entries: Slab>, + + /// The number of entries that have the waker set to `None`. + none_count: usize, +} + +/// A set holding wakers. +pub struct WakerSet { + /// Holds three bits: `LOCKED`, `NOTIFY_ONE`, and `NOTIFY_ALL`. + flag: AtomicUsize, + + /// A set holding wakers. + inner: UnsafeCell, +} + +impl WakerSet { + /// Creates a new `WakerSet`. + #[inline] + pub fn new() -> WakerSet { + WakerSet { + flag: AtomicUsize::new(0), + inner: UnsafeCell::new(Inner { + entries: Slab::new(), + none_count: 0, + }), + } + } + + /// Inserts a waker for a blocked operation and returns a key associated with it. + pub fn insert(&self, cx: &Context<'_>) -> usize { + let w = cx.waker().clone(); + self.lock().entries.insert(Some(w)) + } + + /// Updates the waker of a previously inserted entry. + pub fn update(&self, key: usize, cx: &Context<'_>) { + let mut inner = self.lock(); + + match &mut inner.entries[key] { + None => { + // Fill in the waker. + let w = cx.waker().clone(); + inner.entries[key] = Some(w); + inner.none_count -= 1; + } + Some(w) => { + // Replace the waker if the existing one is different. + if !w.will_wake(cx.waker()) { + *w = cx.waker().clone(); + } + } + } + } + + /// Removes the waker of a completed operation. + pub fn complete(&self, key: usize) { + let mut inner = self.lock(); + if inner.entries.remove(key).is_none() { + inner.none_count -= 1; + } + } + + /// Removes the waker of a cancelled operation. + pub fn cancel(&self, key: usize) { + let mut inner = self.lock(); + if inner.entries.remove(key).is_none() { + inner.none_count -= 1; + + // The operation was cancelled and notified so notify another operation instead. + if let Some((_, opt_waker)) = inner.entries.iter_mut().next() { + // If there is no waker in this entry, that means it was already woken. + if let Some(w) = opt_waker.take() { + w.wake(); + inner.none_count += 1; + } + } + } + } + + /// Notifies one blocked operation. + #[inline] + pub fn notify_one(&self) { + // Use `SeqCst` ordering to synchronize with `Lock::drop()`. + if self.flag.load(Ordering::SeqCst) & NOTIFY_ONE != 0 { + self.notify(false); + } + } + + /// Notifies all blocked operations. + // TODO: Delete this attribute when `crate::sync::channel()` is stabilized. + #[cfg(feature = "unstable")] + #[inline] + pub fn notify_all(&self) { + // Use `SeqCst` ordering to synchronize with `Lock::drop()`. + if self.flag.load(Ordering::SeqCst) & NOTIFY_ALL != 0 { + self.notify(true); + } + } + + /// Notifies blocked operations, either one or all of them. + fn notify(&self, all: bool) { + let mut inner = &mut *self.lock(); + + for (_, opt_waker) in inner.entries.iter_mut() { + // If there is no waker in this entry, that means it was already woken. + if let Some(w) = opt_waker.take() { + w.wake(); + inner.none_count += 1; + } + if !all { + break; + } + } + } + + /// Locks the list of entries. + #[cold] + fn lock(&self) -> Lock<'_> { + let backoff = Backoff::new(); + while self.flag.fetch_or(LOCKED, Ordering::Acquire) & LOCKED != 0 { + backoff.snooze(); + } + Lock { waker_set: self } + } +} + +/// A guard holding a `WakerSet` locked. +struct Lock<'a> { + waker_set: &'a WakerSet, +} + +impl Drop for Lock<'_> { + #[inline] + fn drop(&mut self) { + let mut flag = 0; + + // If there is at least one entry and all are `Some`, then `notify_one()` has work to do. + if !self.entries.is_empty() && self.none_count == 0 { + flag |= NOTIFY_ONE; + } + + // If there is at least one `Some` entry, then `notify_all()` has work to do. + if self.entries.len() - self.none_count > 0 { + flag |= NOTIFY_ALL; + } + + // Use `SeqCst` ordering to synchronize with `WakerSet::lock_to_notify()`. + self.waker_set.flag.store(flag, Ordering::SeqCst); + } +} + +impl Deref for Lock<'_> { + type Target = Inner; + + #[inline] + fn deref(&self) -> &Inner { + unsafe { &*self.waker_set.inner.get() } + } +} + +impl DerefMut for Lock<'_> { + #[inline] + fn deref_mut(&mut self) -> &mut Inner { + unsafe { &mut *self.waker_set.inner.get() } + } +}