diff --git a/Cargo.toml b/Cargo.toml index 961a8cd4..489bf2c5 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -61,6 +61,7 @@ tokio02 = ["tokio"] [dependencies] async-attributes = { version = "1.1.1", optional = true } +async-rwlock = "1.0.1" async-task = { version = "3.0.0", optional = true } async-mutex = { version = "1.1.3", optional = true } crossbeam-utils = { version = "0.7.2", optional = true } diff --git a/src/sync/rwlock.rs b/src/sync/rwlock.rs index 08d8ed84..b66c6d44 100644 --- a/src/sync/rwlock.rs +++ b/src/sync/rwlock.rs @@ -1,24 +1,11 @@ -use std::cell::UnsafeCell; +use async_rwlock::{ + RwLock as RawRwLock, RwLockReadGuard as RawRwLockReadGuard, + RwLockWriteGuard as RawRwLockWriteGuard, +}; + use std::fmt; -use std::isize; use std::ops::{Deref, DerefMut}; -use std::pin::Pin; -use std::process; -use std::future::Future; -use std::sync::atomic::{AtomicUsize, Ordering}; - -use crate::sync::WakerSet; -use crate::task::{Context, Poll}; - -/// Set if a write lock is held. -#[allow(clippy::identity_op)] -const WRITE_LOCK: usize = 1 << 0; - -/// The value of a single blocked read contributing to the read count. -const ONE_READ: usize = 1 << 1; - -/// The bits in which the read count is stored. -const READ_COUNT_MASK: usize = !(ONE_READ - 1); +use std::sync::atomic::Ordering; /// A reader-writer lock for protecting shared data. /// @@ -50,10 +37,7 @@ const READ_COUNT_MASK: usize = !(ONE_READ - 1); /// # }) /// ``` pub struct RwLock { - state: AtomicUsize, - read_wakers: WakerSet, - write_wakers: WakerSet, - value: UnsafeCell, + inner: RawRwLock, } unsafe impl Send for RwLock {} @@ -70,11 +54,8 @@ impl RwLock { /// let lock = RwLock::new(0); /// ``` pub fn new(t: T) -> RwLock { - RwLock { - state: AtomicUsize::new(0), - read_wakers: WakerSet::new(), - write_wakers: WakerSet::new(), - value: UnsafeCell::new(t), + Self { + inner: RawRwLock::new(t), } } } @@ -101,58 +82,7 @@ impl RwLock { /// # }) /// ``` pub async fn read(&self) -> RwLockReadGuard<'_, T> { - pub struct ReadFuture<'a, T: ?Sized> { - lock: &'a RwLock, - opt_key: Option, - } - - impl<'a, T: ?Sized> Future for ReadFuture<'a, T> { - type Output = RwLockReadGuard<'a, T>; - - fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - loop { - // If the current task is in the set, remove it. - if let Some(key) = self.opt_key.take() { - self.lock.read_wakers.remove(key); - } - - // Try acquiring a read lock. - match self.lock.try_read() { - Some(guard) => return Poll::Ready(guard), - None => { - // Insert this lock operation. - self.opt_key = Some(self.lock.read_wakers.insert(cx)); - - // If the lock is still acquired for writing, return. - if self.lock.state.load(Ordering::SeqCst) & WRITE_LOCK != 0 { - return Poll::Pending; - } - } - } - } - } - } - - impl Drop for ReadFuture<'_, T> { - fn drop(&mut self) { - // If the current task is still in the set, that means it is being cancelled now. - if let Some(key) = self.opt_key { - self.lock.read_wakers.cancel(key); - - // If there are no active readers, notify a blocked writer if none were - // notified already. - if self.lock.state.load(Ordering::SeqCst) & READ_COUNT_MASK == 0 { - self.lock.write_wakers.notify_any(); - } - } - } - } - - ReadFuture { - lock: self, - opt_key: None, - } - .await + self.inner.read() } /// Attempts to acquire a read lock. @@ -179,30 +109,7 @@ impl RwLock { /// # }) /// ``` pub fn try_read(&self) -> Option> { - let mut state = self.state.load(Ordering::SeqCst); - - loop { - // If a write lock is currently held, then a read lock cannot be acquired. - if state & WRITE_LOCK != 0 { - return None; - } - - // Make sure the number of readers doesn't overflow. - if state > isize::MAX as usize { - process::abort(); - } - - // Increment the number of active reads. - match self.state.compare_exchange_weak( - state, - state + ONE_READ, - Ordering::SeqCst, - Ordering::SeqCst, - ) { - Ok(_) => return Some(RwLockReadGuard(self)), - Err(s) => state = s, - } - } + self.inner.try_read() } /// Acquires a write lock. @@ -226,55 +133,7 @@ impl RwLock { /// # }) /// ``` pub async fn write(&self) -> RwLockWriteGuard<'_, T> { - pub struct WriteFuture<'a, T: ?Sized> { - lock: &'a RwLock, - opt_key: Option, - } - - impl<'a, T: ?Sized> Future for WriteFuture<'a, T> { - type Output = RwLockWriteGuard<'a, T>; - - fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - loop { - // If the current task is in the set, remove it. - if let Some(key) = self.opt_key.take() { - self.lock.write_wakers.remove(key); - } - - // Try acquiring a write lock. - match self.lock.try_write() { - Some(guard) => return Poll::Ready(guard), - None => { - // Insert this lock operation. - self.opt_key = Some(self.lock.write_wakers.insert(cx)); - - // If the lock is still acquired for reading or writing, return. - if self.lock.state.load(Ordering::SeqCst) != 0 { - return Poll::Pending; - } - } - } - } - } - } - - impl Drop for WriteFuture<'_, T> { - fn drop(&mut self) { - // If the current task is still in the set, that means it is being cancelled now. - if let Some(key) = self.opt_key { - if !self.lock.write_wakers.cancel(key) { - // If no other blocked reader was notified, notify all readers. - self.lock.read_wakers.notify_all(); - } - } - } - } - - WriteFuture { - lock: self, - opt_key: None, - } - .await + self.inner.write() } /// Attempts to acquire a write lock. @@ -301,11 +160,7 @@ impl RwLock { /// # }) /// ``` pub fn try_write(&self) -> Option> { - if self.state.compare_and_swap(0, WRITE_LOCK, Ordering::SeqCst) == 0 { - Some(RwLockWriteGuard(self)) - } else { - None - } + self.inner.try_write() } /// Consumes the lock, returning the underlying data. @@ -318,8 +173,11 @@ impl RwLock { /// let lock = RwLock::new(10); /// assert_eq!(lock.into_inner(), 10); /// ``` - pub fn into_inner(self) -> T where T: Sized { - self.value.into_inner() + pub fn into_inner(self) -> T + where + T: Sized, + { + self.inner.into_inner() } /// Returns a mutable reference to the underlying data. @@ -341,23 +199,13 @@ impl RwLock { /// # }) /// ``` pub fn get_mut(&mut self) -> &mut T { - unsafe { &mut *self.value.get() } + self.inner.get_mut() } } impl fmt::Debug for RwLock { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - struct Locked; - impl fmt::Debug for Locked { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.write_str("") - } - } - - match self.try_read() { - None => f.debug_struct("RwLock").field("data", &Locked).finish(), - Some(guard) => f.debug_struct("RwLock").field("data", &&*guard).finish(), - } + self.inner.fmt(f) } } @@ -374,31 +222,26 @@ impl Default for RwLock { } /// A guard that releases the read lock when dropped. -pub struct RwLockReadGuard<'a, T: ?Sized>(&'a RwLock); +pub struct RwLockReadGuard<'a, T: ?Sized>(&'a RawRwLockReadGuard); unsafe impl Send for RwLockReadGuard<'_, T> {} unsafe impl Sync for RwLockReadGuard<'_, T> {} impl Drop for RwLockReadGuard<'_, T> { fn drop(&mut self) { - let state = self.0.state.fetch_sub(ONE_READ, Ordering::SeqCst); - - // If this was the last reader, notify a blocked writer if none were notified already. - if state & READ_COUNT_MASK == ONE_READ { - self.0.write_wakers.notify_any(); - } + self.0.drop() } } impl fmt::Debug for RwLockReadGuard<'_, T> { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - fmt::Debug::fmt(&**self, f) + self.0.fmt(f) } } impl fmt::Display for RwLockReadGuard<'_, T> { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - (**self).fmt(f) + self.0.fmt(f) } } @@ -406,7 +249,7 @@ impl Deref for RwLockReadGuard<'_, T> { type Target = T; fn deref(&self) -> &T { - unsafe { &*self.0.value.get() } + self.0.deref() } }