From dc15533d3cdaae7d792534ee54dd625823a1a26f Mon Sep 17 00:00:00 2001 From: dignifiedquire Date: Sat, 27 Jun 2020 16:53:52 +0200 Subject: [PATCH] feat: new channels - add new top level `channels` module - reimplement existing channels using `async-channel` --- Cargo.toml | 3 + src/channel.rs | 6 + src/lib.rs | 1 + src/sync/channel.rs | 668 +++----------------------------------------- 4 files changed, 55 insertions(+), 623 deletions(-) create mode 100644 src/channel.rs diff --git a/Cargo.toml b/Cargo.toml index a5b5c4f..bca192a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -49,6 +49,7 @@ std = [ "wasm-bindgen-futures", "futures-channel", "async-mutex", + "async-channel", ] alloc = [ "futures-core/alloc", @@ -72,10 +73,12 @@ pin-project-lite = { version = "0.1.4", optional = true } pin-utils = { version = "0.1.0-alpha.4", optional = true } slab = { version = "0.4.2", optional = true } futures-timer = { version = "3.0.2", optional = true } +async-channel = { version = "1.1.1", optional = true } # Devdepencency, but they are not allowed to be optional :/ surf = { version = "1.0.3", optional = true } + [target.'cfg(not(target_os = "unknown"))'.dependencies] smol = { version = "0.1.17", optional = true } diff --git a/src/channel.rs b/src/channel.rs new file mode 100644 index 0000000..90adc14 --- /dev/null +++ b/src/channel.rs @@ -0,0 +1,6 @@ +//! Channels + +#[cfg(feature = "unstable")] +#[cfg_attr(feature = "docs", doc(cfg(unstable)))] +#[doc(inline)] +pub use async_channel::*; diff --git a/src/lib.rs b/src/lib.rs index a8ba46b..5ef1a11 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -306,6 +306,7 @@ cfg_unstable! { pub mod pin; #[cfg(not(target_os = "unknown"))] pub mod process; + pub mod channel; mod unit; mod vec; diff --git a/src/sync/channel.rs b/src/sync/channel.rs index 3207846..7d207f9 100644 --- a/src/sync/channel.rs +++ b/src/sync/channel.rs @@ -1,21 +1,15 @@ -use std::cell::UnsafeCell; -use std::error::Error; -use std::fmt::{self, Debug, Display}; +use std::fmt; use std::future::Future; -use std::isize; -use std::marker::PhantomData; -use std::mem; use std::pin::Pin; -use std::process; -use std::ptr; -use std::sync::atomic::{self, AtomicUsize, Ordering}; -use std::sync::Arc; use std::task::{Context, Poll}; -use crossbeam_utils::Backoff; - use crate::stream::Stream; -use crate::sync::WakerSet; +use crate::channel::{bounded, Sender as InnerSender, Receiver as InnerReceiver, SendError}; + +#[cfg(feature = "unstable")] +#[cfg_attr(feature = "docs", doc(cfg(unstable)))] +#[doc(inline)] +pub use crate::channel::{RecvError, TryRecvError, TrySendError}; /// Creates a bounded multi-producer multi-consumer channel. /// @@ -61,14 +55,16 @@ use crate::sync::WakerSet; #[cfg(feature = "unstable")] #[cfg_attr(feature = "docs", doc(cfg(unstable)))] pub fn channel(cap: usize) -> (Sender, Receiver) { - let channel = Arc::new(Channel::with_capacity(cap)); + let (sender, receiver) = bounded(cap); + let s = Sender { - channel: channel.clone(), + sender, }; + let r = Receiver { - channel, - opt_key: None, + receiver, }; + (s, r) } @@ -103,8 +99,8 @@ pub fn channel(cap: usize) -> (Sender, Receiver) { #[cfg(feature = "unstable")] #[cfg_attr(feature = "docs", doc(cfg(unstable)))] pub struct Sender { - /// The inner channel. - channel: Arc>, + /// The sender. + sender: InnerSender, } impl Sender { @@ -137,62 +133,31 @@ impl Sender { /// ``` pub async fn send(&self, msg: T) { struct SendFuture<'a, T> { - channel: &'a Channel, - msg: Option, - opt_key: Option, + inner: Pin>> + 'a>>, } + unsafe impl Send for SendFuture<'_, T> {} + impl Unpin for SendFuture<'_, T> {} impl Future for SendFuture<'_, T> { type Output = (); fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - loop { - let msg = self.msg.take().unwrap(); - - // If the current task is in the set, remove it. - if let Some(key) = self.opt_key.take() { - self.channel.send_wakers.remove(key); - } - - // Try sending the message. - match self.channel.try_send(msg) { - Ok(()) => return Poll::Ready(()), - Err(TrySendError::Disconnected(msg)) => { - self.msg = Some(msg); - return Poll::Pending; - } - Err(TrySendError::Full(msg)) => { - self.msg = Some(msg); - - // Insert this send operation. - self.opt_key = Some(self.channel.send_wakers.insert(cx)); - - // If the channel is still full and not disconnected, return. - if self.channel.is_full() && !self.channel.is_disconnected() { - return Poll::Pending; - } - } + match Pin::new(&mut self.inner).poll(cx) { + Poll::Pending => Poll::Pending, + Poll::Ready(Ok(())) => Poll::Ready(()), + Poll::Ready(Err(_)) => { + log::error!("attempted to send on dropped receiver"); + // return pending to preserve old behaviour + Poll::Pending } } } - } - - impl Drop for SendFuture<'_, T> { - fn drop(&mut self) { - // If the current task is still in the set, that means it is being cancelled now. - // Wake up another task instead. - if let Some(key) = self.opt_key { - self.channel.send_wakers.cancel(key); - } - } - } + } SendFuture { - channel: &self.channel, - msg: Some(msg), - opt_key: None, + inner: Box::pin(self.sender.send(msg)), } .await } @@ -215,7 +180,7 @@ impl Sender { /// # }) /// ``` pub fn try_send(&self, msg: T) -> Result<(), TrySendError> { - self.channel.try_send(msg) + self.sender.try_send(msg) } /// Returns the channel capacity. @@ -229,7 +194,7 @@ impl Sender { /// assert_eq!(s.capacity(), 5); /// ``` pub fn capacity(&self) -> usize { - self.channel.cap + self.sender.capacity().unwrap() } /// Returns `true` if the channel is empty. @@ -250,7 +215,7 @@ impl Sender { /// # }) /// ``` pub fn is_empty(&self) -> bool { - self.channel.is_empty() + self.sender.is_empty() } /// Returns `true` if the channel is full. @@ -271,7 +236,7 @@ impl Sender { /// # }) /// ``` pub fn is_full(&self) -> bool { - self.channel.is_full() + self.sender.is_full() } /// Returns the number of messages in the channel. @@ -293,30 +258,14 @@ impl Sender { /// # }) /// ``` pub fn len(&self) -> usize { - self.channel.len() - } -} - -impl Drop for Sender { - fn drop(&mut self) { - // Decrement the sender count and disconnect the channel if it drops down to zero. - if self.channel.sender_count.fetch_sub(1, Ordering::AcqRel) == 1 { - self.channel.disconnect(); - } + self.sender.len() } } impl Clone for Sender { - fn clone(&self) -> Sender { - let count = self.channel.sender_count.fetch_add(1, Ordering::Relaxed); - - // Make sure the count never overflows, even if lots of sender clones are leaked. - if count > isize::MAX as usize { - process::abort(); - } - + fn clone(&self) -> Self { Sender { - channel: self.channel.clone(), + sender: self.sender.clone(), } } } @@ -364,11 +313,8 @@ impl fmt::Debug for Sender { #[cfg(feature = "unstable")] #[cfg_attr(feature = "docs", doc(cfg(unstable)))] pub struct Receiver { - /// The inner channel. - channel: Arc>, - - /// The key for this receiver in the `channel.stream_wakers` set. - opt_key: Option, + /// The inner receiver. + receiver: InnerReceiver, } impl Receiver { @@ -403,38 +349,7 @@ impl Receiver { /// # }) } /// ``` pub async fn recv(&self) -> Result { - struct RecvFuture<'a, T> { - channel: &'a Channel, - opt_key: Option, - } - - impl Future for RecvFuture<'_, T> { - type Output = Result; - - fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - poll_recv( - &self.channel, - &self.channel.recv_wakers, - &mut self.opt_key, - cx, - ) - } - } - - impl Drop for RecvFuture<'_, T> { - fn drop(&mut self) { - // If the current task is still in the set, that means it is being cancelled now. - if let Some(key) = self.opt_key { - self.channel.recv_wakers.cancel(key); - } - } - } - - RecvFuture { - channel: &self.channel, - opt_key: None, - } - .await + self.receiver.recv().await } /// Attempts to receive a message from the channel. @@ -458,7 +373,7 @@ impl Receiver { /// # }) /// ``` pub fn try_recv(&self) -> Result { - self.channel.try_recv() + self.receiver.try_recv() } /// Returns the channel capacity. @@ -472,7 +387,7 @@ impl Receiver { /// assert_eq!(r.capacity(), 5); /// ``` pub fn capacity(&self) -> usize { - self.channel.cap + self.receiver.capacity().unwrap() } /// Returns `true` if the channel is empty. @@ -493,7 +408,7 @@ impl Receiver { /// # }) /// ``` pub fn is_empty(&self) -> bool { - self.channel.is_empty() + self.receiver.is_empty() } /// Returns `true` if the channel is full. @@ -514,7 +429,7 @@ impl Receiver { /// # }) /// ``` pub fn is_full(&self) -> bool { - self.channel.is_full() + self.receiver.is_full() } /// Returns the number of messages in the channel. @@ -536,36 +451,14 @@ impl Receiver { /// # }) /// ``` pub fn len(&self) -> usize { - self.channel.len() - } -} - -impl Drop for Receiver { - fn drop(&mut self) { - // If the current task is still in the stream set, that means it is being cancelled now. - if let Some(key) = self.opt_key { - self.channel.stream_wakers.cancel(key); - } - - // Decrement the receiver count and disconnect the channel if it drops down to zero. - if self.channel.receiver_count.fetch_sub(1, Ordering::AcqRel) == 1 { - self.channel.disconnect(); - } + self.receiver.len() } } impl Clone for Receiver { - fn clone(&self) -> Receiver { - let count = self.channel.receiver_count.fetch_add(1, Ordering::Relaxed); - - // Make sure the count never overflows, even if lots of receiver clones are leaked. - if count > isize::MAX as usize { - process::abort(); - } - + fn clone(&self) -> Self { Receiver { - channel: self.channel.clone(), - opt_key: None, + receiver: self.receiver.clone(), } } } @@ -574,14 +467,7 @@ impl Stream for Receiver { type Item = T; fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - let this = &mut *self; - let res = futures_core::ready!(poll_recv( - &this.channel, - &this.channel.stream_wakers, - &mut this.opt_key, - cx, - )); - Poll::Ready(res.ok()) + Pin::new(&mut self.receiver).poll_next(cx) } } @@ -591,469 +477,5 @@ impl fmt::Debug for Receiver { } } -/// Polls a receive operation on a channel. -/// -/// If the receive operation is blocked, the current task will be inserted into `wakers` and its -/// associated key will then be stored in `opt_key`. -fn poll_recv( - channel: &Channel, - wakers: &WakerSet, - opt_key: &mut Option, - cx: &mut Context<'_>, -) -> Poll> { - loop { - // If the current task is in the set, remove it. - if let Some(key) = opt_key.take() { - wakers.remove(key); - } - - // Try receiving a message. - match channel.try_recv() { - Ok(msg) => return Poll::Ready(Ok(msg)), - Err(TryRecvError::Disconnected) => return Poll::Ready(Err(RecvError {})), - Err(TryRecvError::Empty) => { - // Insert this receive operation. - *opt_key = Some(wakers.insert(cx)); - - // If the channel is still empty and not disconnected, return. - if channel.is_empty() && !channel.is_disconnected() { - return Poll::Pending; - } - } - } - } -} - -/// A slot in a channel. -struct Slot { - /// The current stamp. - stamp: AtomicUsize, - - /// The message in this slot. - msg: UnsafeCell, -} - -/// Bounded channel based on a preallocated array. -struct Channel { - /// The head of the channel. - /// - /// This value is a "stamp" consisting of an index into the buffer, a mark bit, and a lap, but - /// packed into a single `usize`. The lower bits represent the index, while the upper bits - /// represent the lap. The mark bit in the head is always zero. - /// - /// Messages are popped from the head of the channel. - head: AtomicUsize, - - /// The tail of the channel. - /// - /// This value is a "stamp" consisting of an index into the buffer, a mark bit, and a lap, but - /// packed into a single `usize`. The lower bits represent the index, while the upper bits - /// represent the lap. The mark bit indicates that the channel is disconnected. - /// - /// Messages are pushed into the tail of the channel. - tail: AtomicUsize, - - /// The buffer holding slots. - buffer: *mut Slot, - - /// The channel capacity. - cap: usize, - - /// A stamp with the value of `{ lap: 1, mark: 0, index: 0 }`. - one_lap: usize, - - /// If this bit is set in the tail, that means either all senders were dropped or all receivers - /// were dropped. - mark_bit: usize, - - /// Send operations waiting while the channel is full. - send_wakers: WakerSet, - - /// Receive operations waiting while the channel is empty and not disconnected. - recv_wakers: WakerSet, - - /// Streams waiting while the channel is empty and not disconnected. - stream_wakers: WakerSet, - - /// The number of currently active `Sender`s. - sender_count: AtomicUsize, - - /// The number of currently active `Receivers`s. - receiver_count: AtomicUsize, - - /// Indicates that dropping a `Channel` may drop values of type `T`. - _marker: PhantomData, -} - -unsafe impl Send for Channel {} -unsafe impl Sync for Channel {} -impl Unpin for Channel {} - -impl Channel { - /// Creates a bounded channel of capacity `cap`. - fn with_capacity(cap: usize) -> Self { - assert!(cap > 0, "capacity must be positive"); - - // Compute constants `mark_bit` and `one_lap`. - let mark_bit = (cap + 1).next_power_of_two(); - let one_lap = mark_bit * 2; - - // Head is initialized to `{ lap: 0, mark: 0, index: 0 }`. - let head = 0; - // Tail is initialized to `{ lap: 0, mark: 0, index: 0 }`. - let tail = 0; - - // Allocate a buffer of `cap` slots. - let buffer = { - let mut v = Vec::>::with_capacity(cap); - let ptr = v.as_mut_ptr(); - mem::forget(v); - ptr - }; - - // Initialize stamps in the slots. - for i in 0..cap { - unsafe { - // Set the stamp to `{ lap: 0, mark: 0, index: i }`. - let slot = buffer.add(i); - ptr::write(&mut (*slot).stamp, AtomicUsize::new(i)); - } - } - - Channel { - buffer, - cap, - one_lap, - mark_bit, - head: AtomicUsize::new(head), - tail: AtomicUsize::new(tail), - send_wakers: WakerSet::new(), - recv_wakers: WakerSet::new(), - stream_wakers: WakerSet::new(), - sender_count: AtomicUsize::new(1), - receiver_count: AtomicUsize::new(1), - _marker: PhantomData, - } - } - - /// Attempts to send a message. - fn try_send(&self, msg: T) -> Result<(), TrySendError> { - let backoff = Backoff::new(); - let mut tail = self.tail.load(Ordering::Relaxed); - - loop { - // Extract mark bit from the tail and unset it. - // - // If the mark bit was set (which means all receivers have been dropped), we will still - // send the message into the channel if there is enough capacity. The message will get - // dropped when the channel is dropped (which means when all senders are also dropped). - let mark_bit = tail & self.mark_bit; - tail ^= mark_bit; - - // Deconstruct the tail. - let index = tail & (self.mark_bit - 1); - let lap = tail & !(self.one_lap - 1); - - // Inspect the corresponding slot. - let slot = unsafe { &*self.buffer.add(index) }; - let stamp = slot.stamp.load(Ordering::Acquire); - - // If the tail and the stamp match, we may attempt to push. - if tail == stamp { - let new_tail = if index + 1 < self.cap { - // Same lap, incremented index. - // Set to `{ lap: lap, mark: 0, index: index + 1 }`. - tail + 1 - } else { - // One lap forward, index wraps around to zero. - // Set to `{ lap: lap.wrapping_add(1), mark: 0, index: 0 }`. - lap.wrapping_add(self.one_lap) - }; - - // Try moving the tail. - match self.tail.compare_exchange_weak( - tail | mark_bit, - new_tail | mark_bit, - Ordering::SeqCst, - Ordering::Relaxed, - ) { - Ok(_) => { - // Write the message into the slot and update the stamp. - unsafe { slot.msg.get().write(msg) }; - let stamp = tail + 1; - slot.stamp.store(stamp, Ordering::Release); - - // Wake a blocked receive operation. - self.recv_wakers.notify_one(); - - // Wake all blocked streams. - self.stream_wakers.notify_all(); - - return Ok(()); - } - Err(t) => { - tail = t; - backoff.spin(); - } - } - } else if stamp.wrapping_add(self.one_lap) == tail + 1 { - atomic::fence(Ordering::SeqCst); - let head = self.head.load(Ordering::Relaxed); - - // If the head lags one lap behind the tail as well... - if head.wrapping_add(self.one_lap) == tail { - // ...then the channel is full. - - // Check if the channel is disconnected. - if mark_bit != 0 { - return Err(TrySendError::Disconnected(msg)); - } else { - return Err(TrySendError::Full(msg)); - } - } - - backoff.spin(); - tail = self.tail.load(Ordering::Relaxed); - } else { - // Snooze because we need to wait for the stamp to get updated. - backoff.snooze(); - tail = self.tail.load(Ordering::Relaxed); - } - } - } - - /// Attempts to receive a message. - fn try_recv(&self) -> Result { - let backoff = Backoff::new(); - let mut head = self.head.load(Ordering::Relaxed); - - loop { - // Deconstruct the head. - let index = head & (self.mark_bit - 1); - let lap = head & !(self.one_lap - 1); - - // Inspect the corresponding slot. - let slot = unsafe { &*self.buffer.add(index) }; - let stamp = slot.stamp.load(Ordering::Acquire); - - // If the the stamp is ahead of the head by 1, we may attempt to pop. - if head + 1 == stamp { - let new = if index + 1 < self.cap { - // Same lap, incremented index. - // Set to `{ lap: lap, mark: 0, index: index + 1 }`. - head + 1 - } else { - // One lap forward, index wraps around to zero. - // Set to `{ lap: lap.wrapping_add(1), mark: 0, index: 0 }`. - lap.wrapping_add(self.one_lap) - }; - - // Try moving the head. - match self.head.compare_exchange_weak( - head, - new, - Ordering::SeqCst, - Ordering::Relaxed, - ) { - Ok(_) => { - // Read the message from the slot and update the stamp. - let msg = unsafe { slot.msg.get().read() }; - let stamp = head.wrapping_add(self.one_lap); - slot.stamp.store(stamp, Ordering::Release); - - // Wake a blocked send operation. - self.send_wakers.notify_one(); - - return Ok(msg); - } - Err(h) => { - head = h; - backoff.spin(); - } - } - } else if stamp == head { - atomic::fence(Ordering::SeqCst); - let tail = self.tail.load(Ordering::Relaxed); - - // If the tail equals the head, that means the channel is empty. - if (tail & !self.mark_bit) == head { - // If the channel is disconnected... - if tail & self.mark_bit != 0 { - return Err(TryRecvError::Disconnected); - } else { - // Otherwise, the receive operation is not ready. - return Err(TryRecvError::Empty); - } - } - - backoff.spin(); - head = self.head.load(Ordering::Relaxed); - } else { - // Snooze because we need to wait for the stamp to get updated. - backoff.snooze(); - head = self.head.load(Ordering::Relaxed); - } - } - } - - /// Returns the current number of messages inside the channel. - fn len(&self) -> usize { - loop { - // Load the tail, then load the head. - let tail = self.tail.load(Ordering::SeqCst); - let head = self.head.load(Ordering::SeqCst); - - // If the tail didn't change, we've got consistent values to work with. - if self.tail.load(Ordering::SeqCst) == tail { - let hix = head & (self.mark_bit - 1); - let tix = tail & (self.mark_bit - 1); - - return if hix < tix { - tix - hix - } else if hix > tix { - self.cap - hix + tix - } else if (tail & !self.mark_bit) == head { - 0 - } else { - self.cap - }; - } - } - } - - /// Returns `true` if the channel is disconnected. - pub fn is_disconnected(&self) -> bool { - self.tail.load(Ordering::SeqCst) & self.mark_bit != 0 - } - - /// Returns `true` if the channel is empty. - fn is_empty(&self) -> bool { - let head = self.head.load(Ordering::SeqCst); - let tail = self.tail.load(Ordering::SeqCst); - - // Is the tail equal to the head? - // - // Note: If the head changes just before we load the tail, that means there was a moment - // when the channel was not empty, so it is safe to just return `false`. - (tail & !self.mark_bit) == head - } - - /// Returns `true` if the channel is full. - fn is_full(&self) -> bool { - let tail = self.tail.load(Ordering::SeqCst); - let head = self.head.load(Ordering::SeqCst); - - // Is the head lagging one lap behind tail? - // - // Note: If the tail changes just before we load the head, that means there was a moment - // when the channel was not full, so it is safe to just return `false`. - head.wrapping_add(self.one_lap) == tail & !self.mark_bit - } - - /// Disconnects the channel and wakes up all blocked operations. - fn disconnect(&self) { - let tail = self.tail.fetch_or(self.mark_bit, Ordering::SeqCst); - - if tail & self.mark_bit == 0 { - // Notify everyone blocked on this channel. - self.send_wakers.notify_all(); - self.recv_wakers.notify_all(); - self.stream_wakers.notify_all(); - } - } -} - -impl Drop for Channel { - fn drop(&mut self) { - // Get the index of the head. - let hix = self.head.load(Ordering::Relaxed) & (self.mark_bit - 1); - - // Loop over all slots that hold a message and drop them. - for i in 0..self.len() { - // Compute the index of the next slot holding a message. - let index = if hix + i < self.cap { - hix + i - } else { - hix + i - self.cap - }; - - unsafe { - self.buffer.add(index).drop_in_place(); - } - } - - // Finally, deallocate the buffer, but don't run any destructors. - unsafe { - Vec::from_raw_parts(self.buffer, 0, self.cap); - } - } -} - -/// An error returned from the `try_send` method. -#[cfg(feature = "unstable")] -#[cfg_attr(feature = "docs", doc(cfg(unstable)))] -#[derive(PartialEq, Eq)] -pub enum TrySendError { - /// The channel is full but not disconnected. - Full(T), - - /// The channel is full and disconnected. - Disconnected(T), -} - -impl Error for TrySendError {} - -impl Debug for TrySendError { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - match self { - Self::Full(_) => Debug::fmt("Full", f), - Self::Disconnected(_) => Debug::fmt("Disconnected", f), - } - } -} - -impl Display for TrySendError { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - match self { - Self::Full(_) => Display::fmt("The channel is full.", f), - Self::Disconnected(_) => Display::fmt("The channel is full and disconnected.", f), - } - } -} - -/// An error returned from the `try_recv` method. -#[cfg(feature = "unstable")] -#[cfg_attr(feature = "docs", doc(cfg(unstable)))] -#[derive(Debug, PartialEq, Eq)] -pub enum TryRecvError { - /// The channel is empty but not disconnected. - Empty, - - /// The channel is empty and disconnected. - Disconnected, -} -impl Error for TryRecvError {} -impl Display for TryRecvError { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - match self { - Self::Empty => Display::fmt("The channel is empty.", f), - Self::Disconnected => Display::fmt("The channel is empty and disconnected.", f), - } - } -} - -/// An error returned from the `recv` method. -#[cfg(feature = "unstable")] -#[cfg_attr(feature = "docs", doc(cfg(unstable)))] -#[derive(Debug, PartialEq, Eq)] -pub struct RecvError; - -impl Error for RecvError {} - -impl Display for RecvError { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - Display::fmt("The channel is empty.", f) - } -}