forked from mirror/async-std
Add utility type WakerSet to the sync module (#390)
* Add utility type Registry to the sync module * Remove unused import * Split unregister into complete and cancel * Refactoring and renaming * Split remove() into complete() and cancel() * Rename to WakerSet * Ignore clippy warning * Ignore another clippy warning * Use stronger SeqCst ordering
This commit is contained in:
parent
3dd59d7056
commit
87de4e1598
7 changed files with 371 additions and 348 deletions
42
benches/mutex.rs
Normal file
42
benches/mutex.rs
Normal file
|
@ -0,0 +1,42 @@
|
|||
#![feature(test)]
|
||||
|
||||
extern crate test;
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
use async_std::sync::Mutex;
|
||||
use async_std::task;
|
||||
use test::Bencher;
|
||||
|
||||
#[bench]
|
||||
fn create(b: &mut Bencher) {
|
||||
b.iter(|| Mutex::new(()));
|
||||
}
|
||||
|
||||
#[bench]
|
||||
fn contention(b: &mut Bencher) {
|
||||
b.iter(|| task::block_on(run(10, 1000)));
|
||||
}
|
||||
|
||||
#[bench]
|
||||
fn no_contention(b: &mut Bencher) {
|
||||
b.iter(|| task::block_on(run(1, 10000)));
|
||||
}
|
||||
|
||||
async fn run(task: usize, iter: usize) {
|
||||
let m = Arc::new(Mutex::new(()));
|
||||
let mut tasks = Vec::new();
|
||||
|
||||
for _ in 0..task {
|
||||
let m = m.clone();
|
||||
tasks.push(task::spawn(async move {
|
||||
for _ in 0..iter {
|
||||
let _ = m.lock().await;
|
||||
}
|
||||
}));
|
||||
}
|
||||
|
||||
for t in tasks {
|
||||
t.await;
|
||||
}
|
||||
}
|
11
benches/task.rs
Normal file
11
benches/task.rs
Normal file
|
@ -0,0 +1,11 @@
|
|||
#![feature(test)]
|
||||
|
||||
extern crate test;
|
||||
|
||||
use async_std::task;
|
||||
use test::Bencher;
|
||||
|
||||
#[bench]
|
||||
fn block_on(b: &mut Bencher) {
|
||||
b.iter(|| task::block_on(async {}));
|
||||
}
|
|
@ -4,17 +4,17 @@ use std::future::Future;
|
|||
use std::isize;
|
||||
use std::marker::PhantomData;
|
||||
use std::mem;
|
||||
use std::ops::{Deref, DerefMut};
|
||||
use std::pin::Pin;
|
||||
use std::process;
|
||||
use std::ptr;
|
||||
use std::sync::atomic::{self, AtomicBool, AtomicUsize, Ordering};
|
||||
use std::sync::atomic::{self, AtomicUsize, Ordering};
|
||||
use std::sync::Arc;
|
||||
use std::task::{Context, Poll, Waker};
|
||||
use std::task::{Context, Poll};
|
||||
|
||||
use crossbeam_utils::{Backoff, CachePadded};
|
||||
use futures_core::stream::Stream;
|
||||
use slab::Slab;
|
||||
|
||||
use crate::stream::Stream;
|
||||
use crate::sync::WakerSet;
|
||||
|
||||
/// Creates a bounded multi-producer multi-consumer channel.
|
||||
///
|
||||
|
@ -128,7 +128,7 @@ impl<T> Sender<T> {
|
|||
/// ```
|
||||
pub async fn send(&self, msg: T) {
|
||||
struct SendFuture<'a, T> {
|
||||
sender: &'a Sender<T>,
|
||||
channel: &'a Channel<T>,
|
||||
msg: Option<T>,
|
||||
opt_key: Option<usize>,
|
||||
}
|
||||
|
@ -142,23 +142,23 @@ impl<T> Sender<T> {
|
|||
let msg = self.msg.take().unwrap();
|
||||
|
||||
// Try sending the message.
|
||||
let poll = match self.sender.channel.push(msg) {
|
||||
let poll = match self.channel.try_send(msg) {
|
||||
Ok(()) => Poll::Ready(()),
|
||||
Err(PushError::Disconnected(msg)) => {
|
||||
Err(TrySendError::Disconnected(msg)) => {
|
||||
self.msg = Some(msg);
|
||||
Poll::Pending
|
||||
}
|
||||
Err(PushError::Full(msg)) => {
|
||||
// Register the current task.
|
||||
Err(TrySendError::Full(msg)) => {
|
||||
// Insert this send operation.
|
||||
match self.opt_key {
|
||||
None => self.opt_key = Some(self.sender.channel.sends.register(cx)),
|
||||
Some(key) => self.sender.channel.sends.reregister(key, cx),
|
||||
None => self.opt_key = Some(self.channel.send_wakers.insert(cx)),
|
||||
Some(key) => self.channel.send_wakers.update(key, cx),
|
||||
}
|
||||
|
||||
// Try sending the message again.
|
||||
match self.sender.channel.push(msg) {
|
||||
match self.channel.try_send(msg) {
|
||||
Ok(()) => Poll::Ready(()),
|
||||
Err(PushError::Disconnected(msg)) | Err(PushError::Full(msg)) => {
|
||||
Err(TrySendError::Disconnected(msg)) | Err(TrySendError::Full(msg)) => {
|
||||
self.msg = Some(msg);
|
||||
Poll::Pending
|
||||
}
|
||||
|
@ -167,10 +167,9 @@ impl<T> Sender<T> {
|
|||
};
|
||||
|
||||
if poll.is_ready() {
|
||||
// If the current task was registered, unregister now.
|
||||
// If the current task is in the set, remove it.
|
||||
if let Some(key) = self.opt_key.take() {
|
||||
// `true` means the send operation is completed.
|
||||
self.sender.channel.sends.unregister(key, true);
|
||||
self.channel.send_wakers.complete(key);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -180,16 +179,16 @@ impl<T> Sender<T> {
|
|||
|
||||
impl<T> Drop for SendFuture<'_, T> {
|
||||
fn drop(&mut self) {
|
||||
// If the current task was registered, unregister now.
|
||||
// If the current task is still in the set, that means it is being cancelled now.
|
||||
// Wake up another task instead.
|
||||
if let Some(key) = self.opt_key {
|
||||
// `false` means the send operation is cancelled.
|
||||
self.sender.channel.sends.unregister(key, false);
|
||||
self.channel.send_wakers.cancel(key);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
SendFuture {
|
||||
sender: self,
|
||||
channel: &self.channel,
|
||||
msg: Some(msg),
|
||||
opt_key: None,
|
||||
}
|
||||
|
@ -340,7 +339,7 @@ pub struct Receiver<T> {
|
|||
/// The inner channel.
|
||||
channel: Arc<Channel<T>>,
|
||||
|
||||
/// The registration key for this receiver in the `channel.streams` registry.
|
||||
/// The key for this receiver in the `channel.stream_wakers` set.
|
||||
opt_key: Option<usize>,
|
||||
}
|
||||
|
||||
|
@ -382,16 +381,20 @@ impl<T> Receiver<T> {
|
|||
type Output = Option<T>;
|
||||
|
||||
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
|
||||
poll_recv(&self.channel, &self.channel.recvs, &mut self.opt_key, cx)
|
||||
poll_recv(
|
||||
&self.channel,
|
||||
&self.channel.recv_wakers,
|
||||
&mut self.opt_key,
|
||||
cx,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> Drop for RecvFuture<'_, T> {
|
||||
fn drop(&mut self) {
|
||||
// If the current task was registered, unregister now.
|
||||
// If the current task is still in the set, that means it is being cancelled now.
|
||||
if let Some(key) = self.opt_key {
|
||||
// `false` means the receive operation is cancelled.
|
||||
self.channel.recvs.unregister(key, false);
|
||||
self.channel.recv_wakers.cancel(key);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -484,10 +487,9 @@ impl<T> Receiver<T> {
|
|||
|
||||
impl<T> Drop for Receiver<T> {
|
||||
fn drop(&mut self) {
|
||||
// If the current task was registered as blocked on this stream, unregister now.
|
||||
// If the current task is still in the stream set, that means it is being cancelled now.
|
||||
if let Some(key) = self.opt_key {
|
||||
// `false` means the last request for a stream item is cancelled.
|
||||
self.channel.streams.unregister(key, false);
|
||||
self.channel.stream_wakers.cancel(key);
|
||||
}
|
||||
|
||||
// Decrement the receiver count and disconnect the channel if it drops down to zero.
|
||||
|
@ -518,7 +520,12 @@ impl<T> Stream for Receiver<T> {
|
|||
|
||||
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
|
||||
let this = &mut *self;
|
||||
poll_recv(&this.channel, &this.channel.streams, &mut this.opt_key, cx)
|
||||
poll_recv(
|
||||
&this.channel,
|
||||
&this.channel.stream_wakers,
|
||||
&mut this.opt_key,
|
||||
cx,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -530,39 +537,38 @@ impl<T> fmt::Debug for Receiver<T> {
|
|||
|
||||
/// Polls a receive operation on a channel.
|
||||
///
|
||||
/// If the receive operation is blocked, the current task will be registered in `registry` and its
|
||||
/// registration key will then be stored in `opt_key`.
|
||||
/// If the receive operation is blocked, the current task will be inserted into `wakers` and its
|
||||
/// associated key will then be stored in `opt_key`.
|
||||
fn poll_recv<T>(
|
||||
channel: &Channel<T>,
|
||||
registry: &Registry,
|
||||
wakers: &WakerSet,
|
||||
opt_key: &mut Option<usize>,
|
||||
cx: &mut Context<'_>,
|
||||
) -> Poll<Option<T>> {
|
||||
// Try receiving a message.
|
||||
let poll = match channel.pop() {
|
||||
let poll = match channel.try_recv() {
|
||||
Ok(msg) => Poll::Ready(Some(msg)),
|
||||
Err(PopError::Disconnected) => Poll::Ready(None),
|
||||
Err(PopError::Empty) => {
|
||||
// Register the current task.
|
||||
Err(TryRecvError::Disconnected) => Poll::Ready(None),
|
||||
Err(TryRecvError::Empty) => {
|
||||
// Insert this receive operation.
|
||||
match *opt_key {
|
||||
None => *opt_key = Some(registry.register(cx)),
|
||||
Some(key) => registry.reregister(key, cx),
|
||||
None => *opt_key = Some(wakers.insert(cx)),
|
||||
Some(key) => wakers.update(key, cx),
|
||||
}
|
||||
|
||||
// Try receiving a message again.
|
||||
match channel.pop() {
|
||||
match channel.try_recv() {
|
||||
Ok(msg) => Poll::Ready(Some(msg)),
|
||||
Err(PopError::Disconnected) => Poll::Ready(None),
|
||||
Err(PopError::Empty) => Poll::Pending,
|
||||
Err(TryRecvError::Disconnected) => Poll::Ready(None),
|
||||
Err(TryRecvError::Empty) => Poll::Pending,
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
if poll.is_ready() {
|
||||
// If the current task was registered, unregister now.
|
||||
// If the current task is in the set, remove it.
|
||||
if let Some(key) = opt_key.take() {
|
||||
// `true` means the receive operation is completed.
|
||||
registry.unregister(key, true);
|
||||
wakers.complete(key);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -612,13 +618,13 @@ struct Channel<T> {
|
|||
mark_bit: usize,
|
||||
|
||||
/// Send operations waiting while the channel is full.
|
||||
sends: Registry,
|
||||
send_wakers: WakerSet,
|
||||
|
||||
/// Receive operations waiting while the channel is empty and not disconnected.
|
||||
recvs: Registry,
|
||||
recv_wakers: WakerSet,
|
||||
|
||||
/// Streams waiting while the channel is empty and not disconnected.
|
||||
streams: Registry,
|
||||
stream_wakers: WakerSet,
|
||||
|
||||
/// The number of currently active `Sender`s.
|
||||
sender_count: AtomicUsize,
|
||||
|
@ -672,17 +678,17 @@ impl<T> Channel<T> {
|
|||
mark_bit,
|
||||
head: CachePadded::new(AtomicUsize::new(head)),
|
||||
tail: CachePadded::new(AtomicUsize::new(tail)),
|
||||
sends: Registry::new(),
|
||||
recvs: Registry::new(),
|
||||
streams: Registry::new(),
|
||||
send_wakers: WakerSet::new(),
|
||||
recv_wakers: WakerSet::new(),
|
||||
stream_wakers: WakerSet::new(),
|
||||
sender_count: AtomicUsize::new(1),
|
||||
receiver_count: AtomicUsize::new(1),
|
||||
_marker: PhantomData,
|
||||
}
|
||||
}
|
||||
|
||||
/// Attempts to push a message.
|
||||
fn push(&self, msg: T) -> Result<(), PushError<T>> {
|
||||
/// Attempts to send a message.
|
||||
fn try_send(&self, msg: T) -> Result<(), TrySendError<T>> {
|
||||
let backoff = Backoff::new();
|
||||
let mut tail = self.tail.load(Ordering::Relaxed);
|
||||
|
||||
|
@ -721,10 +727,10 @@ impl<T> Channel<T> {
|
|||
slot.stamp.store(stamp, Ordering::Release);
|
||||
|
||||
// Wake a blocked receive operation.
|
||||
self.recvs.notify_one();
|
||||
self.recv_wakers.notify_one();
|
||||
|
||||
// Wake all blocked streams.
|
||||
self.streams.notify_all();
|
||||
self.stream_wakers.notify_all();
|
||||
|
||||
return Ok(());
|
||||
}
|
||||
|
@ -743,9 +749,9 @@ impl<T> Channel<T> {
|
|||
|
||||
// Check if the channel is disconnected.
|
||||
if tail & self.mark_bit != 0 {
|
||||
return Err(PushError::Disconnected(msg));
|
||||
return Err(TrySendError::Disconnected(msg));
|
||||
} else {
|
||||
return Err(PushError::Full(msg));
|
||||
return Err(TrySendError::Full(msg));
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -759,8 +765,8 @@ impl<T> Channel<T> {
|
|||
}
|
||||
}
|
||||
|
||||
/// Attempts to pop a message.
|
||||
fn pop(&self) -> Result<T, PopError> {
|
||||
/// Attempts to receive a message.
|
||||
fn try_recv(&self) -> Result<T, TryRecvError> {
|
||||
let backoff = Backoff::new();
|
||||
let mut head = self.head.load(Ordering::Relaxed);
|
||||
|
||||
|
@ -799,7 +805,7 @@ impl<T> Channel<T> {
|
|||
slot.stamp.store(stamp, Ordering::Release);
|
||||
|
||||
// Wake a blocked send operation.
|
||||
self.sends.notify_one();
|
||||
self.send_wakers.notify_one();
|
||||
|
||||
return Ok(msg);
|
||||
}
|
||||
|
@ -816,10 +822,10 @@ impl<T> Channel<T> {
|
|||
if (tail & !self.mark_bit) == head {
|
||||
// If the channel is disconnected...
|
||||
if tail & self.mark_bit != 0 {
|
||||
return Err(PopError::Disconnected);
|
||||
return Err(TryRecvError::Disconnected);
|
||||
} else {
|
||||
// Otherwise, the receive operation is not ready.
|
||||
return Err(PopError::Empty);
|
||||
return Err(TryRecvError::Empty);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -888,9 +894,9 @@ impl<T> Channel<T> {
|
|||
|
||||
if tail & self.mark_bit == 0 {
|
||||
// Notify everyone blocked on this channel.
|
||||
self.sends.notify_all();
|
||||
self.recvs.notify_all();
|
||||
self.streams.notify_all();
|
||||
self.send_wakers.notify_all();
|
||||
self.recv_wakers.notify_all();
|
||||
self.stream_wakers.notify_all();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -921,8 +927,8 @@ impl<T> Drop for Channel<T> {
|
|||
}
|
||||
}
|
||||
|
||||
/// An error returned from the `push()` method.
|
||||
enum PushError<T> {
|
||||
/// An error returned from the `try_send()` method.
|
||||
enum TrySendError<T> {
|
||||
/// The channel is full but not disconnected.
|
||||
Full(T),
|
||||
|
||||
|
@ -930,203 +936,11 @@ enum PushError<T> {
|
|||
Disconnected(T),
|
||||
}
|
||||
|
||||
/// An error returned from the `pop()` method.
|
||||
enum PopError {
|
||||
/// An error returned from the `try_recv()` method.
|
||||
enum TryRecvError {
|
||||
/// The channel is empty but not disconnected.
|
||||
Empty,
|
||||
|
||||
/// The channel is empty and disconnected.
|
||||
Disconnected,
|
||||
}
|
||||
|
||||
/// A list of blocked channel operations.
|
||||
struct Blocked {
|
||||
/// A list of registered channel operations.
|
||||
///
|
||||
/// Each entry has a waker associated with the task that is executing the operation. If the
|
||||
/// waker is set to `None`, that means the task has been woken up but hasn't removed itself
|
||||
/// from the registry yet.
|
||||
entries: Slab<Option<Waker>>,
|
||||
|
||||
/// The number of wakers in the entry list.
|
||||
waker_count: usize,
|
||||
}
|
||||
|
||||
/// A registry of blocked channel operations.
|
||||
///
|
||||
/// Blocked operations register themselves in a registry. Successful operations on the opposite
|
||||
/// side of the channel wake blocked operations in the registry.
|
||||
struct Registry {
|
||||
/// A list of blocked channel operations.
|
||||
blocked: Spinlock<Blocked>,
|
||||
|
||||
/// Set to `true` if there are no wakers in the registry.
|
||||
///
|
||||
/// Note that this either means there are no entries in the registry, or that all entries have
|
||||
/// been notified.
|
||||
is_empty: AtomicBool,
|
||||
}
|
||||
|
||||
impl Registry {
|
||||
/// Creates a new registry.
|
||||
fn new() -> Registry {
|
||||
Registry {
|
||||
blocked: Spinlock::new(Blocked {
|
||||
entries: Slab::new(),
|
||||
waker_count: 0,
|
||||
}),
|
||||
is_empty: AtomicBool::new(true),
|
||||
}
|
||||
}
|
||||
|
||||
/// Registers a blocked channel operation and returns a key associated with it.
|
||||
fn register(&self, cx: &Context<'_>) -> usize {
|
||||
let mut blocked = self.blocked.lock();
|
||||
|
||||
// Insert a new entry into the list of blocked tasks.
|
||||
let w = cx.waker().clone();
|
||||
let key = blocked.entries.insert(Some(w));
|
||||
|
||||
blocked.waker_count += 1;
|
||||
if blocked.waker_count == 1 {
|
||||
self.is_empty.store(false, Ordering::SeqCst);
|
||||
}
|
||||
|
||||
key
|
||||
}
|
||||
|
||||
/// Re-registers a blocked channel operation by filling in its waker.
|
||||
fn reregister(&self, key: usize, cx: &Context<'_>) {
|
||||
let mut blocked = self.blocked.lock();
|
||||
|
||||
let was_none = blocked.entries[key].is_none();
|
||||
let w = cx.waker().clone();
|
||||
blocked.entries[key] = Some(w);
|
||||
|
||||
if was_none {
|
||||
blocked.waker_count += 1;
|
||||
if blocked.waker_count == 1 {
|
||||
self.is_empty.store(false, Ordering::SeqCst);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Unregisters a channel operation.
|
||||
///
|
||||
/// If `completed` is `true`, the operation will be removed from the registry. If `completed`
|
||||
/// is `false`, that means the operation was cancelled so another one will be notified.
|
||||
fn unregister(&self, key: usize, completed: bool) {
|
||||
let mut blocked = self.blocked.lock();
|
||||
let mut removed = false;
|
||||
|
||||
match blocked.entries.remove(key) {
|
||||
Some(_) => removed = true,
|
||||
None => {
|
||||
if !completed {
|
||||
// This operation was cancelled. Notify another one.
|
||||
if let Some((_, opt_waker)) = blocked.entries.iter_mut().next() {
|
||||
if let Some(w) = opt_waker.take() {
|
||||
w.wake();
|
||||
removed = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if removed {
|
||||
blocked.waker_count -= 1;
|
||||
if blocked.waker_count == 0 {
|
||||
self.is_empty.store(true, Ordering::SeqCst);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Notifies one blocked channel operation.
|
||||
#[inline]
|
||||
fn notify_one(&self) {
|
||||
if !self.is_empty.load(Ordering::SeqCst) {
|
||||
let mut blocked = self.blocked.lock();
|
||||
|
||||
if let Some((_, opt_waker)) = blocked.entries.iter_mut().next() {
|
||||
// If there is no waker in this entry, that means it was already woken.
|
||||
if let Some(w) = opt_waker.take() {
|
||||
w.wake();
|
||||
|
||||
blocked.waker_count -= 1;
|
||||
if blocked.waker_count == 0 {
|
||||
self.is_empty.store(true, Ordering::SeqCst);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Notifies all blocked channel operations.
|
||||
#[inline]
|
||||
fn notify_all(&self) {
|
||||
if !self.is_empty.load(Ordering::SeqCst) {
|
||||
let mut blocked = self.blocked.lock();
|
||||
|
||||
for (_, opt_waker) in blocked.entries.iter_mut() {
|
||||
// If there is no waker in this entry, that means it was already woken.
|
||||
if let Some(w) = opt_waker.take() {
|
||||
w.wake();
|
||||
}
|
||||
}
|
||||
|
||||
blocked.waker_count = 0;
|
||||
self.is_empty.store(true, Ordering::SeqCst);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// A simple spinlock.
|
||||
struct Spinlock<T> {
|
||||
flag: AtomicBool,
|
||||
value: UnsafeCell<T>,
|
||||
}
|
||||
|
||||
impl<T> Spinlock<T> {
|
||||
/// Returns a new spinlock initialized with `value`.
|
||||
fn new(value: T) -> Spinlock<T> {
|
||||
Spinlock {
|
||||
flag: AtomicBool::new(false),
|
||||
value: UnsafeCell::new(value),
|
||||
}
|
||||
}
|
||||
|
||||
/// Locks the spinlock.
|
||||
fn lock(&self) -> SpinlockGuard<'_, T> {
|
||||
let backoff = Backoff::new();
|
||||
while self.flag.swap(true, Ordering::Acquire) {
|
||||
backoff.snooze();
|
||||
}
|
||||
SpinlockGuard { parent: self }
|
||||
}
|
||||
}
|
||||
|
||||
/// A guard holding a spinlock locked.
|
||||
struct SpinlockGuard<'a, T> {
|
||||
parent: &'a Spinlock<T>,
|
||||
}
|
||||
|
||||
impl<'a, T> Drop for SpinlockGuard<'a, T> {
|
||||
fn drop(&mut self) {
|
||||
self.parent.flag.store(false, Ordering::Release);
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, T> Deref for SpinlockGuard<'a, T> {
|
||||
type Target = T;
|
||||
|
||||
fn deref(&self) -> &T {
|
||||
unsafe { &*self.parent.value.get() }
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, T> DerefMut for SpinlockGuard<'a, T> {
|
||||
fn deref_mut(&mut self) -> &mut T {
|
||||
unsafe { &mut *self.parent.value.get() }
|
||||
}
|
||||
}
|
||||
|
|
|
@ -191,3 +191,6 @@ cfg_unstable! {
|
|||
mod barrier;
|
||||
mod channel;
|
||||
}
|
||||
|
||||
pub(crate) mod waker_set;
|
||||
pub(crate) use waker_set::WakerSet;
|
||||
|
|
|
@ -2,18 +2,11 @@ use std::cell::UnsafeCell;
|
|||
use std::fmt;
|
||||
use std::ops::{Deref, DerefMut};
|
||||
use std::pin::Pin;
|
||||
use std::sync::atomic::{AtomicUsize, Ordering};
|
||||
|
||||
use slab::Slab;
|
||||
use std::sync::atomic::{AtomicBool, Ordering};
|
||||
|
||||
use crate::future::Future;
|
||||
use crate::task::{Context, Poll, Waker};
|
||||
|
||||
/// Set if the mutex is locked.
|
||||
const LOCK: usize = 1;
|
||||
|
||||
/// Set if there are tasks blocked on the mutex.
|
||||
const BLOCKED: usize = 1 << 1;
|
||||
use crate::sync::WakerSet;
|
||||
use crate::task::{Context, Poll};
|
||||
|
||||
/// A mutual exclusion primitive for protecting shared data.
|
||||
///
|
||||
|
@ -49,8 +42,8 @@ const BLOCKED: usize = 1 << 1;
|
|||
/// # })
|
||||
/// ```
|
||||
pub struct Mutex<T> {
|
||||
state: AtomicUsize,
|
||||
blocked: std::sync::Mutex<Slab<Option<Waker>>>,
|
||||
locked: AtomicBool,
|
||||
wakers: WakerSet,
|
||||
value: UnsafeCell<T>,
|
||||
}
|
||||
|
||||
|
@ -69,8 +62,8 @@ impl<T> Mutex<T> {
|
|||
/// ```
|
||||
pub fn new(t: T) -> Mutex<T> {
|
||||
Mutex {
|
||||
state: AtomicUsize::new(0),
|
||||
blocked: std::sync::Mutex::new(Slab::new()),
|
||||
locked: AtomicBool::new(false),
|
||||
wakers: WakerSet::new(),
|
||||
value: UnsafeCell::new(t),
|
||||
}
|
||||
}
|
||||
|
@ -105,75 +98,46 @@ impl<T> Mutex<T> {
|
|||
pub struct LockFuture<'a, T> {
|
||||
mutex: &'a Mutex<T>,
|
||||
opt_key: Option<usize>,
|
||||
acquired: bool,
|
||||
}
|
||||
|
||||
impl<'a, T> Future for LockFuture<'a, T> {
|
||||
type Output = MutexGuard<'a, T>;
|
||||
|
||||
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
|
||||
match self.mutex.try_lock() {
|
||||
Some(guard) => {
|
||||
self.acquired = true;
|
||||
Poll::Ready(guard)
|
||||
}
|
||||
let poll = match self.mutex.try_lock() {
|
||||
Some(guard) => Poll::Ready(guard),
|
||||
None => {
|
||||
let mut blocked = self.mutex.blocked.lock().unwrap();
|
||||
|
||||
// Register the current task.
|
||||
// Insert this lock operation.
|
||||
match self.opt_key {
|
||||
None => {
|
||||
// Insert a new entry into the list of blocked tasks.
|
||||
let w = cx.waker().clone();
|
||||
let key = blocked.insert(Some(w));
|
||||
self.opt_key = Some(key);
|
||||
|
||||
if blocked.len() == 1 {
|
||||
self.mutex.state.fetch_or(BLOCKED, Ordering::Relaxed);
|
||||
}
|
||||
}
|
||||
Some(key) => {
|
||||
// There is already an entry in the list of blocked tasks. Just
|
||||
// reset the waker if it was removed.
|
||||
if blocked[key].is_none() {
|
||||
let w = cx.waker().clone();
|
||||
blocked[key] = Some(w);
|
||||
}
|
||||
}
|
||||
None => self.opt_key = Some(self.mutex.wakers.insert(cx)),
|
||||
Some(key) => self.mutex.wakers.update(key, cx),
|
||||
}
|
||||
|
||||
// Try locking again because it's possible the mutex got unlocked just
|
||||
// before the current task was registered as a blocked task.
|
||||
// before the current task was inserted into the waker set.
|
||||
match self.mutex.try_lock() {
|
||||
Some(guard) => {
|
||||
self.acquired = true;
|
||||
Poll::Ready(guard)
|
||||
}
|
||||
Some(guard) => Poll::Ready(guard),
|
||||
None => Poll::Pending,
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
if poll.is_ready() {
|
||||
// If the current task is in the set, remove it.
|
||||
if let Some(key) = self.opt_key.take() {
|
||||
self.mutex.wakers.complete(key);
|
||||
}
|
||||
}
|
||||
|
||||
poll
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> Drop for LockFuture<'_, T> {
|
||||
fn drop(&mut self) {
|
||||
// If the current task is still in the set, that means it is being cancelled now.
|
||||
if let Some(key) = self.opt_key {
|
||||
let mut blocked = self.mutex.blocked.lock().unwrap();
|
||||
let opt_waker = blocked.remove(key);
|
||||
|
||||
if opt_waker.is_none() && !self.acquired {
|
||||
// We were awoken but didn't acquire the lock. Wake up another task.
|
||||
if let Some((_, opt_waker)) = blocked.iter_mut().next() {
|
||||
if let Some(w) = opt_waker.take() {
|
||||
w.wake();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if blocked.is_empty() {
|
||||
self.mutex.state.fetch_and(!BLOCKED, Ordering::Relaxed);
|
||||
}
|
||||
self.mutex.wakers.cancel(key);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -181,7 +145,6 @@ impl<T> Mutex<T> {
|
|||
LockFuture {
|
||||
mutex: self,
|
||||
opt_key: None,
|
||||
acquired: false,
|
||||
}
|
||||
.await
|
||||
}
|
||||
|
@ -220,7 +183,7 @@ impl<T> Mutex<T> {
|
|||
/// # })
|
||||
/// ```
|
||||
pub fn try_lock(&self) -> Option<MutexGuard<'_, T>> {
|
||||
if self.state.fetch_or(LOCK, Ordering::Acquire) & LOCK == 0 {
|
||||
if !self.locked.swap(true, Ordering::SeqCst) {
|
||||
Some(MutexGuard(self))
|
||||
} else {
|
||||
None
|
||||
|
@ -266,18 +229,15 @@ impl<T> Mutex<T> {
|
|||
|
||||
impl<T: fmt::Debug> fmt::Debug for Mutex<T> {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
match self.try_lock() {
|
||||
None => {
|
||||
struct LockedPlaceholder;
|
||||
impl fmt::Debug for LockedPlaceholder {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
f.write_str("<locked>")
|
||||
}
|
||||
}
|
||||
f.debug_struct("Mutex")
|
||||
.field("data", &LockedPlaceholder)
|
||||
.finish()
|
||||
struct Locked;
|
||||
impl fmt::Debug for Locked {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
f.write_str("<locked>")
|
||||
}
|
||||
}
|
||||
|
||||
match self.try_lock() {
|
||||
None => f.debug_struct("Mutex").field("data", &Locked).finish(),
|
||||
Some(guard) => f.debug_struct("Mutex").field("data", &&*guard).finish(),
|
||||
}
|
||||
}
|
||||
|
@ -303,19 +263,11 @@ unsafe impl<T: Sync> Sync for MutexGuard<'_, T> {}
|
|||
|
||||
impl<T> Drop for MutexGuard<'_, T> {
|
||||
fn drop(&mut self) {
|
||||
let state = self.0.state.fetch_and(!LOCK, Ordering::AcqRel);
|
||||
// Use `SeqCst` ordering to synchronize with `WakerSet::insert()` and `WakerSet::update()`.
|
||||
self.0.locked.store(false, Ordering::SeqCst);
|
||||
|
||||
// If there are any blocked tasks, wake one of them up.
|
||||
if state & BLOCKED != 0 {
|
||||
let mut blocked = self.0.blocked.lock().unwrap();
|
||||
|
||||
if let Some((_, opt_waker)) = blocked.iter_mut().next() {
|
||||
// If there is no waker in this entry, that means it was already woken.
|
||||
if let Some(w) = opt_waker.take() {
|
||||
w.wake();
|
||||
}
|
||||
}
|
||||
}
|
||||
// Notify one blocked `lock()` operation.
|
||||
self.0.wakers.notify_one();
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -10,7 +10,8 @@ use crate::future::Future;
|
|||
use crate::task::{Context, Poll, Waker};
|
||||
|
||||
/// Set if a write lock is held.
|
||||
const WRITE_LOCK: usize = 1;
|
||||
#[allow(clippy::identity_op)]
|
||||
const WRITE_LOCK: usize = 1 << 0;
|
||||
|
||||
/// Set if there are read operations blocked on the lock.
|
||||
const BLOCKED_READS: usize = 1 << 1;
|
||||
|
|
200
src/sync/waker_set.rs
Normal file
200
src/sync/waker_set.rs
Normal file
|
@ -0,0 +1,200 @@
|
|||
//! A common utility for building synchronization primitives.
|
||||
//!
|
||||
//! When an async operation is blocked, it needs to register itself somewhere so that it can be
|
||||
//! notified later on. The `WakerSet` type helps with keeping track of such async operations and
|
||||
//! notifying them when they may make progress.
|
||||
|
||||
use std::cell::UnsafeCell;
|
||||
use std::ops::{Deref, DerefMut};
|
||||
use std::sync::atomic::{AtomicUsize, Ordering};
|
||||
|
||||
use crossbeam_utils::Backoff;
|
||||
use slab::Slab;
|
||||
|
||||
use crate::task::{Context, Waker};
|
||||
|
||||
/// Set when the entry list is locked.
|
||||
#[allow(clippy::identity_op)]
|
||||
const LOCKED: usize = 1 << 0;
|
||||
|
||||
/// Set when there are tasks for `notify_one()` to wake.
|
||||
const NOTIFY_ONE: usize = 1 << 1;
|
||||
|
||||
/// Set when there are tasks for `notify_all()` to wake.
|
||||
const NOTIFY_ALL: usize = 1 << 2;
|
||||
|
||||
/// Inner representation of `WakerSet`.
|
||||
struct Inner {
|
||||
/// A list of entries in the set.
|
||||
///
|
||||
/// Each entry has an optional waker associated with the task that is executing the operation.
|
||||
/// If the waker is set to `None`, that means the task has been woken up but hasn't removed
|
||||
/// itself from the `WakerSet` yet.
|
||||
///
|
||||
/// The key of each entry is its index in the `Slab`.
|
||||
entries: Slab<Option<Waker>>,
|
||||
|
||||
/// The number of entries that have the waker set to `None`.
|
||||
none_count: usize,
|
||||
}
|
||||
|
||||
/// A set holding wakers.
|
||||
pub struct WakerSet {
|
||||
/// Holds three bits: `LOCKED`, `NOTIFY_ONE`, and `NOTIFY_ALL`.
|
||||
flag: AtomicUsize,
|
||||
|
||||
/// A set holding wakers.
|
||||
inner: UnsafeCell<Inner>,
|
||||
}
|
||||
|
||||
impl WakerSet {
|
||||
/// Creates a new `WakerSet`.
|
||||
#[inline]
|
||||
pub fn new() -> WakerSet {
|
||||
WakerSet {
|
||||
flag: AtomicUsize::new(0),
|
||||
inner: UnsafeCell::new(Inner {
|
||||
entries: Slab::new(),
|
||||
none_count: 0,
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
/// Inserts a waker for a blocked operation and returns a key associated with it.
|
||||
pub fn insert(&self, cx: &Context<'_>) -> usize {
|
||||
let w = cx.waker().clone();
|
||||
self.lock().entries.insert(Some(w))
|
||||
}
|
||||
|
||||
/// Updates the waker of a previously inserted entry.
|
||||
pub fn update(&self, key: usize, cx: &Context<'_>) {
|
||||
let mut inner = self.lock();
|
||||
|
||||
match &mut inner.entries[key] {
|
||||
None => {
|
||||
// Fill in the waker.
|
||||
let w = cx.waker().clone();
|
||||
inner.entries[key] = Some(w);
|
||||
inner.none_count -= 1;
|
||||
}
|
||||
Some(w) => {
|
||||
// Replace the waker if the existing one is different.
|
||||
if !w.will_wake(cx.waker()) {
|
||||
*w = cx.waker().clone();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Removes the waker of a completed operation.
|
||||
pub fn complete(&self, key: usize) {
|
||||
let mut inner = self.lock();
|
||||
if inner.entries.remove(key).is_none() {
|
||||
inner.none_count -= 1;
|
||||
}
|
||||
}
|
||||
|
||||
/// Removes the waker of a cancelled operation.
|
||||
pub fn cancel(&self, key: usize) {
|
||||
let mut inner = self.lock();
|
||||
if inner.entries.remove(key).is_none() {
|
||||
inner.none_count -= 1;
|
||||
|
||||
// The operation was cancelled and notified so notify another operation instead.
|
||||
if let Some((_, opt_waker)) = inner.entries.iter_mut().next() {
|
||||
// If there is no waker in this entry, that means it was already woken.
|
||||
if let Some(w) = opt_waker.take() {
|
||||
w.wake();
|
||||
inner.none_count += 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Notifies one blocked operation.
|
||||
#[inline]
|
||||
pub fn notify_one(&self) {
|
||||
// Use `SeqCst` ordering to synchronize with `Lock::drop()`.
|
||||
if self.flag.load(Ordering::SeqCst) & NOTIFY_ONE != 0 {
|
||||
self.notify(false);
|
||||
}
|
||||
}
|
||||
|
||||
/// Notifies all blocked operations.
|
||||
// TODO: Delete this attribute when `crate::sync::channel()` is stabilized.
|
||||
#[cfg(feature = "unstable")]
|
||||
#[inline]
|
||||
pub fn notify_all(&self) {
|
||||
// Use `SeqCst` ordering to synchronize with `Lock::drop()`.
|
||||
if self.flag.load(Ordering::SeqCst) & NOTIFY_ALL != 0 {
|
||||
self.notify(true);
|
||||
}
|
||||
}
|
||||
|
||||
/// Notifies blocked operations, either one or all of them.
|
||||
fn notify(&self, all: bool) {
|
||||
let mut inner = &mut *self.lock();
|
||||
|
||||
for (_, opt_waker) in inner.entries.iter_mut() {
|
||||
// If there is no waker in this entry, that means it was already woken.
|
||||
if let Some(w) = opt_waker.take() {
|
||||
w.wake();
|
||||
inner.none_count += 1;
|
||||
}
|
||||
if !all {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Locks the list of entries.
|
||||
#[cold]
|
||||
fn lock(&self) -> Lock<'_> {
|
||||
let backoff = Backoff::new();
|
||||
while self.flag.fetch_or(LOCKED, Ordering::Acquire) & LOCKED != 0 {
|
||||
backoff.snooze();
|
||||
}
|
||||
Lock { waker_set: self }
|
||||
}
|
||||
}
|
||||
|
||||
/// A guard holding a `WakerSet` locked.
|
||||
struct Lock<'a> {
|
||||
waker_set: &'a WakerSet,
|
||||
}
|
||||
|
||||
impl Drop for Lock<'_> {
|
||||
#[inline]
|
||||
fn drop(&mut self) {
|
||||
let mut flag = 0;
|
||||
|
||||
// If there is at least one entry and all are `Some`, then `notify_one()` has work to do.
|
||||
if !self.entries.is_empty() && self.none_count == 0 {
|
||||
flag |= NOTIFY_ONE;
|
||||
}
|
||||
|
||||
// If there is at least one `Some` entry, then `notify_all()` has work to do.
|
||||
if self.entries.len() - self.none_count > 0 {
|
||||
flag |= NOTIFY_ALL;
|
||||
}
|
||||
|
||||
// Use `SeqCst` ordering to synchronize with `WakerSet::lock_to_notify()`.
|
||||
self.waker_set.flag.store(flag, Ordering::SeqCst);
|
||||
}
|
||||
}
|
||||
|
||||
impl Deref for Lock<'_> {
|
||||
type Target = Inner;
|
||||
|
||||
#[inline]
|
||||
fn deref(&self) -> &Inner {
|
||||
unsafe { &*self.waker_set.inner.get() }
|
||||
}
|
||||
}
|
||||
|
||||
impl DerefMut for Lock<'_> {
|
||||
#[inline]
|
||||
fn deref_mut(&mut self) -> &mut Inner {
|
||||
unsafe { &mut *self.waker_set.inner.get() }
|
||||
}
|
||||
}
|
Loading…
Reference in a new issue