Add utility type WakerSet to the sync module (#390)

* Add utility type Registry to the sync module

* Remove unused import

* Split unregister into complete and cancel

* Refactoring and renaming

* Split remove() into complete() and cancel()

* Rename to WakerSet

* Ignore clippy warning

* Ignore another clippy warning

* Use stronger SeqCst ordering
pull/433/head
Stjepan Glavina 5 years ago committed by GitHub
parent 3dd59d7056
commit 87de4e1598
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -0,0 +1,42 @@
#![feature(test)]
extern crate test;
use std::sync::Arc;
use async_std::sync::Mutex;
use async_std::task;
use test::Bencher;
#[bench]
fn create(b: &mut Bencher) {
b.iter(|| Mutex::new(()));
}
#[bench]
fn contention(b: &mut Bencher) {
b.iter(|| task::block_on(run(10, 1000)));
}
#[bench]
fn no_contention(b: &mut Bencher) {
b.iter(|| task::block_on(run(1, 10000)));
}
async fn run(task: usize, iter: usize) {
let m = Arc::new(Mutex::new(()));
let mut tasks = Vec::new();
for _ in 0..task {
let m = m.clone();
tasks.push(task::spawn(async move {
for _ in 0..iter {
let _ = m.lock().await;
}
}));
}
for t in tasks {
t.await;
}
}

@ -0,0 +1,11 @@
#![feature(test)]
extern crate test;
use async_std::task;
use test::Bencher;
#[bench]
fn block_on(b: &mut Bencher) {
b.iter(|| task::block_on(async {}));
}

@ -4,17 +4,17 @@ use std::future::Future;
use std::isize; use std::isize;
use std::marker::PhantomData; use std::marker::PhantomData;
use std::mem; use std::mem;
use std::ops::{Deref, DerefMut};
use std::pin::Pin; use std::pin::Pin;
use std::process; use std::process;
use std::ptr; use std::ptr;
use std::sync::atomic::{self, AtomicBool, AtomicUsize, Ordering}; use std::sync::atomic::{self, AtomicUsize, Ordering};
use std::sync::Arc; use std::sync::Arc;
use std::task::{Context, Poll, Waker}; use std::task::{Context, Poll};
use crossbeam_utils::{Backoff, CachePadded}; use crossbeam_utils::{Backoff, CachePadded};
use futures_core::stream::Stream;
use slab::Slab; use crate::stream::Stream;
use crate::sync::WakerSet;
/// Creates a bounded multi-producer multi-consumer channel. /// Creates a bounded multi-producer multi-consumer channel.
/// ///
@ -128,7 +128,7 @@ impl<T> Sender<T> {
/// ``` /// ```
pub async fn send(&self, msg: T) { pub async fn send(&self, msg: T) {
struct SendFuture<'a, T> { struct SendFuture<'a, T> {
sender: &'a Sender<T>, channel: &'a Channel<T>,
msg: Option<T>, msg: Option<T>,
opt_key: Option<usize>, opt_key: Option<usize>,
} }
@ -142,23 +142,23 @@ impl<T> Sender<T> {
let msg = self.msg.take().unwrap(); let msg = self.msg.take().unwrap();
// Try sending the message. // Try sending the message.
let poll = match self.sender.channel.push(msg) { let poll = match self.channel.try_send(msg) {
Ok(()) => Poll::Ready(()), Ok(()) => Poll::Ready(()),
Err(PushError::Disconnected(msg)) => { Err(TrySendError::Disconnected(msg)) => {
self.msg = Some(msg); self.msg = Some(msg);
Poll::Pending Poll::Pending
} }
Err(PushError::Full(msg)) => { Err(TrySendError::Full(msg)) => {
// Register the current task. // Insert this send operation.
match self.opt_key { match self.opt_key {
None => self.opt_key = Some(self.sender.channel.sends.register(cx)), None => self.opt_key = Some(self.channel.send_wakers.insert(cx)),
Some(key) => self.sender.channel.sends.reregister(key, cx), Some(key) => self.channel.send_wakers.update(key, cx),
} }
// Try sending the message again. // Try sending the message again.
match self.sender.channel.push(msg) { match self.channel.try_send(msg) {
Ok(()) => Poll::Ready(()), Ok(()) => Poll::Ready(()),
Err(PushError::Disconnected(msg)) | Err(PushError::Full(msg)) => { Err(TrySendError::Disconnected(msg)) | Err(TrySendError::Full(msg)) => {
self.msg = Some(msg); self.msg = Some(msg);
Poll::Pending Poll::Pending
} }
@ -167,10 +167,9 @@ impl<T> Sender<T> {
}; };
if poll.is_ready() { if poll.is_ready() {
// If the current task was registered, unregister now. // If the current task is in the set, remove it.
if let Some(key) = self.opt_key.take() { if let Some(key) = self.opt_key.take() {
// `true` means the send operation is completed. self.channel.send_wakers.complete(key);
self.sender.channel.sends.unregister(key, true);
} }
} }
@ -180,16 +179,16 @@ impl<T> Sender<T> {
impl<T> Drop for SendFuture<'_, T> { impl<T> Drop for SendFuture<'_, T> {
fn drop(&mut self) { fn drop(&mut self) {
// If the current task was registered, unregister now. // If the current task is still in the set, that means it is being cancelled now.
// Wake up another task instead.
if let Some(key) = self.opt_key { if let Some(key) = self.opt_key {
// `false` means the send operation is cancelled. self.channel.send_wakers.cancel(key);
self.sender.channel.sends.unregister(key, false);
} }
} }
} }
SendFuture { SendFuture {
sender: self, channel: &self.channel,
msg: Some(msg), msg: Some(msg),
opt_key: None, opt_key: None,
} }
@ -340,7 +339,7 @@ pub struct Receiver<T> {
/// The inner channel. /// The inner channel.
channel: Arc<Channel<T>>, channel: Arc<Channel<T>>,
/// The registration key for this receiver in the `channel.streams` registry. /// The key for this receiver in the `channel.stream_wakers` set.
opt_key: Option<usize>, opt_key: Option<usize>,
} }
@ -382,16 +381,20 @@ impl<T> Receiver<T> {
type Output = Option<T>; type Output = Option<T>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
poll_recv(&self.channel, &self.channel.recvs, &mut self.opt_key, cx) poll_recv(
&self.channel,
&self.channel.recv_wakers,
&mut self.opt_key,
cx,
)
} }
} }
impl<T> Drop for RecvFuture<'_, T> { impl<T> Drop for RecvFuture<'_, T> {
fn drop(&mut self) { fn drop(&mut self) {
// If the current task was registered, unregister now. // If the current task is still in the set, that means it is being cancelled now.
if let Some(key) = self.opt_key { if let Some(key) = self.opt_key {
// `false` means the receive operation is cancelled. self.channel.recv_wakers.cancel(key);
self.channel.recvs.unregister(key, false);
} }
} }
} }
@ -484,10 +487,9 @@ impl<T> Receiver<T> {
impl<T> Drop for Receiver<T> { impl<T> Drop for Receiver<T> {
fn drop(&mut self) { fn drop(&mut self) {
// If the current task was registered as blocked on this stream, unregister now. // If the current task is still in the stream set, that means it is being cancelled now.
if let Some(key) = self.opt_key { if let Some(key) = self.opt_key {
// `false` means the last request for a stream item is cancelled. self.channel.stream_wakers.cancel(key);
self.channel.streams.unregister(key, false);
} }
// Decrement the receiver count and disconnect the channel if it drops down to zero. // Decrement the receiver count and disconnect the channel if it drops down to zero.
@ -518,7 +520,12 @@ impl<T> Stream for Receiver<T> {
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> { fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let this = &mut *self; let this = &mut *self;
poll_recv(&this.channel, &this.channel.streams, &mut this.opt_key, cx) poll_recv(
&this.channel,
&this.channel.stream_wakers,
&mut this.opt_key,
cx,
)
} }
} }
@ -530,39 +537,38 @@ impl<T> fmt::Debug for Receiver<T> {
/// Polls a receive operation on a channel. /// Polls a receive operation on a channel.
/// ///
/// If the receive operation is blocked, the current task will be registered in `registry` and its /// If the receive operation is blocked, the current task will be inserted into `wakers` and its
/// registration key will then be stored in `opt_key`. /// associated key will then be stored in `opt_key`.
fn poll_recv<T>( fn poll_recv<T>(
channel: &Channel<T>, channel: &Channel<T>,
registry: &Registry, wakers: &WakerSet,
opt_key: &mut Option<usize>, opt_key: &mut Option<usize>,
cx: &mut Context<'_>, cx: &mut Context<'_>,
) -> Poll<Option<T>> { ) -> Poll<Option<T>> {
// Try receiving a message. // Try receiving a message.
let poll = match channel.pop() { let poll = match channel.try_recv() {
Ok(msg) => Poll::Ready(Some(msg)), Ok(msg) => Poll::Ready(Some(msg)),
Err(PopError::Disconnected) => Poll::Ready(None), Err(TryRecvError::Disconnected) => Poll::Ready(None),
Err(PopError::Empty) => { Err(TryRecvError::Empty) => {
// Register the current task. // Insert this receive operation.
match *opt_key { match *opt_key {
None => *opt_key = Some(registry.register(cx)), None => *opt_key = Some(wakers.insert(cx)),
Some(key) => registry.reregister(key, cx), Some(key) => wakers.update(key, cx),
} }
// Try receiving a message again. // Try receiving a message again.
match channel.pop() { match channel.try_recv() {
Ok(msg) => Poll::Ready(Some(msg)), Ok(msg) => Poll::Ready(Some(msg)),
Err(PopError::Disconnected) => Poll::Ready(None), Err(TryRecvError::Disconnected) => Poll::Ready(None),
Err(PopError::Empty) => Poll::Pending, Err(TryRecvError::Empty) => Poll::Pending,
} }
} }
}; };
if poll.is_ready() { if poll.is_ready() {
// If the current task was registered, unregister now. // If the current task is in the set, remove it.
if let Some(key) = opt_key.take() { if let Some(key) = opt_key.take() {
// `true` means the receive operation is completed. wakers.complete(key);
registry.unregister(key, true);
} }
} }
@ -612,13 +618,13 @@ struct Channel<T> {
mark_bit: usize, mark_bit: usize,
/// Send operations waiting while the channel is full. /// Send operations waiting while the channel is full.
sends: Registry, send_wakers: WakerSet,
/// Receive operations waiting while the channel is empty and not disconnected. /// Receive operations waiting while the channel is empty and not disconnected.
recvs: Registry, recv_wakers: WakerSet,
/// Streams waiting while the channel is empty and not disconnected. /// Streams waiting while the channel is empty and not disconnected.
streams: Registry, stream_wakers: WakerSet,
/// The number of currently active `Sender`s. /// The number of currently active `Sender`s.
sender_count: AtomicUsize, sender_count: AtomicUsize,
@ -672,17 +678,17 @@ impl<T> Channel<T> {
mark_bit, mark_bit,
head: CachePadded::new(AtomicUsize::new(head)), head: CachePadded::new(AtomicUsize::new(head)),
tail: CachePadded::new(AtomicUsize::new(tail)), tail: CachePadded::new(AtomicUsize::new(tail)),
sends: Registry::new(), send_wakers: WakerSet::new(),
recvs: Registry::new(), recv_wakers: WakerSet::new(),
streams: Registry::new(), stream_wakers: WakerSet::new(),
sender_count: AtomicUsize::new(1), sender_count: AtomicUsize::new(1),
receiver_count: AtomicUsize::new(1), receiver_count: AtomicUsize::new(1),
_marker: PhantomData, _marker: PhantomData,
} }
} }
/// Attempts to push a message. /// Attempts to send a message.
fn push(&self, msg: T) -> Result<(), PushError<T>> { fn try_send(&self, msg: T) -> Result<(), TrySendError<T>> {
let backoff = Backoff::new(); let backoff = Backoff::new();
let mut tail = self.tail.load(Ordering::Relaxed); let mut tail = self.tail.load(Ordering::Relaxed);
@ -721,10 +727,10 @@ impl<T> Channel<T> {
slot.stamp.store(stamp, Ordering::Release); slot.stamp.store(stamp, Ordering::Release);
// Wake a blocked receive operation. // Wake a blocked receive operation.
self.recvs.notify_one(); self.recv_wakers.notify_one();
// Wake all blocked streams. // Wake all blocked streams.
self.streams.notify_all(); self.stream_wakers.notify_all();
return Ok(()); return Ok(());
} }
@ -743,9 +749,9 @@ impl<T> Channel<T> {
// Check if the channel is disconnected. // Check if the channel is disconnected.
if tail & self.mark_bit != 0 { if tail & self.mark_bit != 0 {
return Err(PushError::Disconnected(msg)); return Err(TrySendError::Disconnected(msg));
} else { } else {
return Err(PushError::Full(msg)); return Err(TrySendError::Full(msg));
} }
} }
@ -759,8 +765,8 @@ impl<T> Channel<T> {
} }
} }
/// Attempts to pop a message. /// Attempts to receive a message.
fn pop(&self) -> Result<T, PopError> { fn try_recv(&self) -> Result<T, TryRecvError> {
let backoff = Backoff::new(); let backoff = Backoff::new();
let mut head = self.head.load(Ordering::Relaxed); let mut head = self.head.load(Ordering::Relaxed);
@ -799,7 +805,7 @@ impl<T> Channel<T> {
slot.stamp.store(stamp, Ordering::Release); slot.stamp.store(stamp, Ordering::Release);
// Wake a blocked send operation. // Wake a blocked send operation.
self.sends.notify_one(); self.send_wakers.notify_one();
return Ok(msg); return Ok(msg);
} }
@ -816,10 +822,10 @@ impl<T> Channel<T> {
if (tail & !self.mark_bit) == head { if (tail & !self.mark_bit) == head {
// If the channel is disconnected... // If the channel is disconnected...
if tail & self.mark_bit != 0 { if tail & self.mark_bit != 0 {
return Err(PopError::Disconnected); return Err(TryRecvError::Disconnected);
} else { } else {
// Otherwise, the receive operation is not ready. // Otherwise, the receive operation is not ready.
return Err(PopError::Empty); return Err(TryRecvError::Empty);
} }
} }
@ -888,9 +894,9 @@ impl<T> Channel<T> {
if tail & self.mark_bit == 0 { if tail & self.mark_bit == 0 {
// Notify everyone blocked on this channel. // Notify everyone blocked on this channel.
self.sends.notify_all(); self.send_wakers.notify_all();
self.recvs.notify_all(); self.recv_wakers.notify_all();
self.streams.notify_all(); self.stream_wakers.notify_all();
} }
} }
} }
@ -921,8 +927,8 @@ impl<T> Drop for Channel<T> {
} }
} }
/// An error returned from the `push()` method. /// An error returned from the `try_send()` method.
enum PushError<T> { enum TrySendError<T> {
/// The channel is full but not disconnected. /// The channel is full but not disconnected.
Full(T), Full(T),
@ -930,203 +936,11 @@ enum PushError<T> {
Disconnected(T), Disconnected(T),
} }
/// An error returned from the `pop()` method. /// An error returned from the `try_recv()` method.
enum PopError { enum TryRecvError {
/// The channel is empty but not disconnected. /// The channel is empty but not disconnected.
Empty, Empty,
/// The channel is empty and disconnected. /// The channel is empty and disconnected.
Disconnected, Disconnected,
} }
/// A list of blocked channel operations.
struct Blocked {
/// A list of registered channel operations.
///
/// Each entry has a waker associated with the task that is executing the operation. If the
/// waker is set to `None`, that means the task has been woken up but hasn't removed itself
/// from the registry yet.
entries: Slab<Option<Waker>>,
/// The number of wakers in the entry list.
waker_count: usize,
}
/// A registry of blocked channel operations.
///
/// Blocked operations register themselves in a registry. Successful operations on the opposite
/// side of the channel wake blocked operations in the registry.
struct Registry {
/// A list of blocked channel operations.
blocked: Spinlock<Blocked>,
/// Set to `true` if there are no wakers in the registry.
///
/// Note that this either means there are no entries in the registry, or that all entries have
/// been notified.
is_empty: AtomicBool,
}
impl Registry {
/// Creates a new registry.
fn new() -> Registry {
Registry {
blocked: Spinlock::new(Blocked {
entries: Slab::new(),
waker_count: 0,
}),
is_empty: AtomicBool::new(true),
}
}
/// Registers a blocked channel operation and returns a key associated with it.
fn register(&self, cx: &Context<'_>) -> usize {
let mut blocked = self.blocked.lock();
// Insert a new entry into the list of blocked tasks.
let w = cx.waker().clone();
let key = blocked.entries.insert(Some(w));
blocked.waker_count += 1;
if blocked.waker_count == 1 {
self.is_empty.store(false, Ordering::SeqCst);
}
key
}
/// Re-registers a blocked channel operation by filling in its waker.
fn reregister(&self, key: usize, cx: &Context<'_>) {
let mut blocked = self.blocked.lock();
let was_none = blocked.entries[key].is_none();
let w = cx.waker().clone();
blocked.entries[key] = Some(w);
if was_none {
blocked.waker_count += 1;
if blocked.waker_count == 1 {
self.is_empty.store(false, Ordering::SeqCst);
}
}
}
/// Unregisters a channel operation.
///
/// If `completed` is `true`, the operation will be removed from the registry. If `completed`
/// is `false`, that means the operation was cancelled so another one will be notified.
fn unregister(&self, key: usize, completed: bool) {
let mut blocked = self.blocked.lock();
let mut removed = false;
match blocked.entries.remove(key) {
Some(_) => removed = true,
None => {
if !completed {
// This operation was cancelled. Notify another one.
if let Some((_, opt_waker)) = blocked.entries.iter_mut().next() {
if let Some(w) = opt_waker.take() {
w.wake();
removed = true;
}
}
}
}
}
if removed {
blocked.waker_count -= 1;
if blocked.waker_count == 0 {
self.is_empty.store(true, Ordering::SeqCst);
}
}
}
/// Notifies one blocked channel operation.
#[inline]
fn notify_one(&self) {
if !self.is_empty.load(Ordering::SeqCst) {
let mut blocked = self.blocked.lock();
if let Some((_, opt_waker)) = blocked.entries.iter_mut().next() {
// If there is no waker in this entry, that means it was already woken.
if let Some(w) = opt_waker.take() {
w.wake();
blocked.waker_count -= 1;
if blocked.waker_count == 0 {
self.is_empty.store(true, Ordering::SeqCst);
}
}
}
}
}
/// Notifies all blocked channel operations.
#[inline]
fn notify_all(&self) {
if !self.is_empty.load(Ordering::SeqCst) {
let mut blocked = self.blocked.lock();
for (_, opt_waker) in blocked.entries.iter_mut() {
// If there is no waker in this entry, that means it was already woken.
if let Some(w) = opt_waker.take() {
w.wake();
}
}
blocked.waker_count = 0;
self.is_empty.store(true, Ordering::SeqCst);
}
}
}
/// A simple spinlock.
struct Spinlock<T> {
flag: AtomicBool,
value: UnsafeCell<T>,
}
impl<T> Spinlock<T> {
/// Returns a new spinlock initialized with `value`.
fn new(value: T) -> Spinlock<T> {
Spinlock {
flag: AtomicBool::new(false),
value: UnsafeCell::new(value),
}
}
/// Locks the spinlock.
fn lock(&self) -> SpinlockGuard<'_, T> {
let backoff = Backoff::new();
while self.flag.swap(true, Ordering::Acquire) {
backoff.snooze();
}
SpinlockGuard { parent: self }
}
}
/// A guard holding a spinlock locked.
struct SpinlockGuard<'a, T> {
parent: &'a Spinlock<T>,
}
impl<'a, T> Drop for SpinlockGuard<'a, T> {
fn drop(&mut self) {
self.parent.flag.store(false, Ordering::Release);
}
}
impl<'a, T> Deref for SpinlockGuard<'a, T> {
type Target = T;
fn deref(&self) -> &T {
unsafe { &*self.parent.value.get() }
}
}
impl<'a, T> DerefMut for SpinlockGuard<'a, T> {
fn deref_mut(&mut self) -> &mut T {
unsafe { &mut *self.parent.value.get() }
}
}

@ -191,3 +191,6 @@ cfg_unstable! {
mod barrier; mod barrier;
mod channel; mod channel;
} }
pub(crate) mod waker_set;
pub(crate) use waker_set::WakerSet;

@ -2,18 +2,11 @@ use std::cell::UnsafeCell;
use std::fmt; use std::fmt;
use std::ops::{Deref, DerefMut}; use std::ops::{Deref, DerefMut};
use std::pin::Pin; use std::pin::Pin;
use std::sync::atomic::{AtomicUsize, Ordering}; use std::sync::atomic::{AtomicBool, Ordering};
use slab::Slab;
use crate::future::Future; use crate::future::Future;
use crate::task::{Context, Poll, Waker}; use crate::sync::WakerSet;
use crate::task::{Context, Poll};
/// Set if the mutex is locked.
const LOCK: usize = 1;
/// Set if there are tasks blocked on the mutex.
const BLOCKED: usize = 1 << 1;
/// A mutual exclusion primitive for protecting shared data. /// A mutual exclusion primitive for protecting shared data.
/// ///
@ -49,8 +42,8 @@ const BLOCKED: usize = 1 << 1;
/// # }) /// # })
/// ``` /// ```
pub struct Mutex<T> { pub struct Mutex<T> {
state: AtomicUsize, locked: AtomicBool,
blocked: std::sync::Mutex<Slab<Option<Waker>>>, wakers: WakerSet,
value: UnsafeCell<T>, value: UnsafeCell<T>,
} }
@ -69,8 +62,8 @@ impl<T> Mutex<T> {
/// ``` /// ```
pub fn new(t: T) -> Mutex<T> { pub fn new(t: T) -> Mutex<T> {
Mutex { Mutex {
state: AtomicUsize::new(0), locked: AtomicBool::new(false),
blocked: std::sync::Mutex::new(Slab::new()), wakers: WakerSet::new(),
value: UnsafeCell::new(t), value: UnsafeCell::new(t),
} }
} }
@ -105,75 +98,46 @@ impl<T> Mutex<T> {
pub struct LockFuture<'a, T> { pub struct LockFuture<'a, T> {
mutex: &'a Mutex<T>, mutex: &'a Mutex<T>,
opt_key: Option<usize>, opt_key: Option<usize>,
acquired: bool,
} }
impl<'a, T> Future for LockFuture<'a, T> { impl<'a, T> Future for LockFuture<'a, T> {
type Output = MutexGuard<'a, T>; type Output = MutexGuard<'a, T>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
match self.mutex.try_lock() { let poll = match self.mutex.try_lock() {
Some(guard) => { Some(guard) => Poll::Ready(guard),
self.acquired = true;
Poll::Ready(guard)
}
None => { None => {
let mut blocked = self.mutex.blocked.lock().unwrap(); // Insert this lock operation.
// Register the current task.
match self.opt_key { match self.opt_key {
None => { None => self.opt_key = Some(self.mutex.wakers.insert(cx)),
// Insert a new entry into the list of blocked tasks. Some(key) => self.mutex.wakers.update(key, cx),
let w = cx.waker().clone();
let key = blocked.insert(Some(w));
self.opt_key = Some(key);
if blocked.len() == 1 {
self.mutex.state.fetch_or(BLOCKED, Ordering::Relaxed);
}
}
Some(key) => {
// There is already an entry in the list of blocked tasks. Just
// reset the waker if it was removed.
if blocked[key].is_none() {
let w = cx.waker().clone();
blocked[key] = Some(w);
}
}
} }
// Try locking again because it's possible the mutex got unlocked just // Try locking again because it's possible the mutex got unlocked just
// before the current task was registered as a blocked task. // before the current task was inserted into the waker set.
match self.mutex.try_lock() { match self.mutex.try_lock() {
Some(guard) => { Some(guard) => Poll::Ready(guard),
self.acquired = true;
Poll::Ready(guard)
}
None => Poll::Pending, None => Poll::Pending,
} }
} }
};
if poll.is_ready() {
// If the current task is in the set, remove it.
if let Some(key) = self.opt_key.take() {
self.mutex.wakers.complete(key);
}
} }
poll
} }
} }
impl<T> Drop for LockFuture<'_, T> { impl<T> Drop for LockFuture<'_, T> {
fn drop(&mut self) { fn drop(&mut self) {
// If the current task is still in the set, that means it is being cancelled now.
if let Some(key) = self.opt_key { if let Some(key) = self.opt_key {
let mut blocked = self.mutex.blocked.lock().unwrap(); self.mutex.wakers.cancel(key);
let opt_waker = blocked.remove(key);
if opt_waker.is_none() && !self.acquired {
// We were awoken but didn't acquire the lock. Wake up another task.
if let Some((_, opt_waker)) = blocked.iter_mut().next() {
if let Some(w) = opt_waker.take() {
w.wake();
}
}
}
if blocked.is_empty() {
self.mutex.state.fetch_and(!BLOCKED, Ordering::Relaxed);
}
} }
} }
} }
@ -181,7 +145,6 @@ impl<T> Mutex<T> {
LockFuture { LockFuture {
mutex: self, mutex: self,
opt_key: None, opt_key: None,
acquired: false,
} }
.await .await
} }
@ -220,7 +183,7 @@ impl<T> Mutex<T> {
/// # }) /// # })
/// ``` /// ```
pub fn try_lock(&self) -> Option<MutexGuard<'_, T>> { pub fn try_lock(&self) -> Option<MutexGuard<'_, T>> {
if self.state.fetch_or(LOCK, Ordering::Acquire) & LOCK == 0 { if !self.locked.swap(true, Ordering::SeqCst) {
Some(MutexGuard(self)) Some(MutexGuard(self))
} else { } else {
None None
@ -266,18 +229,15 @@ impl<T> Mutex<T> {
impl<T: fmt::Debug> fmt::Debug for Mutex<T> { impl<T: fmt::Debug> fmt::Debug for Mutex<T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self.try_lock() { struct Locked;
None => { impl fmt::Debug for Locked {
struct LockedPlaceholder; fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
impl fmt::Debug for LockedPlaceholder { f.write_str("<locked>")
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str("<locked>")
}
}
f.debug_struct("Mutex")
.field("data", &LockedPlaceholder)
.finish()
} }
}
match self.try_lock() {
None => f.debug_struct("Mutex").field("data", &Locked).finish(),
Some(guard) => f.debug_struct("Mutex").field("data", &&*guard).finish(), Some(guard) => f.debug_struct("Mutex").field("data", &&*guard).finish(),
} }
} }
@ -303,19 +263,11 @@ unsafe impl<T: Sync> Sync for MutexGuard<'_, T> {}
impl<T> Drop for MutexGuard<'_, T> { impl<T> Drop for MutexGuard<'_, T> {
fn drop(&mut self) { fn drop(&mut self) {
let state = self.0.state.fetch_and(!LOCK, Ordering::AcqRel); // Use `SeqCst` ordering to synchronize with `WakerSet::insert()` and `WakerSet::update()`.
self.0.locked.store(false, Ordering::SeqCst);
// If there are any blocked tasks, wake one of them up.
if state & BLOCKED != 0 {
let mut blocked = self.0.blocked.lock().unwrap();
if let Some((_, opt_waker)) = blocked.iter_mut().next() { // Notify one blocked `lock()` operation.
// If there is no waker in this entry, that means it was already woken. self.0.wakers.notify_one();
if let Some(w) = opt_waker.take() {
w.wake();
}
}
}
} }
} }

@ -10,7 +10,8 @@ use crate::future::Future;
use crate::task::{Context, Poll, Waker}; use crate::task::{Context, Poll, Waker};
/// Set if a write lock is held. /// Set if a write lock is held.
const WRITE_LOCK: usize = 1; #[allow(clippy::identity_op)]
const WRITE_LOCK: usize = 1 << 0;
/// Set if there are read operations blocked on the lock. /// Set if there are read operations blocked on the lock.
const BLOCKED_READS: usize = 1 << 1; const BLOCKED_READS: usize = 1 << 1;

@ -0,0 +1,200 @@
//! A common utility for building synchronization primitives.
//!
//! When an async operation is blocked, it needs to register itself somewhere so that it can be
//! notified later on. The `WakerSet` type helps with keeping track of such async operations and
//! notifying them when they may make progress.
use std::cell::UnsafeCell;
use std::ops::{Deref, DerefMut};
use std::sync::atomic::{AtomicUsize, Ordering};
use crossbeam_utils::Backoff;
use slab::Slab;
use crate::task::{Context, Waker};
/// Set when the entry list is locked.
#[allow(clippy::identity_op)]
const LOCKED: usize = 1 << 0;
/// Set when there are tasks for `notify_one()` to wake.
const NOTIFY_ONE: usize = 1 << 1;
/// Set when there are tasks for `notify_all()` to wake.
const NOTIFY_ALL: usize = 1 << 2;
/// Inner representation of `WakerSet`.
struct Inner {
/// A list of entries in the set.
///
/// Each entry has an optional waker associated with the task that is executing the operation.
/// If the waker is set to `None`, that means the task has been woken up but hasn't removed
/// itself from the `WakerSet` yet.
///
/// The key of each entry is its index in the `Slab`.
entries: Slab<Option<Waker>>,
/// The number of entries that have the waker set to `None`.
none_count: usize,
}
/// A set holding wakers.
pub struct WakerSet {
/// Holds three bits: `LOCKED`, `NOTIFY_ONE`, and `NOTIFY_ALL`.
flag: AtomicUsize,
/// A set holding wakers.
inner: UnsafeCell<Inner>,
}
impl WakerSet {
/// Creates a new `WakerSet`.
#[inline]
pub fn new() -> WakerSet {
WakerSet {
flag: AtomicUsize::new(0),
inner: UnsafeCell::new(Inner {
entries: Slab::new(),
none_count: 0,
}),
}
}
/// Inserts a waker for a blocked operation and returns a key associated with it.
pub fn insert(&self, cx: &Context<'_>) -> usize {
let w = cx.waker().clone();
self.lock().entries.insert(Some(w))
}
/// Updates the waker of a previously inserted entry.
pub fn update(&self, key: usize, cx: &Context<'_>) {
let mut inner = self.lock();
match &mut inner.entries[key] {
None => {
// Fill in the waker.
let w = cx.waker().clone();
inner.entries[key] = Some(w);
inner.none_count -= 1;
}
Some(w) => {
// Replace the waker if the existing one is different.
if !w.will_wake(cx.waker()) {
*w = cx.waker().clone();
}
}
}
}
/// Removes the waker of a completed operation.
pub fn complete(&self, key: usize) {
let mut inner = self.lock();
if inner.entries.remove(key).is_none() {
inner.none_count -= 1;
}
}
/// Removes the waker of a cancelled operation.
pub fn cancel(&self, key: usize) {
let mut inner = self.lock();
if inner.entries.remove(key).is_none() {
inner.none_count -= 1;
// The operation was cancelled and notified so notify another operation instead.
if let Some((_, opt_waker)) = inner.entries.iter_mut().next() {
// If there is no waker in this entry, that means it was already woken.
if let Some(w) = opt_waker.take() {
w.wake();
inner.none_count += 1;
}
}
}
}
/// Notifies one blocked operation.
#[inline]
pub fn notify_one(&self) {
// Use `SeqCst` ordering to synchronize with `Lock::drop()`.
if self.flag.load(Ordering::SeqCst) & NOTIFY_ONE != 0 {
self.notify(false);
}
}
/// Notifies all blocked operations.
// TODO: Delete this attribute when `crate::sync::channel()` is stabilized.
#[cfg(feature = "unstable")]
#[inline]
pub fn notify_all(&self) {
// Use `SeqCst` ordering to synchronize with `Lock::drop()`.
if self.flag.load(Ordering::SeqCst) & NOTIFY_ALL != 0 {
self.notify(true);
}
}
/// Notifies blocked operations, either one or all of them.
fn notify(&self, all: bool) {
let mut inner = &mut *self.lock();
for (_, opt_waker) in inner.entries.iter_mut() {
// If there is no waker in this entry, that means it was already woken.
if let Some(w) = opt_waker.take() {
w.wake();
inner.none_count += 1;
}
if !all {
break;
}
}
}
/// Locks the list of entries.
#[cold]
fn lock(&self) -> Lock<'_> {
let backoff = Backoff::new();
while self.flag.fetch_or(LOCKED, Ordering::Acquire) & LOCKED != 0 {
backoff.snooze();
}
Lock { waker_set: self }
}
}
/// A guard holding a `WakerSet` locked.
struct Lock<'a> {
waker_set: &'a WakerSet,
}
impl Drop for Lock<'_> {
#[inline]
fn drop(&mut self) {
let mut flag = 0;
// If there is at least one entry and all are `Some`, then `notify_one()` has work to do.
if !self.entries.is_empty() && self.none_count == 0 {
flag |= NOTIFY_ONE;
}
// If there is at least one `Some` entry, then `notify_all()` has work to do.
if self.entries.len() - self.none_count > 0 {
flag |= NOTIFY_ALL;
}
// Use `SeqCst` ordering to synchronize with `WakerSet::lock_to_notify()`.
self.waker_set.flag.store(flag, Ordering::SeqCst);
}
}
impl Deref for Lock<'_> {
type Target = Inner;
#[inline]
fn deref(&self) -> &Inner {
unsafe { &*self.waker_set.inner.get() }
}
}
impl DerefMut for Lock<'_> {
#[inline]
fn deref_mut(&mut self) -> &mut Inner {
unsafe { &mut *self.waker_set.inner.get() }
}
}
Loading…
Cancel
Save