diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index c622a59..d47eabc 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -11,6 +11,8 @@ jobs: build_and_test: name: Build and test runs-on: ${{ matrix.os }} + env: + RUSTFLAGS: -Dwarnings strategy: matrix: os: [ubuntu-latest, windows-latest, macOS-latest] @@ -29,7 +31,7 @@ jobs: uses: actions-rs/cargo@v1 with: command: check - args: --all --benches --bins --examples --tests + args: --all --bins --examples - name: check unstable uses: actions-rs/cargo@v1 @@ -46,6 +48,8 @@ jobs: check_fmt_and_docs: name: Checking fmt and docs runs-on: ubuntu-latest + env: + RUSTFLAGS: -Dwarnings steps: - uses: actions/checkout@master @@ -77,6 +81,9 @@ jobs: clippy_check: name: Clippy check runs-on: ubuntu-latest + # TODO: There is a lot of warnings + # env: + # RUSTFLAGS: -Dwarnings steps: - uses: actions/checkout@v1 - id: component diff --git a/Cargo.toml b/Cargo.toml index 2741a19..7aaba68 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -43,9 +43,11 @@ pin-utils = "0.1.0-alpha.4" slab = "0.4.2" kv-log-macro = "1.0.4" broadcaster = { version = "0.2.6", optional = true, default-features = false, features = ["default-channels"] } +pin-project-lite = "0.1" [dev-dependencies] femme = "1.2.0" +rand = "0.7.2" # surf = "1.0.2" tempdir = "0.3.7" futures-preview = { version = "=0.3.0-alpha.19", features = ["async-await"] } diff --git a/src/future/timeout.rs b/src/future/timeout.rs index a8338fb..c745d73 100644 --- a/src/future/timeout.rs +++ b/src/future/timeout.rs @@ -4,6 +4,7 @@ use std::pin::Pin; use std::time::Duration; use futures_timer::Delay; +use pin_project_lite::pin_project; use crate::future::Future; use crate::task::{Context, Poll}; @@ -39,24 +40,24 @@ where f.await } -/// A future that times out after a duration of time. -struct TimeoutFuture { - future: F, - delay: Delay, -} - -impl TimeoutFuture { - pin_utils::unsafe_pinned!(future: F); - pin_utils::unsafe_pinned!(delay: Delay); +pin_project! { + /// A future that times out after a duration of time. + struct TimeoutFuture { + #[pin] + future: F, + #[pin] + delay: Delay, + } } impl Future for TimeoutFuture { type Output = Result; - fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - match self.as_mut().future().poll(cx) { + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let this = self.project(); + match this.future.poll(cx) { Poll::Ready(v) => Poll::Ready(Ok(v)), - Poll::Pending => match self.delay().poll(cx) { + Poll::Pending => match this.delay.poll(cx) { Poll::Ready(_) => Poll::Ready(Err(TimeoutError { _private: () })), Poll::Pending => Poll::Pending, }, diff --git a/src/io/buf_read/lines.rs b/src/io/buf_read/lines.rs index 6cb4a07..c60529c 100644 --- a/src/io/buf_read/lines.rs +++ b/src/io/buf_read/lines.rs @@ -2,50 +2,55 @@ use std::mem; use std::pin::Pin; use std::str; +use pin_project_lite::pin_project; + use super::read_until_internal; use crate::io::{self, BufRead}; use crate::stream::Stream; use crate::task::{Context, Poll}; -/// A stream of lines in a byte stream. -/// -/// This stream is created by the [`lines`] method on types that implement [`BufRead`]. -/// -/// This type is an async version of [`std::io::Lines`]. -/// -/// [`lines`]: trait.BufRead.html#method.lines -/// [`BufRead`]: trait.BufRead.html -/// [`std::io::Lines`]: https://doc.rust-lang.org/std/io/struct.Lines.html -#[derive(Debug)] -pub struct Lines { - pub(crate) reader: R, - pub(crate) buf: String, - pub(crate) bytes: Vec, - pub(crate) read: usize, +pin_project! { + /// A stream of lines in a byte stream. + /// + /// This stream is created by the [`lines`] method on types that implement [`BufRead`]. + /// + /// This type is an async version of [`std::io::Lines`]. + /// + /// [`lines`]: trait.BufRead.html#method.lines + /// [`BufRead`]: trait.BufRead.html + /// [`std::io::Lines`]: https://doc.rust-lang.org/std/io/struct.Lines.html + #[derive(Debug)] + pub struct Lines { + #[pin] + pub(crate) reader: R, + pub(crate) buf: String, + pub(crate) bytes: Vec, + pub(crate) read: usize, + } } impl Stream for Lines { type Item = io::Result; fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - let Self { - reader, - buf, - bytes, - read, - } = unsafe { self.get_unchecked_mut() }; - let reader = unsafe { Pin::new_unchecked(reader) }; - let n = futures_core::ready!(read_line_internal(reader, cx, buf, bytes, read))?; - if n == 0 && buf.is_empty() { + let this = self.project(); + let n = futures_core::ready!(read_line_internal( + this.reader, + cx, + this.buf, + this.bytes, + this.read + ))?; + if n == 0 && this.buf.is_empty() { return Poll::Ready(None); } - if buf.ends_with('\n') { - buf.pop(); - if buf.ends_with('\r') { - buf.pop(); + if this.buf.ends_with('\n') { + this.buf.pop(); + if this.buf.ends_with('\r') { + this.buf.pop(); } } - Poll::Ready(Some(Ok(mem::replace(buf, String::new())))) + Poll::Ready(Some(Ok(mem::replace(this.buf, String::new())))) } } diff --git a/src/io/buf_read/split.rs b/src/io/buf_read/split.rs index aa3b6fb..229a99b 100644 --- a/src/io/buf_read/split.rs +++ b/src/io/buf_read/split.rs @@ -1,46 +1,51 @@ use std::mem; use std::pin::Pin; +use pin_project_lite::pin_project; + use super::read_until_internal; use crate::io::{self, BufRead}; use crate::stream::Stream; use crate::task::{Context, Poll}; -/// A stream over the contents of an instance of [`BufRead`] split on a particular byte. -/// -/// This stream is created by the [`split`] method on types that implement [`BufRead`]. -/// -/// This type is an async version of [`std::io::Split`]. -/// -/// [`split`]: trait.BufRead.html#method.lines -/// [`BufRead`]: trait.BufRead.html -/// [`std::io::Split`]: https://doc.rust-lang.org/std/io/struct.Split.html -#[derive(Debug)] -pub struct Split { - pub(crate) reader: R, - pub(crate) buf: Vec, - pub(crate) read: usize, - pub(crate) delim: u8, +pin_project! { + /// A stream over the contents of an instance of [`BufRead`] split on a particular byte. + /// + /// This stream is created by the [`split`] method on types that implement [`BufRead`]. + /// + /// This type is an async version of [`std::io::Split`]. + /// + /// [`split`]: trait.BufRead.html#method.lines + /// [`BufRead`]: trait.BufRead.html + /// [`std::io::Split`]: https://doc.rust-lang.org/std/io/struct.Split.html + #[derive(Debug)] + pub struct Split { + #[pin] + pub(crate) reader: R, + pub(crate) buf: Vec, + pub(crate) read: usize, + pub(crate) delim: u8, + } } impl Stream for Split { type Item = io::Result>; fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - let Self { - reader, - buf, - read, - delim, - } = unsafe { self.get_unchecked_mut() }; - let reader = unsafe { Pin::new_unchecked(reader) }; - let n = futures_core::ready!(read_until_internal(reader, cx, *delim, buf, read))?; - if n == 0 && buf.is_empty() { + let this = self.project(); + let n = futures_core::ready!(read_until_internal( + this.reader, + cx, + *this.delim, + this.buf, + this.read + ))?; + if n == 0 && this.buf.is_empty() { return Poll::Ready(None); } - if buf[buf.len() - 1] == *delim { - buf.pop(); + if this.buf[this.buf.len() - 1] == *this.delim { + this.buf.pop(); } - Poll::Ready(Some(Ok(mem::replace(buf, vec![])))) + Poll::Ready(Some(Ok(mem::replace(this.buf, vec![])))) } } diff --git a/src/io/buf_reader.rs b/src/io/buf_reader.rs index 13cb4cc..1d00b52 100644 --- a/src/io/buf_reader.rs +++ b/src/io/buf_reader.rs @@ -2,51 +2,56 @@ use std::io::{IoSliceMut, Read as _}; use std::pin::Pin; use std::{cmp, fmt}; +use pin_project_lite::pin_project; + use crate::io::{self, BufRead, Read, Seek, SeekFrom}; use crate::task::{Context, Poll}; const DEFAULT_CAPACITY: usize = 8 * 1024; -/// Adds buffering to any reader. -/// -/// It can be excessively inefficient to work directly with a [`Read`] instance. A `BufReader` -/// performs large, infrequent reads on the underlying [`Read`] and maintains an in-memory buffer -/// of the incoming byte stream. -/// -/// `BufReader` can improve the speed of programs that make *small* and *repeated* read calls to -/// the same file or network socket. It does not help when reading very large amounts at once, or -/// reading just one or a few times. It also provides no advantage when reading from a source that -/// is already in memory, like a `Vec`. -/// -/// When the `BufReader` is dropped, the contents of its buffer will be discarded. Creating -/// multiple instances of a `BufReader` on the same stream can cause data loss. -/// -/// This type is an async version of [`std::io::BufReader`]. -/// -/// [`Read`]: trait.Read.html -/// [`std::io::BufReader`]: https://doc.rust-lang.org/std/io/struct.BufReader.html -/// -/// # Examples -/// -/// ```no_run -/// # fn main() -> std::io::Result<()> { async_std::task::block_on(async { -/// # -/// use async_std::fs::File; -/// use async_std::io::BufReader; -/// use async_std::prelude::*; -/// -/// let mut file = BufReader::new(File::open("a.txt").await?); -/// -/// let mut line = String::new(); -/// file.read_line(&mut line).await?; -/// # -/// # Ok(()) }) } -/// ``` -pub struct BufReader { - inner: R, - buf: Box<[u8]>, - pos: usize, - cap: usize, +pin_project! { + /// Adds buffering to any reader. + /// + /// It can be excessively inefficient to work directly with a [`Read`] instance. A `BufReader` + /// performs large, infrequent reads on the underlying [`Read`] and maintains an in-memory buffer + /// of the incoming byte stream. + /// + /// `BufReader` can improve the speed of programs that make *small* and *repeated* read calls to + /// the same file or network socket. It does not help when reading very large amounts at once, or + /// reading just one or a few times. It also provides no advantage when reading from a source that + /// is already in memory, like a `Vec`. + /// + /// When the `BufReader` is dropped, the contents of its buffer will be discarded. Creating + /// multiple instances of a `BufReader` on the same stream can cause data loss. + /// + /// This type is an async version of [`std::io::BufReader`]. + /// + /// [`Read`]: trait.Read.html + /// [`std::io::BufReader`]: https://doc.rust-lang.org/std/io/struct.BufReader.html + /// + /// # Examples + /// + /// ```no_run + /// # fn main() -> std::io::Result<()> { async_std::task::block_on(async { + /// # + /// use async_std::fs::File; + /// use async_std::io::BufReader; + /// use async_std::prelude::*; + /// + /// let mut file = BufReader::new(File::open("a.txt").await?); + /// + /// let mut line = String::new(); + /// file.read_line(&mut line).await?; + /// # + /// # Ok(()) }) } + /// ``` + pub struct BufReader { + #[pin] + inner: R, + buf: Box<[u8]>, + pos: usize, + cap: usize, + } } impl BufReader { @@ -95,10 +100,6 @@ impl BufReader { } impl BufReader { - pin_utils::unsafe_pinned!(inner: R); - pin_utils::unsafe_unpinned!(pos: usize); - pin_utils::unsafe_unpinned!(cap: usize); - /// Gets a reference to the underlying reader. /// /// It is inadvisable to directly read from the underlying reader. @@ -141,6 +142,13 @@ impl BufReader { &mut self.inner } + /// Gets a pinned mutable reference to the underlying reader. + /// + /// It is inadvisable to directly read from the underlying reader. + fn get_pin_mut(self: Pin<&mut Self>) -> Pin<&mut R> { + self.project().inner + } + /// Returns a reference to the internal buffer. /// /// This function will not attempt to fill the buffer if it is empty. @@ -185,9 +193,10 @@ impl BufReader { /// Invalidates all data in the internal buffer. #[inline] - fn discard_buffer(mut self: Pin<&mut Self>) { - *self.as_mut().pos() = 0; - *self.cap() = 0; + fn discard_buffer(self: Pin<&mut Self>) { + let this = self.project(); + *this.pos = 0; + *this.cap = 0; } } @@ -201,7 +210,7 @@ impl Read for BufReader { // (larger than our internal buffer), bypass our internal buffer // entirely. if self.pos == self.cap && buf.len() >= self.buf.len() { - let res = futures_core::ready!(self.as_mut().inner().poll_read(cx, buf)); + let res = futures_core::ready!(self.as_mut().get_pin_mut().poll_read(cx, buf)); self.discard_buffer(); return Poll::Ready(res); } @@ -218,7 +227,8 @@ impl Read for BufReader { ) -> Poll> { let total_len = bufs.iter().map(|b| b.len()).sum::(); if self.pos == self.cap && total_len >= self.buf.len() { - let res = futures_core::ready!(self.as_mut().inner().poll_read_vectored(cx, bufs)); + let res = + futures_core::ready!(self.as_mut().get_pin_mut().poll_read_vectored(cx, bufs)); self.discard_buffer(); return Poll::Ready(res); } @@ -234,28 +244,23 @@ impl BufRead for BufReader { self: Pin<&'a mut Self>, cx: &mut Context<'_>, ) -> Poll> { - let Self { - inner, - buf, - cap, - pos, - } = unsafe { self.get_unchecked_mut() }; - let mut inner = unsafe { Pin::new_unchecked(inner) }; + let mut this = self.project(); // If we've reached the end of our internal buffer then we need to fetch // some more data from the underlying reader. // Branch using `>=` instead of the more correct `==` // to tell the compiler that the pos..cap slice is always valid. - if *pos >= *cap { - debug_assert!(*pos == *cap); - *cap = futures_core::ready!(inner.as_mut().poll_read(cx, buf))?; - *pos = 0; + if *this.pos >= *this.cap { + debug_assert!(*this.pos == *this.cap); + *this.cap = futures_core::ready!(this.inner.as_mut().poll_read(cx, this.buf))?; + *this.pos = 0; } - Poll::Ready(Ok(&buf[*pos..*cap])) + Poll::Ready(Ok(&this.buf[*this.pos..*this.cap])) } - fn consume(mut self: Pin<&mut Self>, amt: usize) { - *self.as_mut().pos() = cmp::min(self.pos + amt, self.cap); + fn consume(self: Pin<&mut Self>, amt: usize) { + let this = self.project(); + *this.pos = cmp::min(*this.pos + amt, *this.cap); } } @@ -305,24 +310,26 @@ impl Seek for BufReader { if let Some(offset) = n.checked_sub(remainder) { result = futures_core::ready!( self.as_mut() - .inner() + .get_pin_mut() .poll_seek(cx, SeekFrom::Current(offset)) )?; } else { // seek backwards by our remainder, and then by the offset futures_core::ready!( self.as_mut() - .inner() + .get_pin_mut() .poll_seek(cx, SeekFrom::Current(-remainder)) )?; self.as_mut().discard_buffer(); result = futures_core::ready!( - self.as_mut().inner().poll_seek(cx, SeekFrom::Current(n)) + self.as_mut() + .get_pin_mut() + .poll_seek(cx, SeekFrom::Current(n)) )?; } } else { // Seeking with Start/End doesn't care about our buffer length. - result = futures_core::ready!(self.as_mut().inner().poll_seek(cx, pos))?; + result = futures_core::ready!(self.as_mut().get_pin_mut().poll_seek(cx, pos))?; } self.discard_buffer(); Poll::Ready(Ok(result)) diff --git a/src/io/buf_writer.rs b/src/io/buf_writer.rs index f12aacb..6327ca7 100644 --- a/src/io/buf_writer.rs +++ b/src/io/buf_writer.rs @@ -2,6 +2,7 @@ use std::fmt; use std::pin::Pin; use futures_core::ready; +use pin_project_lite::pin_project; use crate::io::write::WriteExt; use crate::io::{self, Seek, SeekFrom, Write}; @@ -9,88 +10,88 @@ use crate::task::{Context, Poll}; const DEFAULT_CAPACITY: usize = 8 * 1024; -/// Wraps a writer and buffers its output. -/// -/// It can be excessively inefficient to work directly with something that -/// implements [`Write`]. For example, every call to -/// [`write`][`TcpStream::write`] on [`TcpStream`] results in a system call. A -/// `BufWriter` keeps an in-memory buffer of data and writes it to an underlying -/// writer in large, infrequent batches. -/// -/// `BufWriter` can improve the speed of programs that make *small* and -/// *repeated* write calls to the same file or network socket. It does not -/// help when writing very large amounts at once, or writing just one or a few -/// times. It also provides no advantage when writing to a destination that is -/// in memory, like a `Vec`. -/// -/// When the `BufWriter` is dropped, the contents of its buffer will be written -/// out. However, any errors that happen in the process of flushing the buffer -/// when the writer is dropped will be ignored. Code that wishes to handle such -/// errors must manually call [`flush`] before the writer is dropped. -/// -/// This type is an async version of [`std::io::BufReader`]. -/// -/// [`std::io::BufReader`]: https://doc.rust-lang.org/std/io/struct.BufReader.html -/// -/// # Examples -/// -/// Let's write the numbers one through ten to a [`TcpStream`]: -/// -/// ```no_run -/// # fn main() -> std::io::Result<()> { async_std::task::block_on(async { -/// use async_std::net::TcpStream; -/// use async_std::prelude::*; -/// -/// let mut stream = TcpStream::connect("127.0.0.1:34254").await?; -/// -/// for i in 0..10 { -/// let arr = [i+1]; -/// stream.write(&arr).await?; -/// } -/// # -/// # Ok(()) }) } -/// ``` -/// -/// Because we're not buffering, we write each one in turn, incurring the -/// overhead of a system call per byte written. We can fix this with a -/// `BufWriter`: -/// -/// ```no_run -/// # fn main() -> std::io::Result<()> { async_std::task::block_on(async { -/// use async_std::io::BufWriter; -/// use async_std::net::TcpStream; -/// use async_std::prelude::*; -/// -/// let mut stream = BufWriter::new(TcpStream::connect("127.0.0.1:34254").await?); -/// for i in 0..10 { -/// let arr = [i+1]; -/// stream.write(&arr).await?; -/// }; -/// # -/// # Ok(()) }) } -/// ``` -/// -/// By wrapping the stream with a `BufWriter`, these ten writes are all grouped -/// together by the buffer, and will all be written out in one system call when -/// the `stream` is dropped. -/// -/// [`Write`]: trait.Write.html -/// [`TcpStream::write`]: ../net/struct.TcpStream.html#method.write -/// [`TcpStream`]: ../net/struct.TcpStream.html -/// [`flush`]: trait.Write.html#tymethod.flush -pub struct BufWriter { - inner: W, - buf: Vec, - written: usize, +pin_project! { + /// Wraps a writer and buffers its output. + /// + /// It can be excessively inefficient to work directly with something that + /// implements [`Write`]. For example, every call to + /// [`write`][`TcpStream::write`] on [`TcpStream`] results in a system call. A + /// `BufWriter` keeps an in-memory buffer of data and writes it to an underlying + /// writer in large, infrequent batches. + /// + /// `BufWriter` can improve the speed of programs that make *small* and + /// *repeated* write calls to the same file or network socket. It does not + /// help when writing very large amounts at once, or writing just one or a few + /// times. It also provides no advantage when writing to a destination that is + /// in memory, like a `Vec`. + /// + /// When the `BufWriter` is dropped, the contents of its buffer will be written + /// out. However, any errors that happen in the process of flushing the buffer + /// when the writer is dropped will be ignored. Code that wishes to handle such + /// errors must manually call [`flush`] before the writer is dropped. + /// + /// This type is an async version of [`std::io::BufReader`]. + /// + /// [`std::io::BufReader`]: https://doc.rust-lang.org/std/io/struct.BufReader.html + /// + /// # Examples + /// + /// Let's write the numbers one through ten to a [`TcpStream`]: + /// + /// ```no_run + /// # fn main() -> std::io::Result<()> { async_std::task::block_on(async { + /// use async_std::net::TcpStream; + /// use async_std::prelude::*; + /// + /// let mut stream = TcpStream::connect("127.0.0.1:34254").await?; + /// + /// for i in 0..10 { + /// let arr = [i+1]; + /// stream.write(&arr).await?; + /// } + /// # + /// # Ok(()) }) } + /// ``` + /// + /// Because we're not buffering, we write each one in turn, incurring the + /// overhead of a system call per byte written. We can fix this with a + /// `BufWriter`: + /// + /// ```no_run + /// # fn main() -> std::io::Result<()> { async_std::task::block_on(async { + /// use async_std::io::BufWriter; + /// use async_std::net::TcpStream; + /// use async_std::prelude::*; + /// + /// let mut stream = BufWriter::new(TcpStream::connect("127.0.0.1:34254").await?); + /// for i in 0..10 { + /// let arr = [i+1]; + /// stream.write(&arr).await?; + /// }; + /// # + /// # Ok(()) }) } + /// ``` + /// + /// By wrapping the stream with a `BufWriter`, these ten writes are all grouped + /// together by the buffer, and will all be written out in one system call when + /// the `stream` is dropped. + /// + /// [`Write`]: trait.Write.html + /// [`TcpStream::write`]: ../net/struct.TcpStream.html#method.write + /// [`TcpStream`]: ../net/struct.TcpStream.html + /// [`flush`]: trait.Write.html#tymethod.flush + pub struct BufWriter { + #[pin] + inner: W, + buf: Vec, + written: usize, + } } #[derive(Debug)] pub struct IntoInnerError(W, std::io::Error); impl BufWriter { - pin_utils::unsafe_pinned!(inner: W); - pin_utils::unsafe_unpinned!(buf: Vec); - /// Creates a new `BufWriter` with a default buffer capacity. The default is currently 8 KB, /// but may change in the future. /// @@ -178,6 +179,13 @@ impl BufWriter { &mut self.inner } + /// Gets a pinned mutable reference to the underlying writer. + /// + /// It is inadvisable to directly write to the underlying writer. + fn get_pin_mut(self: Pin<&mut Self>) -> Pin<&mut W> { + self.project().inner + } + /// Consumes BufWriter, returning the underlying writer /// /// This method will not write leftover data, it will be lost. @@ -234,16 +242,15 @@ impl BufWriter { /// /// [`LineWriter`]: struct.LineWriter.html fn poll_flush_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - let Self { - inner, - buf, - written, - } = unsafe { Pin::get_unchecked_mut(self) }; - let mut inner = unsafe { Pin::new_unchecked(inner) }; - let len = buf.len(); + let mut this = self.project(); + let len = this.buf.len(); let mut ret = Ok(()); - while *written < len { - match inner.as_mut().poll_write(cx, &buf[*written..]) { + while *this.written < len { + match this + .inner + .as_mut() + .poll_write(cx, &this.buf[*this.written..]) + { Poll::Ready(Ok(0)) => { ret = Err(io::Error::new( io::ErrorKind::WriteZero, @@ -251,7 +258,7 @@ impl BufWriter { )); break; } - Poll::Ready(Ok(n)) => *written += n, + Poll::Ready(Ok(n)) => *this.written += n, Poll::Ready(Err(ref e)) if e.kind() == io::ErrorKind::Interrupted => {} Poll::Ready(Err(e)) => { ret = Err(e); @@ -260,10 +267,10 @@ impl BufWriter { Poll::Pending => return Poll::Pending, } } - if *written > 0 { - buf.drain(..*written); + if *this.written > 0 { + this.buf.drain(..*this.written); } - *written = 0; + *this.written = 0; Poll::Ready(ret) } } @@ -278,20 +285,20 @@ impl Write for BufWriter { ready!(self.as_mut().poll_flush_buf(cx))?; } if buf.len() >= self.buf.capacity() { - self.inner().poll_write(cx, buf) + self.get_pin_mut().poll_write(cx, buf) } else { - Pin::new(&mut *self.buf()).poll_write(cx, buf) + Pin::new(&mut *self.project().buf).poll_write(cx, buf) } } fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { ready!(self.as_mut().poll_flush_buf(cx))?; - self.inner().poll_flush(cx) + self.get_pin_mut().poll_flush(cx) } fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { ready!(self.as_mut().poll_flush_buf(cx))?; - self.inner().poll_close(cx) + self.get_pin_mut().poll_close(cx) } } @@ -314,6 +321,6 @@ impl Seek for BufWriter { pos: SeekFrom, ) -> Poll> { ready!(self.as_mut().poll_flush_buf(cx))?; - self.inner().poll_seek(cx, pos) + self.get_pin_mut().poll_seek(cx, pos) } } diff --git a/src/io/copy.rs b/src/io/copy.rs index 3840d2a..098df8d 100644 --- a/src/io/copy.rs +++ b/src/io/copy.rs @@ -1,5 +1,7 @@ use std::pin::Pin; +use pin_project_lite::pin_project; + use crate::future::Future; use crate::io::{self, BufRead, BufReader, Read, Write}; use crate::task::{Context, Poll}; @@ -46,47 +48,38 @@ where R: Read + Unpin + ?Sized, W: Write + Unpin + ?Sized, { - pub struct CopyFuture<'a, R, W: ?Sized> { - reader: R, - writer: &'a mut W, - amt: u64, - } - - impl CopyFuture<'_, R, W> { - fn project(self: Pin<&mut Self>) -> (Pin<&mut R>, Pin<&mut W>, &mut u64) { - unsafe { - let this = self.get_unchecked_mut(); - ( - Pin::new_unchecked(&mut this.reader), - Pin::new(&mut *this.writer), - &mut this.amt, - ) - } + pin_project! { + struct CopyFuture { + #[pin] + reader: R, + #[pin] + writer: W, + amt: u64, } } - impl Future for CopyFuture<'_, R, W> + impl Future for CopyFuture where R: BufRead, - W: Write + Unpin + ?Sized, + W: Write + Unpin, { type Output = io::Result; fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - let (mut reader, mut writer, amt) = self.project(); + let mut this = self.project(); loop { - let buffer = futures_core::ready!(reader.as_mut().poll_fill_buf(cx))?; + let buffer = futures_core::ready!(this.reader.as_mut().poll_fill_buf(cx))?; if buffer.is_empty() { - futures_core::ready!(writer.as_mut().poll_flush(cx))?; - return Poll::Ready(Ok(*amt)); + futures_core::ready!(this.writer.as_mut().poll_flush(cx))?; + return Poll::Ready(Ok(*this.amt)); } - let i = futures_core::ready!(writer.as_mut().poll_write(cx, buffer))?; + let i = futures_core::ready!(this.writer.as_mut().poll_write(cx, buffer))?; if i == 0 { return Poll::Ready(Err(io::ErrorKind::WriteZero.into())); } - *amt += i as u64; - reader.as_mut().consume(i); + *this.amt += i as u64; + this.reader.as_mut().consume(i); } } } diff --git a/src/io/read/chain.rs b/src/io/read/chain.rs index 09517cc..335cac2 100644 --- a/src/io/read/chain.rs +++ b/src/io/read/chain.rs @@ -1,20 +1,25 @@ -use crate::io::IoSliceMut; use std::fmt; use std::pin::Pin; -use crate::io::{self, BufRead, Read}; +use pin_project_lite::pin_project; + +use crate::io::{self, BufRead, IoSliceMut, Read}; use crate::task::{Context, Poll}; -/// Adaptor to chain together two readers. -/// -/// This struct is generally created by calling [`chain`] on a reader. -/// Please see the documentation of [`chain`] for more details. -/// -/// [`chain`]: trait.Read.html#method.chain -pub struct Chain { - pub(crate) first: T, - pub(crate) second: U, - pub(crate) done_first: bool, +pin_project! { + /// Adaptor to chain together two readers. + /// + /// This struct is generally created by calling [`chain`] on a reader. + /// Please see the documentation of [`chain`] for more details. + /// + /// [`chain`]: trait.Read.html#method.chain + pub struct Chain { + #[pin] + pub(crate) first: T, + #[pin] + pub(crate) second: U, + pub(crate) done_first: bool, + } } impl Chain { @@ -98,76 +103,64 @@ impl fmt::Debug for Chain { } } -impl Read for Chain { +impl Read for Chain { fn poll_read( - mut self: Pin<&mut Self>, + self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut [u8], ) -> Poll> { - if !self.done_first { - let rd = Pin::new(&mut self.first); - - match futures_core::ready!(rd.poll_read(cx, buf)) { - Ok(0) if !buf.is_empty() => self.done_first = true, + let this = self.project(); + if !*this.done_first { + match futures_core::ready!(this.first.poll_read(cx, buf)) { + Ok(0) if !buf.is_empty() => *this.done_first = true, Ok(n) => return Poll::Ready(Ok(n)), Err(err) => return Poll::Ready(Err(err)), } } - let rd = Pin::new(&mut self.second); - rd.poll_read(cx, buf) + this.second.poll_read(cx, buf) } fn poll_read_vectored( - mut self: Pin<&mut Self>, + self: Pin<&mut Self>, cx: &mut Context<'_>, bufs: &mut [IoSliceMut<'_>], ) -> Poll> { - if !self.done_first { - let rd = Pin::new(&mut self.first); - - match futures_core::ready!(rd.poll_read_vectored(cx, bufs)) { - Ok(0) if !bufs.is_empty() => self.done_first = true, + let this = self.project(); + if !*this.done_first { + match futures_core::ready!(this.first.poll_read_vectored(cx, bufs)) { + Ok(0) if !bufs.is_empty() => *this.done_first = true, Ok(n) => return Poll::Ready(Ok(n)), Err(err) => return Poll::Ready(Err(err)), } } - let rd = Pin::new(&mut self.second); - rd.poll_read_vectored(cx, bufs) + this.second.poll_read_vectored(cx, bufs) } } -impl BufRead for Chain { +impl BufRead for Chain { fn poll_fill_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - let Self { - first, - second, - done_first, - } = unsafe { self.get_unchecked_mut() }; - - if !*done_first { - let first = unsafe { Pin::new_unchecked(first) }; - match futures_core::ready!(first.poll_fill_buf(cx)) { + let this = self.project(); + if !*this.done_first { + match futures_core::ready!(this.first.poll_fill_buf(cx)) { Ok(buf) if buf.is_empty() => { - *done_first = true; + *this.done_first = true; } Ok(buf) => return Poll::Ready(Ok(buf)), Err(err) => return Poll::Ready(Err(err)), } } - let second = unsafe { Pin::new_unchecked(second) }; - second.poll_fill_buf(cx) + this.second.poll_fill_buf(cx) } - fn consume(mut self: Pin<&mut Self>, amt: usize) { - if !self.done_first { - let rd = Pin::new(&mut self.first); - rd.consume(amt) + fn consume(self: Pin<&mut Self>, amt: usize) { + let this = self.project(); + if !*this.done_first { + this.first.consume(amt) } else { - let rd = Pin::new(&mut self.second); - rd.consume(amt) + this.second.consume(amt) } } } diff --git a/src/io/read/take.rs b/src/io/read/take.rs index def4e24..09b02c2 100644 --- a/src/io/read/take.rs +++ b/src/io/read/take.rs @@ -1,19 +1,24 @@ use std::cmp; use std::pin::Pin; +use pin_project_lite::pin_project; + use crate::io::{self, BufRead, Read}; use crate::task::{Context, Poll}; -/// Reader adaptor which limits the bytes read from an underlying reader. -/// -/// This struct is generally created by calling [`take`] on a reader. -/// Please see the documentation of [`take`] for more details. -/// -/// [`take`]: trait.Read.html#method.take -#[derive(Debug)] -pub struct Take { - pub(crate) inner: T, - pub(crate) limit: u64, +pin_project! { + /// Reader adaptor which limits the bytes read from an underlying reader. + /// + /// This struct is generally created by calling [`take`] on a reader. + /// Please see the documentation of [`take`] for more details. + /// + /// [`take`]: trait.Read.html#method.take + #[derive(Debug)] + pub struct Take { + #[pin] + pub(crate) inner: T, + pub(crate) limit: u64, + } } impl Take { @@ -152,15 +157,15 @@ impl Take { } } -impl Read for Take { +impl Read for Take { /// Attempt to read from the `AsyncRead` into `buf`. fn poll_read( - mut self: Pin<&mut Self>, + self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut [u8], ) -> Poll> { - let Self { inner, limit } = &mut *self; - take_read_internal(Pin::new(inner), cx, buf, limit) + let this = self.project(); + take_read_internal(this.inner, cx, buf, this.limit) } } @@ -186,31 +191,30 @@ pub fn take_read_internal( } } -impl BufRead for Take { +impl BufRead for Take { fn poll_fill_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - let Self { inner, limit } = unsafe { self.get_unchecked_mut() }; - let inner = unsafe { Pin::new_unchecked(inner) }; + let this = self.project(); - if *limit == 0 { + if *this.limit == 0 { return Poll::Ready(Ok(&[])); } - match futures_core::ready!(inner.poll_fill_buf(cx)) { + match futures_core::ready!(this.inner.poll_fill_buf(cx)) { Ok(buf) => { - let cap = cmp::min(buf.len() as u64, *limit) as usize; + let cap = cmp::min(buf.len() as u64, *this.limit) as usize; Poll::Ready(Ok(&buf[..cap])) } Err(e) => Poll::Ready(Err(e)), } } - fn consume(mut self: Pin<&mut Self>, amt: usize) { + fn consume(self: Pin<&mut Self>, amt: usize) { + let this = self.project(); // Don't let callers reset the limit by passing an overlarge value - let amt = cmp::min(amt as u64, self.limit) as usize; - self.limit -= amt as u64; + let amt = cmp::min(amt as u64, *this.limit) as usize; + *this.limit -= amt as u64; - let rd = Pin::new(&mut self.inner); - rd.consume(amt); + this.inner.consume(amt); } } diff --git a/src/io/timeout.rs b/src/io/timeout.rs index 9fcc15e..ec3668e 100644 --- a/src/io/timeout.rs +++ b/src/io/timeout.rs @@ -3,7 +3,7 @@ use std::task::{Context, Poll}; use std::time::Duration; use futures_timer::Delay; -use pin_utils::unsafe_pinned; +use pin_project_lite::pin_project; use crate::future::Future; use crate::io; @@ -43,22 +43,18 @@ where .await } -/// Future returned by the `FutureExt::timeout` method. -#[derive(Debug)] -pub struct Timeout -where - F: Future>, -{ - future: F, - timeout: Delay, -} - -impl Timeout -where - F: Future>, -{ - unsafe_pinned!(future: F); - unsafe_pinned!(timeout: Delay); +pin_project! { + /// Future returned by the `FutureExt::timeout` method. + #[derive(Debug)] + pub struct Timeout + where + F: Future>, + { + #[pin] + future: F, + #[pin] + timeout: Delay, + } } impl Future for Timeout @@ -67,14 +63,15 @@ where { type Output = io::Result; - fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - match self.as_mut().future().poll(cx) { + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let this = self.project(); + match this.future.poll(cx) { Poll::Pending => {} other => return other, } - if self.timeout().poll(cx).is_ready() { - let err = Err(io::Error::new(io::ErrorKind::TimedOut, "future timed out").into()); + if this.timeout.poll(cx).is_ready() { + let err = Err(io::Error::new(io::ErrorKind::TimedOut, "future timed out")); Poll::Ready(err) } else { Poll::Pending diff --git a/src/io/write/write_fmt.rs b/src/io/write/write_fmt.rs index a1149cd..ad3e94a 100644 --- a/src/io/write/write_fmt.rs +++ b/src/io/write/write_fmt.rs @@ -32,7 +32,7 @@ impl Future for WriteFmtFuture<'_, T> { buffer, .. } = &mut *self; - let mut buffer = buffer.as_mut().unwrap(); + let buffer = buffer.as_mut().unwrap(); // Copy the data from the buffer into the writer until it's done. loop { @@ -40,7 +40,7 @@ impl Future for WriteFmtFuture<'_, T> { futures_core::ready!(Pin::new(&mut **writer).poll_flush(cx))?; return Poll::Ready(Ok(())); } - let i = futures_core::ready!(Pin::new(&mut **writer).poll_write(cx, &mut buffer))?; + let i = futures_core::ready!(Pin::new(&mut **writer).poll_write(cx, buffer))?; if i == 0 { return Poll::Ready(Err(io::ErrorKind::WriteZero.into())); } diff --git a/src/net/addr.rs b/src/net/addr.rs index 1dada64..519b184 100644 --- a/src/net/addr.rs +++ b/src/net/addr.rs @@ -210,7 +210,7 @@ impl ToSocketAddrs for str { impl Future, ToSocketAddrsFuture ) { - if let Some(addr) = self.parse().ok() { + if let Ok(addr) = self.parse() { return ToSocketAddrsFuture::Ready(Ok(vec![addr].into_iter())); } diff --git a/src/path/path.rs b/src/path/path.rs index aa15ab2..22dab76 100644 --- a/src/path/path.rs +++ b/src/path/path.rs @@ -799,7 +799,7 @@ impl AsRef for String { impl AsRef for std::path::PathBuf { fn as_ref(&self) -> &Path { - Path::new(self.into()) + Path::new(self) } } diff --git a/src/path/pathbuf.rs b/src/path/pathbuf.rs index 64744e1..38018c9 100644 --- a/src/path/pathbuf.rs +++ b/src/path/pathbuf.rs @@ -5,7 +5,7 @@ use crate::path::Path; /// This struct is an async version of [`std::path::PathBuf`]. /// /// [`std::path::Path`]: https://doc.rust-lang.org/std/path/struct.PathBuf.html -#[derive(Debug, PartialEq)] +#[derive(Debug, PartialEq, Default)] pub struct PathBuf { inner: std::path::PathBuf, } @@ -206,7 +206,7 @@ impl From for PathBuf { impl Into for PathBuf { fn into(self) -> std::path::PathBuf { - self.inner.into() + self.inner } } diff --git a/src/stream/exact_size_stream.rs b/src/stream/exact_size_stream.rs index 7d2e7cb..32a1eb3 100644 --- a/src/stream/exact_size_stream.rs +++ b/src/stream/exact_size_stream.rs @@ -59,7 +59,7 @@ pub use crate::stream::Stream; /// # } /// # } /// # } -/// # fn main() { async_std::task::block_on(async { +/// # async_std::task::block_on(async { /// # /// impl ExactSizeStream for Counter { /// // We can easily calculate the remaining number of iterations. @@ -74,7 +74,6 @@ pub use crate::stream::Stream; /// /// assert_eq!(5, counter.len()); /// # }); -/// # } /// ``` #[cfg(feature = "unstable")] #[cfg_attr(feature = "docs", doc(cfg(unstable)))] diff --git a/src/stream/extend.rs b/src/stream/extend.rs index d9e1481..350418d 100644 --- a/src/stream/extend.rs +++ b/src/stream/extend.rs @@ -14,7 +14,7 @@ use crate::stream::IntoStream; /// ## Examples /// /// ``` -/// # fn main() { async_std::task::block_on(async { +/// # async_std::task::block_on(async { /// # /// use async_std::prelude::*; /// use async_std::stream::{self, Extend}; @@ -25,7 +25,7 @@ use crate::stream::IntoStream; /// /// assert_eq!(v, vec![1, 2, 3, 3, 3]); /// # -/// # }) } +/// # }) /// ``` #[cfg(feature = "unstable")] #[cfg_attr(feature = "docs", doc(cfg(unstable)))] diff --git a/src/stream/from_fn.rs b/src/stream/from_fn.rs index c1cb97a..7fee892 100644 --- a/src/stream/from_fn.rs +++ b/src/stream/from_fn.rs @@ -1,20 +1,25 @@ use std::marker::PhantomData; use std::pin::Pin; +use pin_project_lite::pin_project; + use crate::future::Future; use crate::stream::Stream; use crate::task::{Context, Poll}; -/// A stream that yields elements by calling a closure. -/// -/// This stream is constructed by [`from_fn`] function. -/// -/// [`from_fn`]: fn.from_fn.html -#[derive(Debug)] -pub struct FromFn { - f: F, - future: Option, - __t: PhantomData, +pin_project! { + /// A stream that yields elements by calling a closure. + /// + /// This stream is constructed by [`from_fn`] function. + /// + /// [`from_fn`]: fn.from_fn.html + #[derive(Debug)] + pub struct FromFn { + f: F, + #[pin] + future: Option, + __t: PhantomData, + } } /// Creates a new stream where to produce each new element a provided closure is called. @@ -25,7 +30,7 @@ pub struct FromFn { /// # Examples /// /// ``` -/// # fn main() { async_std::task::block_on(async { +/// # async_std::task::block_on(async { /// # /// use async_std::prelude::*; /// use async_std::sync::Mutex; @@ -53,8 +58,7 @@ pub struct FromFn { /// assert_eq!(s.next().await, Some(3)); /// assert_eq!(s.next().await, None); /// # -/// # }) } -/// +/// # }) /// ``` pub fn from_fn(f: F) -> FromFn where @@ -68,11 +72,6 @@ where } } -impl FromFn { - pin_utils::unsafe_unpinned!(f: F); - pin_utils::unsafe_pinned!(future: Option); -} - impl Stream for FromFn where F: FnMut() -> Fut, @@ -80,20 +79,18 @@ where { type Item = T; - fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let mut this = self.project(); loop { - match &self.future { - Some(_) => { - let next = - futures_core::ready!(self.as_mut().future().as_pin_mut().unwrap().poll(cx)); - self.as_mut().future().set(None); + if this.future.is_some() { + let next = + futures_core::ready!(this.future.as_mut().as_pin_mut().unwrap().poll(cx)); + this.future.set(None); - return Poll::Ready(next); - } - None => { - let fut = (self.as_mut().f())(); - self.as_mut().future().set(Some(fut)); - } + return Poll::Ready(next); + } else { + let fut = (this.f)(); + this.future.set(Some(fut)); } } } diff --git a/src/stream/interval.rs b/src/stream/interval.rs index 043d307..2f7fe9e 100644 --- a/src/stream/interval.rs +++ b/src/stream/interval.rs @@ -4,8 +4,6 @@ use std::time::{Duration, Instant}; use futures_core::future::Future; use futures_core::stream::Stream; -use pin_utils::unsafe_pinned; - use futures_timer::Delay; /// Creates a new stream that yields at a set interval. @@ -62,15 +60,11 @@ pub struct Interval { interval: Duration, } -impl Interval { - unsafe_pinned!(delay: Delay); -} - impl Stream for Interval { type Item = (); fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - if Pin::new(&mut *self).delay().poll(cx).is_pending() { + if Pin::new(&mut self.delay).poll(cx).is_pending() { return Poll::Pending; } let when = Instant::now(); diff --git a/src/stream/once.rs b/src/stream/once.rs index be875e4..ae90d63 100644 --- a/src/stream/once.rs +++ b/src/stream/once.rs @@ -1,5 +1,7 @@ use std::pin::Pin; +use pin_project_lite::pin_project; + use crate::stream::Stream; use crate::task::{Context, Poll}; @@ -24,20 +26,22 @@ pub fn once(t: T) -> Once { Once { value: Some(t) } } -/// A stream that yields a single item. -/// -/// This stream is constructed by the [`once`] function. -/// -/// [`once`]: fn.once.html -#[derive(Debug)] -pub struct Once { - value: Option, +pin_project! { + /// A stream that yields a single item. + /// + /// This stream is constructed by the [`once`] function. + /// + /// [`once`]: fn.once.html + #[derive(Debug)] + pub struct Once { + value: Option, + } } impl Stream for Once { type Item = T; - fn poll_next(mut self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll> { - Poll::Ready(self.value.take()) + fn poll_next(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll> { + Poll::Ready(self.project().value.take()) } } diff --git a/src/stream/repeat_with.rs b/src/stream/repeat_with.rs index f38b323..15d4aa1 100644 --- a/src/stream/repeat_with.rs +++ b/src/stream/repeat_with.rs @@ -1,20 +1,25 @@ use std::marker::PhantomData; use std::pin::Pin; +use pin_project_lite::pin_project; + use crate::future::Future; use crate::stream::Stream; use crate::task::{Context, Poll}; -/// A stream that repeats elements of type `T` endlessly by applying a provided closure. -/// -/// This stream is constructed by the [`repeat_with`] function. -/// -/// [`repeat_with`]: fn.repeat_with.html -#[derive(Debug)] -pub struct RepeatWith { - f: F, - future: Option, - __a: PhantomData, +pin_project! { + /// A stream that repeats elements of type `T` endlessly by applying a provided closure. + /// + /// This stream is constructed by the [`repeat_with`] function. + /// + /// [`repeat_with`]: fn.repeat_with.html + #[derive(Debug)] + pub struct RepeatWith { + f: F, + #[pin] + future: Option, + __a: PhantomData, + } } /// Creates a new stream that repeats elements of type `A` endlessly by applying the provided closure. @@ -24,7 +29,7 @@ pub struct RepeatWith { /// Basic usage: /// /// ``` -/// # fn main() { async_std::task::block_on(async { +/// # async_std::task::block_on(async { /// # /// use async_std::prelude::*; /// use async_std::stream; @@ -37,13 +42,13 @@ pub struct RepeatWith { /// assert_eq!(s.next().await, Some(1)); /// assert_eq!(s.next().await, Some(1)); /// assert_eq!(s.next().await, Some(1)); -/// # }) } +/// # }) /// ``` /// /// Going finite: /// /// ``` -/// # fn main() { async_std::task::block_on(async { +/// # async_std::task::block_on(async { /// # /// use async_std::prelude::*; /// use async_std::stream; @@ -55,7 +60,7 @@ pub struct RepeatWith { /// assert_eq!(s.next().await, Some(1)); /// assert_eq!(s.next().await, Some(1)); /// assert_eq!(s.next().await, None); -/// # }) } +/// # }) /// ``` pub fn repeat_with(repeater: F) -> RepeatWith where @@ -69,11 +74,6 @@ where } } -impl RepeatWith { - pin_utils::unsafe_unpinned!(f: F); - pin_utils::unsafe_pinned!(future: Option); -} - impl Stream for RepeatWith where F: FnMut() -> Fut, @@ -81,22 +81,19 @@ where { type Item = A; - fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let mut this = self.project(); loop { - match &self.future { - Some(_) => { - let res = - futures_core::ready!(self.as_mut().future().as_pin_mut().unwrap().poll(cx)); + if this.future.is_some() { + let res = futures_core::ready!(this.future.as_mut().as_pin_mut().unwrap().poll(cx)); - self.as_mut().future().set(None); + this.future.set(None); - return Poll::Ready(Some(res)); - } - None => { - let fut = (self.as_mut().f())(); + return Poll::Ready(Some(res)); + } else { + let fut = (this.f)(); - self.as_mut().future().set(Some(fut)); - } + this.future.set(Some(fut)); } } } diff --git a/src/stream/stream/chain.rs b/src/stream/stream/chain.rs index 2693382..df31615 100644 --- a/src/stream/stream/chain.rs +++ b/src/stream/stream/chain.rs @@ -1,20 +1,23 @@ use std::pin::Pin; +use pin_project_lite::pin_project; + use super::fuse::Fuse; use crate::prelude::*; use crate::task::{Context, Poll}; -/// Chains two streams one after another. -#[derive(Debug)] -pub struct Chain { - first: Fuse, - second: Fuse, +pin_project! { + /// Chains two streams one after another. + #[derive(Debug)] + pub struct Chain { + #[pin] + first: Fuse, + #[pin] + second: Fuse, + } } impl Chain { - pin_utils::unsafe_pinned!(first: Fuse); - pin_utils::unsafe_pinned!(second: Fuse); - pub(super) fn new(first: S, second: U) -> Self { Chain { first: first.fuse(), @@ -26,22 +29,23 @@ impl Chain { impl> Stream for Chain { type Item = S::Item; - fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - if !self.first.done { - let next = futures_core::ready!(self.as_mut().first().poll_next(cx)); + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let mut this = self.project(); + if !this.first.done { + let next = futures_core::ready!(this.first.as_mut().poll_next(cx)); if let Some(next) = next { return Poll::Ready(Some(next)); } } - if !self.second.done { - let next = futures_core::ready!(self.as_mut().second().poll_next(cx)); + if !this.second.done { + let next = futures_core::ready!(this.second.as_mut().poll_next(cx)); if let Some(next) = next { return Poll::Ready(Some(next)); } } - if self.first.done && self.second.done { + if this.first.done && this.second.done { return Poll::Ready(None); } diff --git a/src/stream/stream/cmp.rs b/src/stream/stream/cmp.rs index fc7161a..df08e9d 100644 --- a/src/stream/stream/cmp.rs +++ b/src/stream/stream/cmp.rs @@ -1,29 +1,30 @@ use std::cmp::Ordering; use std::pin::Pin; +use pin_project_lite::pin_project; + use super::fuse::Fuse; use crate::future::Future; use crate::prelude::*; use crate::stream::Stream; use crate::task::{Context, Poll}; -// Lexicographically compares the elements of this `Stream` with those -// of another using `Ord`. -#[doc(hidden)] -#[allow(missing_debug_implementations)] -pub struct CmpFuture { - l: Fuse, - r: Fuse, - l_cache: Option, - r_cache: Option, +pin_project! { + // Lexicographically compares the elements of this `Stream` with those + // of another using `Ord`. + #[doc(hidden)] + #[allow(missing_debug_implementations)] + pub struct CmpFuture { + #[pin] + l: Fuse, + #[pin] + r: Fuse, + l_cache: Option, + r_cache: Option, + } } impl CmpFuture { - pin_utils::unsafe_pinned!(l: Fuse); - pin_utils::unsafe_pinned!(r: Fuse); - pin_utils::unsafe_unpinned!(l_cache: Option); - pin_utils::unsafe_unpinned!(r_cache: Option); - pub(super) fn new(l: L, r: R) -> Self { CmpFuture { l: l.fuse(), @@ -42,11 +43,12 @@ where { type Output = Ordering; - fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let mut this = self.project(); loop { // Stream that completes earliest can be considered Less, etc - let l_complete = self.l.done && self.as_mut().l_cache.is_none(); - let r_complete = self.r.done && self.as_mut().r_cache.is_none(); + let l_complete = this.l.done && this.l_cache.is_none(); + let r_complete = this.r.done && this.r_cache.is_none(); if l_complete && r_complete { return Poll::Ready(Ordering::Equal); @@ -57,30 +59,30 @@ where } // Get next value if possible and necesary - if !self.l.done && self.as_mut().l_cache.is_none() { - let l_next = futures_core::ready!(self.as_mut().l().poll_next(cx)); + if !this.l.done && this.l_cache.is_none() { + let l_next = futures_core::ready!(this.l.as_mut().poll_next(cx)); if let Some(item) = l_next { - *self.as_mut().l_cache() = Some(item); + *this.l_cache = Some(item); } } - if !self.r.done && self.as_mut().r_cache.is_none() { - let r_next = futures_core::ready!(self.as_mut().r().poll_next(cx)); + if !this.r.done && this.r_cache.is_none() { + let r_next = futures_core::ready!(this.r.as_mut().poll_next(cx)); if let Some(item) = r_next { - *self.as_mut().r_cache() = Some(item); + *this.r_cache = Some(item); } } // Compare if both values are available. - if self.as_mut().l_cache.is_some() && self.as_mut().r_cache.is_some() { - let l_value = self.as_mut().l_cache().take().unwrap(); - let r_value = self.as_mut().r_cache().take().unwrap(); + if this.l_cache.is_some() && this.r_cache.is_some() { + let l_value = this.l_cache.take().unwrap(); + let r_value = this.r_cache.take().unwrap(); let result = l_value.cmp(&r_value); if let Ordering::Equal = result { // Reset cache to prepare for next comparison - *self.as_mut().l_cache() = None; - *self.as_mut().r_cache() = None; + *this.l_cache = None; + *this.r_cache = None; } else { // Return non equal value return Poll::Ready(result); diff --git a/src/stream/stream/enumerate.rs b/src/stream/stream/enumerate.rs index 7d5a3d6..2a3afa8 100644 --- a/src/stream/stream/enumerate.rs +++ b/src/stream/stream/enumerate.rs @@ -1,19 +1,21 @@ -use crate::task::{Context, Poll}; use std::pin::Pin; +use pin_project_lite::pin_project; + use crate::stream::Stream; +use crate::task::{Context, Poll}; -#[doc(hidden)] -#[allow(missing_debug_implementations)] -pub struct Enumerate { - stream: S, - i: usize, +pin_project! { + #[doc(hidden)] + #[allow(missing_debug_implementations)] + pub struct Enumerate { + #[pin] + stream: S, + i: usize, + } } impl Enumerate { - pin_utils::unsafe_pinned!(stream: S); - pin_utils::unsafe_unpinned!(i: usize); - pub(super) fn new(stream: S) -> Self { Enumerate { stream, i: 0 } } @@ -25,13 +27,14 @@ where { type Item = (usize, S::Item); - fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - let next = futures_core::ready!(self.as_mut().stream().poll_next(cx)); + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let this = self.project(); + let next = futures_core::ready!(this.stream.poll_next(cx)); match next { Some(v) => { - let ret = (self.i, v); - *self.as_mut().i() += 1; + let ret = (*this.i, v); + *this.i += 1; Poll::Ready(Some(ret)) } None => Poll::Ready(None), diff --git a/src/stream/stream/filter.rs b/src/stream/stream/filter.rs index 8ed282c..eb4153f 100644 --- a/src/stream/stream/filter.rs +++ b/src/stream/stream/filter.rs @@ -1,21 +1,23 @@ use std::marker::PhantomData; use std::pin::Pin; +use pin_project_lite::pin_project; + use crate::stream::Stream; use crate::task::{Context, Poll}; -/// A stream to filter elements of another stream with a predicate. -#[derive(Debug)] -pub struct Filter { - stream: S, - predicate: P, - __t: PhantomData, +pin_project! { + /// A stream to filter elements of another stream with a predicate. + #[derive(Debug)] + pub struct Filter { + #[pin] + stream: S, + predicate: P, + __t: PhantomData, + } } impl Filter { - pin_utils::unsafe_pinned!(stream: S); - pin_utils::unsafe_unpinned!(predicate: P); - pub(super) fn new(stream: S, predicate: P) -> Self { Filter { stream, @@ -32,11 +34,12 @@ where { type Item = S::Item; - fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - let next = futures_core::ready!(self.as_mut().stream().poll_next(cx)); + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let this = self.project(); + let next = futures_core::ready!(this.stream.poll_next(cx)); match next { - Some(v) if (self.as_mut().predicate())(&v) => Poll::Ready(Some(v)), + Some(v) if (this.predicate)(&v) => Poll::Ready(Some(v)), Some(_) => { cx.waker().wake_by_ref(); Poll::Pending diff --git a/src/stream/stream/filter_map.rs b/src/stream/stream/filter_map.rs index 756efff..6a4593f 100644 --- a/src/stream/stream/filter_map.rs +++ b/src/stream/stream/filter_map.rs @@ -2,21 +2,23 @@ use std::marker::PhantomData; use std::pin::Pin; use std::task::{Context, Poll}; +use pin_project_lite::pin_project; + use crate::stream::Stream; -#[doc(hidden)] -#[allow(missing_debug_implementations)] -pub struct FilterMap { - stream: S, - f: F, - __from: PhantomData, - __to: PhantomData, +pin_project! { + #[doc(hidden)] + #[allow(missing_debug_implementations)] + pub struct FilterMap { + #[pin] + stream: S, + f: F, + __from: PhantomData, + __to: PhantomData, + } } impl FilterMap { - pin_utils::unsafe_pinned!(stream: S); - pin_utils::unsafe_unpinned!(f: F); - pub(crate) fn new(stream: S, f: F) -> Self { FilterMap { stream, @@ -34,10 +36,11 @@ where { type Item = B; - fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - let next = futures_core::ready!(self.as_mut().stream().poll_next(cx)); + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let this = self.project(); + let next = futures_core::ready!(this.stream.poll_next(cx)); match next { - Some(v) => match (self.as_mut().f())(v) { + Some(v) => match (this.f)(v) { Some(b) => Poll::Ready(Some(b)), None => { cx.waker().wake_by_ref(); diff --git a/src/stream/stream/fold.rs b/src/stream/stream/fold.rs index 18ddcd8..5b0eb12 100644 --- a/src/stream/stream/fold.rs +++ b/src/stream/stream/fold.rs @@ -1,24 +1,25 @@ use std::marker::PhantomData; use std::pin::Pin; +use pin_project_lite::pin_project; + use crate::future::Future; use crate::stream::Stream; use crate::task::{Context, Poll}; -#[doc(hidden)] -#[allow(missing_debug_implementations)] -pub struct FoldFuture { - stream: S, - f: F, - acc: Option, - __t: PhantomData, +pin_project! { + #[doc(hidden)] + #[allow(missing_debug_implementations)] + pub struct FoldFuture { + #[pin] + stream: S, + f: F, + acc: Option, + __t: PhantomData, + } } impl FoldFuture { - pin_utils::unsafe_pinned!(stream: S); - pin_utils::unsafe_unpinned!(f: F); - pin_utils::unsafe_unpinned!(acc: Option); - pub(super) fn new(stream: S, init: B, f: F) -> Self { FoldFuture { stream, @@ -36,17 +37,18 @@ where { type Output = B; - fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let mut this = self.project(); loop { - let next = futures_core::ready!(self.as_mut().stream().poll_next(cx)); + let next = futures_core::ready!(this.stream.as_mut().poll_next(cx)); match next { Some(v) => { - let old = self.as_mut().acc().take().unwrap(); - let new = (self.as_mut().f())(old, v); - *self.as_mut().acc() = Some(new); + let old = this.acc.take().unwrap(); + let new = (this.f)(old, v); + *this.acc = Some(new); } - None => return Poll::Ready(self.as_mut().acc().take().unwrap()), + None => return Poll::Ready(this.acc.take().unwrap()), } } } diff --git a/src/stream/stream/for_each.rs b/src/stream/stream/for_each.rs index 0406a50..4696529 100644 --- a/src/stream/stream/for_each.rs +++ b/src/stream/stream/for_each.rs @@ -1,22 +1,24 @@ use std::marker::PhantomData; use std::pin::Pin; +use pin_project_lite::pin_project; + use crate::future::Future; use crate::stream::Stream; use crate::task::{Context, Poll}; -#[doc(hidden)] -#[allow(missing_debug_implementations)] -pub struct ForEachFuture { - stream: S, - f: F, - __t: PhantomData, +pin_project! { + #[doc(hidden)] + #[allow(missing_debug_implementations)] + pub struct ForEachFuture { + #[pin] + stream: S, + f: F, + __t: PhantomData, + } } impl ForEachFuture { - pin_utils::unsafe_pinned!(stream: S); - pin_utils::unsafe_unpinned!(f: F); - pub(super) fn new(stream: S, f: F) -> Self { ForEachFuture { stream, @@ -33,12 +35,13 @@ where { type Output = (); - fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let mut this = self.project(); loop { - let next = futures_core::ready!(self.as_mut().stream().poll_next(cx)); + let next = futures_core::ready!(this.stream.as_mut().poll_next(cx)); match next { - Some(v) => (self.as_mut().f())(v), + Some(v) => (this.f)(v), None => return Poll::Ready(()), } } diff --git a/src/stream/stream/fuse.rs b/src/stream/stream/fuse.rs index ff5bdab..1162970 100644 --- a/src/stream/stream/fuse.rs +++ b/src/stream/stream/fuse.rs @@ -1,33 +1,32 @@ use std::pin::Pin; +use pin_project_lite::pin_project; + use crate::stream::Stream; use crate::task::{Context, Poll}; -/// A `Stream` that is permanently closed once a single call to `poll` results in -/// `Poll::Ready(None)`, returning `Poll::Ready(None)` for all future calls to `poll`. -#[derive(Clone, Debug)] -pub struct Fuse { - pub(crate) stream: S, - pub(crate) done: bool, -} - -impl Unpin for Fuse {} - -impl Fuse { - pin_utils::unsafe_pinned!(stream: S); - pin_utils::unsafe_unpinned!(done: bool); +pin_project! { + /// A `Stream` that is permanently closed once a single call to `poll` results in + /// `Poll::Ready(None)`, returning `Poll::Ready(None)` for all future calls to `poll`. + #[derive(Clone, Debug)] + pub struct Fuse { + #[pin] + pub(crate) stream: S, + pub(crate) done: bool, + } } impl Stream for Fuse { type Item = S::Item; - fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - if self.done { + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let this = self.project(); + if *this.done { Poll::Ready(None) } else { - let next = futures_core::ready!(self.as_mut().stream().poll_next(cx)); + let next = futures_core::ready!(this.stream.poll_next(cx)); if next.is_none() { - *self.as_mut().done() = true; + *this.done = true; } Poll::Ready(next) } diff --git a/src/stream/stream/ge.rs b/src/stream/stream/ge.rs index eb9786b..3dc6031 100644 --- a/src/stream/stream/ge.rs +++ b/src/stream/stream/ge.rs @@ -1,26 +1,29 @@ use std::cmp::Ordering; use std::pin::Pin; +use pin_project_lite::pin_project; + use super::partial_cmp::PartialCmpFuture; use crate::future::Future; use crate::prelude::*; use crate::stream::Stream; use crate::task::{Context, Poll}; -// Determines if the elements of this `Stream` are lexicographically -// greater than or equal to those of another. -#[doc(hidden)] -#[allow(missing_debug_implementations)] -pub struct GeFuture { - partial_cmp: PartialCmpFuture, +pin_project! { + // Determines if the elements of this `Stream` are lexicographically + // greater than or equal to those of another. + #[doc(hidden)] + #[allow(missing_debug_implementations)] + pub struct GeFuture { + #[pin] + partial_cmp: PartialCmpFuture, + } } impl GeFuture where L::Item: PartialOrd, { - pin_utils::unsafe_pinned!(partial_cmp: PartialCmpFuture); - pub(super) fn new(l: L, r: R) -> Self { GeFuture { partial_cmp: l.partial_cmp(r), @@ -30,14 +33,14 @@ where impl Future for GeFuture where - L: Stream + Sized, - R: Stream + Sized, + L: Stream, + R: Stream, L::Item: PartialOrd, { type Output = bool; - fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - let result = futures_core::ready!(self.as_mut().partial_cmp().poll(cx)); + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let result = futures_core::ready!(self.project().partial_cmp.poll(cx)); match result { Some(Ordering::Greater) | Some(Ordering::Equal) => Poll::Ready(true), diff --git a/src/stream/stream/gt.rs b/src/stream/stream/gt.rs index 6c480a2..513ca76 100644 --- a/src/stream/stream/gt.rs +++ b/src/stream/stream/gt.rs @@ -1,26 +1,29 @@ use std::cmp::Ordering; use std::pin::Pin; +use pin_project_lite::pin_project; + use super::partial_cmp::PartialCmpFuture; use crate::future::Future; use crate::prelude::*; use crate::stream::Stream; use crate::task::{Context, Poll}; -// Determines if the elements of this `Stream` are lexicographically -// greater than those of another. -#[doc(hidden)] -#[allow(missing_debug_implementations)] -pub struct GtFuture { - partial_cmp: PartialCmpFuture, +pin_project! { + // Determines if the elements of this `Stream` are lexicographically + // greater than those of another. + #[doc(hidden)] + #[allow(missing_debug_implementations)] + pub struct GtFuture { + #[pin] + partial_cmp: PartialCmpFuture, + } } impl GtFuture where L::Item: PartialOrd, { - pin_utils::unsafe_pinned!(partial_cmp: PartialCmpFuture); - pub(super) fn new(l: L, r: R) -> Self { GtFuture { partial_cmp: l.partial_cmp(r), @@ -36,8 +39,8 @@ where { type Output = bool; - fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - let result = futures_core::ready!(self.as_mut().partial_cmp().poll(cx)); + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let result = futures_core::ready!(self.project().partial_cmp.poll(cx)); match result { Some(Ordering::Greater) => Poll::Ready(true), diff --git a/src/stream/stream/inspect.rs b/src/stream/stream/inspect.rs index e63b584..5de60fb 100644 --- a/src/stream/stream/inspect.rs +++ b/src/stream/stream/inspect.rs @@ -1,21 +1,23 @@ use std::marker::PhantomData; use std::pin::Pin; +use pin_project_lite::pin_project; + use crate::stream::Stream; use crate::task::{Context, Poll}; -/// A stream that does something with each element of another stream. -#[derive(Debug)] -pub struct Inspect { - stream: S, - f: F, - __t: PhantomData, +pin_project! { + /// A stream that does something with each element of another stream. + #[derive(Debug)] + pub struct Inspect { + #[pin] + stream: S, + f: F, + __t: PhantomData, + } } impl Inspect { - pin_utils::unsafe_pinned!(stream: S); - pin_utils::unsafe_unpinned!(f: F); - pub(super) fn new(stream: S, f: F) -> Self { Inspect { stream, @@ -32,11 +34,12 @@ where { type Item = S::Item; - fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - let next = futures_core::ready!(self.as_mut().stream().poll_next(cx)); + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let mut this = self.project(); + let next = futures_core::ready!(this.stream.as_mut().poll_next(cx)); Poll::Ready(next.and_then(|x| { - (self.as_mut().f())(&x); + (this.f)(&x); Some(x) })) } diff --git a/src/stream/stream/last.rs b/src/stream/stream/last.rs index c58dd66..eba01e5 100644 --- a/src/stream/stream/last.rs +++ b/src/stream/stream/last.rs @@ -1,20 +1,22 @@ use std::pin::Pin; +use pin_project_lite::pin_project; + use crate::future::Future; use crate::stream::Stream; use crate::task::{Context, Poll}; -#[doc(hidden)] -#[allow(missing_debug_implementations)] -pub struct LastFuture { - stream: S, - last: Option, +pin_project! { + #[doc(hidden)] + #[allow(missing_debug_implementations)] + pub struct LastFuture { + #[pin] + stream: S, + last: Option, + } } impl LastFuture { - pin_utils::unsafe_pinned!(stream: S); - pin_utils::unsafe_unpinned!(last: Option); - pub(crate) fn new(stream: S) -> Self { LastFuture { stream, last: None } } @@ -27,16 +29,17 @@ where { type Output = Option; - fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - let next = futures_core::ready!(self.as_mut().stream().poll_next(cx)); + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let this = self.project(); + let next = futures_core::ready!(this.stream.poll_next(cx)); match next { Some(new) => { cx.waker().wake_by_ref(); - *self.as_mut().last() = Some(new); + *this.last = Some(new); Poll::Pending } - None => Poll::Ready(self.last), + None => Poll::Ready(*this.last), } } } diff --git a/src/stream/stream/le.rs b/src/stream/stream/le.rs index 37b62d8..af72700 100644 --- a/src/stream/stream/le.rs +++ b/src/stream/stream/le.rs @@ -1,26 +1,29 @@ use std::cmp::Ordering; use std::pin::Pin; +use pin_project_lite::pin_project; + use super::partial_cmp::PartialCmpFuture; use crate::future::Future; use crate::prelude::*; use crate::stream::Stream; use crate::task::{Context, Poll}; -/// Determines if the elements of this `Stream` are lexicographically -/// less or equal to those of another. -#[doc(hidden)] -#[allow(missing_debug_implementations)] -pub struct LeFuture { - partial_cmp: PartialCmpFuture, +pin_project! { + /// Determines if the elements of this `Stream` are lexicographically + /// less or equal to those of another. + #[doc(hidden)] + #[allow(missing_debug_implementations)] + pub struct LeFuture { + #[pin] + partial_cmp: PartialCmpFuture, + } } impl LeFuture where L::Item: PartialOrd, { - pin_utils::unsafe_pinned!(partial_cmp: PartialCmpFuture); - pub(super) fn new(l: L, r: R) -> Self { LeFuture { partial_cmp: l.partial_cmp(r), @@ -36,8 +39,8 @@ where { type Output = bool; - fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - let result = futures_core::ready!(self.as_mut().partial_cmp().poll(cx)); + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let result = futures_core::ready!(self.project().partial_cmp.poll(cx)); match result { Some(Ordering::Less) | Some(Ordering::Equal) => Poll::Ready(true), diff --git a/src/stream/stream/lt.rs b/src/stream/stream/lt.rs index b774d7b..524f268 100644 --- a/src/stream/stream/lt.rs +++ b/src/stream/stream/lt.rs @@ -1,26 +1,29 @@ use std::cmp::Ordering; use std::pin::Pin; +use pin_project_lite::pin_project; + use super::partial_cmp::PartialCmpFuture; use crate::future::Future; use crate::prelude::*; use crate::stream::Stream; use crate::task::{Context, Poll}; -// Determines if the elements of this `Stream` are lexicographically -// less than those of another. -#[doc(hidden)] -#[allow(missing_debug_implementations)] -pub struct LtFuture { - partial_cmp: PartialCmpFuture, +pin_project! { + // Determines if the elements of this `Stream` are lexicographically + // less than those of another. + #[doc(hidden)] + #[allow(missing_debug_implementations)] + pub struct LtFuture { + #[pin] + partial_cmp: PartialCmpFuture, + } } impl LtFuture where L::Item: PartialOrd, { - pin_utils::unsafe_pinned!(partial_cmp: PartialCmpFuture); - pub(super) fn new(l: L, r: R) -> Self { LtFuture { partial_cmp: l.partial_cmp(r), @@ -36,8 +39,8 @@ where { type Output = bool; - fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - let result = futures_core::ready!(self.as_mut().partial_cmp().poll(cx)); + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let result = futures_core::ready!(self.project().partial_cmp.poll(cx)); match result { Some(Ordering::Less) => Poll::Ready(true), diff --git a/src/stream/stream/map.rs b/src/stream/stream/map.rs index 4bc2e36..a1fafc3 100644 --- a/src/stream/stream/map.rs +++ b/src/stream/stream/map.rs @@ -1,22 +1,24 @@ use std::marker::PhantomData; use std::pin::Pin; +use pin_project_lite::pin_project; + use crate::stream::Stream; use crate::task::{Context, Poll}; -#[doc(hidden)] -#[allow(missing_debug_implementations)] -pub struct Map { - stream: S, - f: F, - __from: PhantomData, - __to: PhantomData, +pin_project! { + #[doc(hidden)] + #[allow(missing_debug_implementations)] + pub struct Map { + #[pin] + stream: S, + f: F, + __from: PhantomData, + __to: PhantomData, + } } impl Map { - pin_utils::unsafe_pinned!(stream: S); - pin_utils::unsafe_unpinned!(f: F); - pub(crate) fn new(stream: S, f: F) -> Self { Map { stream, @@ -34,8 +36,9 @@ where { type Item = B; - fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - let next = futures_core::ready!(self.as_mut().stream().poll_next(cx)); - Poll::Ready(next.map(self.as_mut().f())) + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let this = self.project(); + let next = futures_core::ready!(this.stream.poll_next(cx)); + Poll::Ready(next.map(this.f)) } } diff --git a/src/stream/stream/merge.rs b/src/stream/stream/merge.rs index 9889dc7..3ccc223 100644 --- a/src/stream/stream/merge.rs +++ b/src/stream/stream/merge.rs @@ -2,22 +2,25 @@ use std::pin::Pin; use std::task::{Context, Poll}; use futures_core::Stream; +use pin_project_lite::pin_project; -/// A stream that merges two other streams into a single stream. -/// -/// This stream is returned by [`Stream::merge`]. -/// -/// [`Stream::merge`]: trait.Stream.html#method.merge -#[cfg(feature = "unstable")] -#[cfg_attr(feature = "docs", doc(cfg(unstable)))] -#[derive(Debug)] -pub struct Merge { - left: L, - right: R, +pin_project! { + /// A stream that merges two other streams into a single stream. + /// + /// This stream is returned by [`Stream::merge`]. + /// + /// [`Stream::merge`]: trait.Stream.html#method.merge + #[cfg(feature = "unstable")] + #[cfg_attr(feature = "docs", doc(cfg(unstable)))] + #[derive(Debug)] + pub struct Merge { + #[pin] + left: L, + #[pin] + right: R, + } } -impl Unpin for Merge {} - impl Merge { pub(crate) fn new(left: L, right: R) -> Self { Self { left, right } @@ -26,19 +29,20 @@ impl Merge { impl Stream for Merge where - L: Stream + Unpin, - R: Stream + Unpin, + L: Stream, + R: Stream, { type Item = T; - fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - if let Poll::Ready(Some(item)) = Pin::new(&mut self.left).poll_next(cx) { + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let this = self.project(); + if let Poll::Ready(Some(item)) = this.left.poll_next(cx) { // The first stream made progress. The Merge needs to be polled // again to check the progress of the second stream. cx.waker().wake_by_ref(); Poll::Ready(Some(item)) } else { - Pin::new(&mut self.right).poll_next(cx) + this.right.poll_next(cx) } } } diff --git a/src/stream/stream/min_by.rs b/src/stream/stream/min_by.rs index a68cf31..ab12aa0 100644 --- a/src/stream/stream/min_by.rs +++ b/src/stream/stream/min_by.rs @@ -1,23 +1,24 @@ use std::cmp::Ordering; use std::pin::Pin; +use pin_project_lite::pin_project; + use crate::future::Future; use crate::stream::Stream; use crate::task::{Context, Poll}; -#[doc(hidden)] -#[allow(missing_debug_implementations)] -pub struct MinByFuture { - stream: S, - compare: F, - min: Option, +pin_project! { + #[doc(hidden)] + #[allow(missing_debug_implementations)] + pub struct MinByFuture { + #[pin] + stream: S, + compare: F, + min: Option, + } } impl MinByFuture { - pin_utils::unsafe_pinned!(stream: S); - pin_utils::unsafe_unpinned!(compare: F); - pin_utils::unsafe_unpinned!(min: Option); - pub(super) fn new(stream: S, compare: F) -> Self { MinByFuture { stream, @@ -35,22 +36,23 @@ where { type Output = Option; - fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - let next = futures_core::ready!(self.as_mut().stream().poll_next(cx)); + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let this = self.project(); + let next = futures_core::ready!(this.stream.poll_next(cx)); match next { Some(new) => { cx.waker().wake_by_ref(); - match self.as_mut().min().take() { - None => *self.as_mut().min() = Some(new), - Some(old) => match (&mut self.as_mut().compare())(&new, &old) { - Ordering::Less => *self.as_mut().min() = Some(new), - _ => *self.as_mut().min() = Some(old), + match this.min.take() { + None => *this.min = Some(new), + Some(old) => match (this.compare)(&new, &old) { + Ordering::Less => *this.min = Some(new), + _ => *this.min = Some(old), }, } Poll::Pending } - None => Poll::Ready(self.min), + None => Poll::Ready(*this.min), } } } diff --git a/src/stream/stream/partial_cmp.rs b/src/stream/stream/partial_cmp.rs index fac2705..e30c6ea 100644 --- a/src/stream/stream/partial_cmp.rs +++ b/src/stream/stream/partial_cmp.rs @@ -1,29 +1,30 @@ use std::cmp::Ordering; use std::pin::Pin; +use pin_project_lite::pin_project; + use super::fuse::Fuse; use crate::future::Future; use crate::prelude::*; use crate::stream::Stream; use crate::task::{Context, Poll}; -// Lexicographically compares the elements of this `Stream` with those -// of another. -#[doc(hidden)] -#[allow(missing_debug_implementations)] -pub struct PartialCmpFuture { - l: Fuse, - r: Fuse, - l_cache: Option, - r_cache: Option, +pin_project! { + // Lexicographically compares the elements of this `Stream` with those + // of another. + #[doc(hidden)] + #[allow(missing_debug_implementations)] + pub struct PartialCmpFuture { + #[pin] + l: Fuse, + #[pin] + r: Fuse, + l_cache: Option, + r_cache: Option, + } } impl PartialCmpFuture { - pin_utils::unsafe_pinned!(l: Fuse); - pin_utils::unsafe_pinned!(r: Fuse); - pin_utils::unsafe_unpinned!(l_cache: Option); - pin_utils::unsafe_unpinned!(r_cache: Option); - pub(super) fn new(l: L, r: R) -> Self { PartialCmpFuture { l: l.fuse(), @@ -42,12 +43,13 @@ where { type Output = Option; - fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let mut this = self.project(); loop { // Short circuit logic // Stream that completes earliest can be considered Less, etc - let l_complete = self.l.done && self.as_mut().l_cache.is_none(); - let r_complete = self.r.done && self.as_mut().r_cache.is_none(); + let l_complete = this.l.done && this.l_cache.is_none(); + let r_complete = this.r.done && this.r_cache.is_none(); if l_complete && r_complete { return Poll::Ready(Some(Ordering::Equal)); @@ -58,30 +60,30 @@ where } // Get next value if possible and necesary - if !self.l.done && self.as_mut().l_cache.is_none() { - let l_next = futures_core::ready!(self.as_mut().l().poll_next(cx)); + if !this.l.done && this.l_cache.is_none() { + let l_next = futures_core::ready!(this.l.as_mut().poll_next(cx)); if let Some(item) = l_next { - *self.as_mut().l_cache() = Some(item); + *this.l_cache = Some(item); } } - if !self.r.done && self.as_mut().r_cache.is_none() { - let r_next = futures_core::ready!(self.as_mut().r().poll_next(cx)); + if !this.r.done && this.r_cache.is_none() { + let r_next = futures_core::ready!(this.r.as_mut().poll_next(cx)); if let Some(item) = r_next { - *self.as_mut().r_cache() = Some(item); + *this.r_cache = Some(item); } } // Compare if both values are available. - if self.as_mut().l_cache.is_some() && self.as_mut().r_cache.is_some() { - let l_value = self.as_mut().l_cache().take().unwrap(); - let r_value = self.as_mut().r_cache().take().unwrap(); + if this.l_cache.is_some() && this.r_cache.is_some() { + let l_value = this.l_cache.as_mut().take().unwrap(); + let r_value = this.r_cache.as_mut().take().unwrap(); let result = l_value.partial_cmp(&r_value); if let Some(Ordering::Equal) = result { // Reset cache to prepare for next comparison - *self.as_mut().l_cache() = None; - *self.as_mut().r_cache() = None; + *this.l_cache = None; + *this.r_cache = None; } else { // Return non equal value return Poll::Ready(result); diff --git a/src/stream/stream/scan.rs b/src/stream/stream/scan.rs index 7897516..c4771d8 100644 --- a/src/stream/stream/scan.rs +++ b/src/stream/stream/scan.rs @@ -1,13 +1,18 @@ use std::pin::Pin; +use pin_project_lite::pin_project; + use crate::stream::Stream; use crate::task::{Context, Poll}; -/// A stream to maintain state while polling another stream. -#[derive(Debug)] -pub struct Scan { - stream: S, - state_f: (St, F), +pin_project! { + /// A stream to maintain state while polling another stream. + #[derive(Debug)] + pub struct Scan { + #[pin] + stream: S, + state_f: (St, F), + } } impl Scan { @@ -17,13 +22,8 @@ impl Scan { state_f: (initial_state, f), } } - - pin_utils::unsafe_pinned!(stream: S); - pin_utils::unsafe_unpinned!(state_f: (St, F)); } -impl Unpin for Scan {} - impl Stream for Scan where S: Stream, @@ -31,11 +31,12 @@ where { type Item = B; - fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - let poll_result = self.as_mut().stream().poll_next(cx); + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let mut this = self.project(); + let poll_result = this.stream.as_mut().poll_next(cx); poll_result.map(|item| { item.and_then(|item| { - let (state, f) = self.as_mut().state_f(); + let (state, f) = this.state_f; f(state, item) }) }) diff --git a/src/stream/stream/skip.rs b/src/stream/stream/skip.rs index 8a2d966..6562b99 100644 --- a/src/stream/stream/skip.rs +++ b/src/stream/stream/skip.rs @@ -1,19 +1,21 @@ use std::pin::Pin; use std::task::{Context, Poll}; +use pin_project_lite::pin_project; + use crate::stream::Stream; -/// A stream to skip first n elements of another stream. -#[derive(Debug)] -pub struct Skip { - stream: S, - n: usize, +pin_project! { + /// A stream to skip first n elements of another stream. + #[derive(Debug)] + pub struct Skip { + #[pin] + stream: S, + n: usize, + } } impl Skip { - pin_utils::unsafe_pinned!(stream: S); - pin_utils::unsafe_unpinned!(n: usize); - pub(crate) fn new(stream: S, n: usize) -> Self { Skip { stream, n } } @@ -25,14 +27,15 @@ where { type Item = S::Item; - fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let mut this = self.project(); loop { - let next = futures_core::ready!(self.as_mut().stream().poll_next(cx)); + let next = futures_core::ready!(this.stream.as_mut().poll_next(cx)); match next { - Some(v) => match self.n { + Some(v) => match *this.n { 0 => return Poll::Ready(Some(v)), - _ => *self.as_mut().n() -= 1, + _ => *this.n -= 1, }, None => return Poll::Ready(None), } diff --git a/src/stream/stream/skip_while.rs b/src/stream/stream/skip_while.rs index b1a8d9e..0499df2 100644 --- a/src/stream/stream/skip_while.rs +++ b/src/stream/stream/skip_while.rs @@ -1,21 +1,23 @@ use std::marker::PhantomData; use std::pin::Pin; +use pin_project_lite::pin_project; + use crate::stream::Stream; use crate::task::{Context, Poll}; -/// A stream to skip elements of another stream based on a predicate. -#[derive(Debug)] -pub struct SkipWhile { - stream: S, - predicate: Option

, - __t: PhantomData, +pin_project! { + /// A stream to skip elements of another stream based on a predicate. + #[derive(Debug)] + pub struct SkipWhile { + #[pin] + stream: S, + predicate: Option

, + __t: PhantomData, + } } impl SkipWhile { - pin_utils::unsafe_pinned!(stream: S); - pin_utils::unsafe_unpinned!(predicate: Option

); - pub(crate) fn new(stream: S, predicate: P) -> Self { SkipWhile { stream, @@ -32,15 +34,16 @@ where { type Item = S::Item; - fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let mut this = self.project(); loop { - let next = futures_core::ready!(self.as_mut().stream().poll_next(cx)); + let next = futures_core::ready!(this.stream.as_mut().poll_next(cx)); match next { - Some(v) => match self.as_mut().predicate() { + Some(v) => match this.predicate { Some(p) => { if !p(&v) { - *self.as_mut().predicate() = None; + *this.predicate = None; return Poll::Ready(Some(v)); } } diff --git a/src/stream/stream/step_by.rs b/src/stream/stream/step_by.rs index f84feec..ab9e45b 100644 --- a/src/stream/stream/step_by.rs +++ b/src/stream/stream/step_by.rs @@ -1,21 +1,22 @@ use std::pin::Pin; +use pin_project_lite::pin_project; + use crate::stream::Stream; use crate::task::{Context, Poll}; -/// A stream that steps a given amount of elements of another stream. -#[derive(Debug)] -pub struct StepBy { - stream: S, - step: usize, - i: usize, +pin_project! { + /// A stream that steps a given amount of elements of another stream. + #[derive(Debug)] + pub struct StepBy { + #[pin] + stream: S, + step: usize, + i: usize, + } } impl StepBy { - pin_utils::unsafe_pinned!(stream: S); - pin_utils::unsafe_unpinned!(step: usize); - pin_utils::unsafe_unpinned!(i: usize); - pub(crate) fn new(stream: S, step: usize) -> Self { StepBy { stream, @@ -31,17 +32,18 @@ where { type Item = S::Item; - fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let mut this = self.project(); loop { - let next = futures_core::ready!(self.as_mut().stream().poll_next(cx)); + let next = futures_core::ready!(this.stream.as_mut().poll_next(cx)); match next { - Some(v) => match self.i { + Some(v) => match this.i { 0 => { - *self.as_mut().i() = self.step; + *this.i = *this.step; return Poll::Ready(Some(v)); } - _ => *self.as_mut().i() -= 1, + _ => *this.i -= 1, }, None => return Poll::Ready(None), } diff --git a/src/stream/stream/take.rs b/src/stream/stream/take.rs index 81d48d2..835bc44 100644 --- a/src/stream/stream/take.rs +++ b/src/stream/stream/take.rs @@ -1,33 +1,32 @@ use std::pin::Pin; +use pin_project_lite::pin_project; + use crate::stream::Stream; use crate::task::{Context, Poll}; -/// A stream that yields the first `n` items of another stream. -#[derive(Clone, Debug)] -pub struct Take { - pub(crate) stream: S, - pub(crate) remaining: usize, -} - -impl Unpin for Take {} - -impl Take { - pin_utils::unsafe_pinned!(stream: S); - pin_utils::unsafe_unpinned!(remaining: usize); +pin_project! { + /// A stream that yields the first `n` items of another stream. + #[derive(Clone, Debug)] + pub struct Take { + #[pin] + pub(crate) stream: S, + pub(crate) remaining: usize, + } } impl Stream for Take { type Item = S::Item; - fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - if self.remaining == 0 { + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let this = self.project(); + if *this.remaining == 0 { Poll::Ready(None) } else { - let next = futures_core::ready!(self.as_mut().stream().poll_next(cx)); + let next = futures_core::ready!(this.stream.poll_next(cx)); match next { - Some(_) => *self.as_mut().remaining() -= 1, - None => *self.as_mut().remaining() = 0, + Some(_) => *this.remaining -= 1, + None => *this.remaining = 0, } Poll::Ready(next) } diff --git a/src/stream/stream/take_while.rs b/src/stream/stream/take_while.rs index 6f3cc8f..bf89458 100644 --- a/src/stream/stream/take_while.rs +++ b/src/stream/stream/take_while.rs @@ -1,21 +1,23 @@ use std::marker::PhantomData; use std::pin::Pin; +use pin_project_lite::pin_project; + use crate::stream::Stream; use crate::task::{Context, Poll}; -/// A stream that yields elements based on a predicate. -#[derive(Debug)] -pub struct TakeWhile { - stream: S, - predicate: P, - __t: PhantomData, +pin_project! { + /// A stream that yields elements based on a predicate. + #[derive(Debug)] + pub struct TakeWhile { + #[pin] + stream: S, + predicate: P, + __t: PhantomData, + } } impl TakeWhile { - pin_utils::unsafe_pinned!(stream: S); - pin_utils::unsafe_unpinned!(predicate: P); - pub(super) fn new(stream: S, predicate: P) -> Self { TakeWhile { stream, @@ -32,11 +34,12 @@ where { type Item = S::Item; - fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - let next = futures_core::ready!(self.as_mut().stream().poll_next(cx)); + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let this = self.project(); + let next = futures_core::ready!(this.stream.poll_next(cx)); match next { - Some(v) if (self.as_mut().predicate())(&v) => Poll::Ready(Some(v)), + Some(v) if (this.predicate)(&v) => Poll::Ready(Some(v)), Some(_) => { cx.waker().wake_by_ref(); Poll::Pending diff --git a/src/stream/stream/try_fold.rs b/src/stream/stream/try_fold.rs index 212b058..80392e1 100644 --- a/src/stream/stream/try_fold.rs +++ b/src/stream/stream/try_fold.rs @@ -1,24 +1,25 @@ use std::marker::PhantomData; use std::pin::Pin; +use pin_project_lite::pin_project; + use crate::future::Future; use crate::stream::Stream; use crate::task::{Context, Poll}; -#[doc(hidden)] -#[allow(missing_debug_implementations)] -pub struct TryFoldFuture { - stream: S, - f: F, - acc: Option, - __t: PhantomData, +pin_project! { + #[doc(hidden)] + #[allow(missing_debug_implementations)] + pub struct TryFoldFuture { + #[pin] + stream: S, + f: F, + acc: Option, + __t: PhantomData, + } } impl TryFoldFuture { - pin_utils::unsafe_pinned!(stream: S); - pin_utils::unsafe_unpinned!(f: F); - pin_utils::unsafe_unpinned!(acc: Option); - pub(super) fn new(stream: S, init: T, f: F) -> Self { TryFoldFuture { stream, @@ -36,23 +37,22 @@ where { type Output = Result; - fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let mut this = self.project(); loop { - let next = futures_core::ready!(self.as_mut().stream().poll_next(cx)); + let next = futures_core::ready!(this.stream.as_mut().poll_next(cx)); match next { Some(v) => { - let old = self.as_mut().acc().take().unwrap(); - let new = (self.as_mut().f())(old, v); + let old = this.acc.take().unwrap(); + let new = (this.f)(old, v); match new { - Ok(o) => { - *self.as_mut().acc() = Some(o); - } + Ok(o) => *this.acc = Some(o), Err(e) => return Poll::Ready(Err(e)), } } - None => return Poll::Ready(Ok(self.as_mut().acc().take().unwrap())), + None => return Poll::Ready(Ok(this.acc.take().unwrap())), } } } diff --git a/src/stream/stream/try_for_each.rs b/src/stream/stream/try_for_each.rs index ae3d5ea..578b854 100644 --- a/src/stream/stream/try_for_each.rs +++ b/src/stream/stream/try_for_each.rs @@ -1,23 +1,25 @@ use std::marker::PhantomData; use std::pin::Pin; +use pin_project_lite::pin_project; + use crate::future::Future; use crate::stream::Stream; use crate::task::{Context, Poll}; -#[doc(hidden)] -#[allow(missing_debug_implementations)] -pub struct TryForEeachFuture { - stream: S, - f: F, - __from: PhantomData, - __to: PhantomData, +pin_project! { + #[doc(hidden)] + #[allow(missing_debug_implementations)] + pub struct TryForEeachFuture { + #[pin] + stream: S, + f: F, + __from: PhantomData, + __to: PhantomData, + } } impl TryForEeachFuture { - pin_utils::unsafe_pinned!(stream: S); - pin_utils::unsafe_unpinned!(f: F); - pub(crate) fn new(stream: S, f: F) -> Self { TryForEeachFuture { stream, @@ -36,14 +38,15 @@ where { type Output = Result<(), E>; - fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let mut this = self.project(); loop { - let item = futures_core::ready!(self.as_mut().stream().poll_next(cx)); + let item = futures_core::ready!(this.stream.as_mut().poll_next(cx)); match item { None => return Poll::Ready(Ok(())), Some(v) => { - let res = (self.as_mut().f())(v); + let res = (this.f)(v); if let Err(e) = res { return Poll::Ready(Err(e)); } diff --git a/src/stream/stream/zip.rs b/src/stream/stream/zip.rs index 4c66aef..9b7c299 100644 --- a/src/stream/stream/zip.rs +++ b/src/stream/stream/zip.rs @@ -1,14 +1,20 @@ use std::fmt; use std::pin::Pin; +use pin_project_lite::pin_project; + use crate::stream::Stream; use crate::task::{Context, Poll}; -/// An iterator that iterates two other iterators simultaneously. -pub struct Zip { - item_slot: Option, - first: A, - second: B, +pin_project! { + /// An iterator that iterates two other iterators simultaneously. + pub struct Zip { + item_slot: Option, + #[pin] + first: A, + #[pin] + second: B, + } } impl fmt::Debug for Zip { @@ -20,8 +26,6 @@ impl fmt::Debug for Zip { } } -impl Unpin for Zip {} - impl Zip { pub(crate) fn new(first: A, second: B) -> Self { Zip { @@ -30,25 +34,22 @@ impl Zip { second, } } - - pin_utils::unsafe_unpinned!(item_slot: Option); - pin_utils::unsafe_pinned!(first: A); - pin_utils::unsafe_pinned!(second: B); } impl Stream for Zip { type Item = (A::Item, B::Item); - fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - if self.as_mut().item_slot().is_none() { - match self.as_mut().first().poll_next(cx) { + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let this = self.project(); + if this.item_slot.is_none() { + match this.first.poll_next(cx) { Poll::Pending => return Poll::Pending, Poll::Ready(None) => return Poll::Ready(None), - Poll::Ready(Some(item)) => *self.as_mut().item_slot() = Some(item), + Poll::Ready(Some(item)) => *this.item_slot = Some(item), } } - let second_item = futures_core::ready!(self.as_mut().second().poll_next(cx)); - let first_item = self.as_mut().item_slot().take().unwrap(); + let second_item = futures_core::ready!(this.second.poll_next(cx)); + let first_item = this.item_slot.take().unwrap(); Poll::Ready(second_item.map(|second_item| (first_item, second_item))) } } diff --git a/src/sync/barrier.rs b/src/sync/barrier.rs index 080eff8..0163e3f 100644 --- a/src/sync/barrier.rs +++ b/src/sync/barrier.rs @@ -8,7 +8,7 @@ use crate::sync::Mutex; /// # Examples /// /// ``` -/// # fn main() { async_std::task::block_on(async { +/// # async_std::task::block_on(async { /// # /// use async_std::sync::{Arc, Barrier}; /// use async_std::task; @@ -30,7 +30,6 @@ use crate::sync::Mutex; /// handle.await; /// } /// # }); -/// # } /// ``` #[cfg(feature = "unstable")] #[cfg_attr(feature = "docs", doc(cfg(unstable)))] @@ -120,7 +119,7 @@ impl Barrier { /// # Examples /// /// ``` - /// # fn main() { async_std::task::block_on(async { + /// # async_std::task::block_on(async { /// # /// use async_std::sync::{Arc, Barrier}; /// use async_std::task; @@ -142,7 +141,6 @@ impl Barrier { /// handle.await; /// } /// # }); - /// # } /// ``` pub async fn wait(&self) -> BarrierWaitResult { let mut lock = self.state.lock().await; @@ -190,7 +188,7 @@ impl BarrierWaitResult { /// # Examples /// /// ``` - /// # fn main() { async_std::task::block_on(async { + /// # async_std::task::block_on(async { /// # /// use async_std::sync::Barrier; /// @@ -198,7 +196,6 @@ impl BarrierWaitResult { /// let barrier_wait_result = barrier.wait().await; /// println!("{:?}", barrier_wait_result.is_leader()); /// # }); - /// # } /// ``` pub fn is_leader(&self) -> bool { self.0 diff --git a/src/sync/channel.rs b/src/sync/channel.rs new file mode 100644 index 0000000..6751417 --- /dev/null +++ b/src/sync/channel.rs @@ -0,0 +1,1132 @@ +use std::cell::UnsafeCell; +use std::fmt; +use std::future::Future; +use std::isize; +use std::marker::PhantomData; +use std::mem; +use std::ops::{Deref, DerefMut}; +use std::pin::Pin; +use std::process; +use std::ptr; +use std::sync::atomic::{self, AtomicBool, AtomicUsize, Ordering}; +use std::sync::Arc; +use std::task::{Context, Poll, Waker}; + +use crossbeam_utils::{Backoff, CachePadded}; +use futures_core::stream::Stream; +use slab::Slab; + +/// Creates a bounded multi-producer multi-consumer channel. +/// +/// This channel has a buffer that can hold at most `cap` messages at a time. +/// +/// Senders and receivers can be cloned. When all senders associated with a channel get dropped, it +/// becomes closed. Receive operations on a closed and empty channel return `None` instead of +/// blocking. +/// +/// # Panics +/// +/// If `cap` is zero, this function will panic. +/// +/// # Examples +/// +/// ``` +/// # async_std::task::block_on(async { +/// # +/// use std::time::Duration; +/// +/// use async_std::sync::channel; +/// use async_std::task; +/// +/// let (s, r) = channel(1); +/// +/// // This call returns immediately because there is enough space in the channel. +/// s.send(1).await; +/// +/// task::spawn(async move { +/// // This call blocks the current task because the channel is full. +/// // It will be able to complete only after the first message is received. +/// s.send(2).await; +/// }); +/// +/// task::sleep(Duration::from_secs(1)).await; +/// assert_eq!(r.recv().await, Some(1)); +/// assert_eq!(r.recv().await, Some(2)); +/// # +/// # }) +/// ``` +#[cfg(feature = "unstable")] +#[cfg_attr(feature = "docs", doc(cfg(unstable)))] +pub fn channel(cap: usize) -> (Sender, Receiver) { + let channel = Arc::new(Channel::with_capacity(cap)); + let s = Sender { + channel: channel.clone(), + }; + let r = Receiver { + channel, + opt_key: None, + }; + (s, r) +} + +/// The sending side of a channel. +/// +/// # Examples +/// +/// ``` +/// # async_std::task::block_on(async { +/// # +/// use async_std::sync::channel; +/// use async_std::task; +/// +/// let (s1, r) = channel(100); +/// let s2 = s1.clone(); +/// +/// task::spawn(async move { s1.send(1).await }); +/// task::spawn(async move { s2.send(2).await }); +/// +/// let msg1 = r.recv().await.unwrap(); +/// let msg2 = r.recv().await.unwrap(); +/// +/// assert_eq!(msg1 + msg2, 3); +/// # +/// # }) +/// ``` +#[cfg(feature = "unstable")] +#[cfg_attr(feature = "docs", doc(cfg(unstable)))] +pub struct Sender { + /// The inner channel. + channel: Arc>, +} + +impl Sender { + /// Sends a message into the channel. + /// + /// If the channel is full, this method will block the current task until the send operation + /// can proceed. + /// + /// # Examples + /// + /// ``` + /// # async_std::task::block_on(async { + /// # + /// use async_std::sync::channel; + /// use async_std::task; + /// + /// let (s, r) = channel(1); + /// + /// task::spawn(async move { + /// s.send(1).await; + /// s.send(2).await; + /// }); + /// + /// assert_eq!(r.recv().await, Some(1)); + /// assert_eq!(r.recv().await, Some(2)); + /// assert_eq!(r.recv().await, None); + /// # + /// # }) + /// ``` + pub async fn send(&self, msg: T) { + struct SendFuture<'a, T> { + sender: &'a Sender, + msg: Option, + opt_key: Option, + } + + impl Unpin for SendFuture<'_, T> {} + + impl Future for SendFuture<'_, T> { + type Output = (); + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let msg = self.msg.take().unwrap(); + + // Try sending the message. + let poll = match self.sender.channel.push(msg) { + Ok(()) => Poll::Ready(()), + Err(PushError::Disconnected(msg)) => { + self.msg = Some(msg); + Poll::Pending + } + Err(PushError::Full(msg)) => { + // Register the current task. + match self.opt_key { + None => self.opt_key = Some(self.sender.channel.sends.register(cx)), + Some(key) => self.sender.channel.sends.reregister(key, cx), + } + + // Try sending the message again. + match self.sender.channel.push(msg) { + Ok(()) => Poll::Ready(()), + Err(PushError::Disconnected(msg)) | Err(PushError::Full(msg)) => { + self.msg = Some(msg); + Poll::Pending + } + } + } + }; + + if poll.is_ready() { + // If the current task was registered, unregister now. + if let Some(key) = self.opt_key.take() { + // `true` means the send operation is completed. + self.sender.channel.sends.unregister(key, true); + } + } + + poll + } + } + + impl Drop for SendFuture<'_, T> { + fn drop(&mut self) { + // If the current task was registered, unregister now. + if let Some(key) = self.opt_key { + // `false` means the send operation is cancelled. + self.sender.channel.sends.unregister(key, false); + } + } + } + + SendFuture { + sender: self, + msg: Some(msg), + opt_key: None, + } + .await + } + + /// Returns the channel capacity. + /// + /// # Examples + /// + /// ``` + /// use async_std::sync::channel; + /// + /// let (s, _) = channel::(5); + /// assert_eq!(s.capacity(), 5); + /// ``` + pub fn capacity(&self) -> usize { + self.channel.cap + } + + /// Returns `true` if the channel is empty. + /// + /// # Examples + /// + /// ``` + /// # async_std::task::block_on(async { + /// # + /// use async_std::sync::channel; + /// + /// let (s, r) = channel(1); + /// + /// assert!(s.is_empty()); + /// s.send(0).await; + /// assert!(!s.is_empty()); + /// # + /// # }) + /// ``` + pub fn is_empty(&self) -> bool { + self.channel.is_empty() + } + + /// Returns `true` if the channel is full. + /// + /// # Examples + /// + /// ``` + /// # async_std::task::block_on(async { + /// # + /// use async_std::sync::channel; + /// + /// let (s, r) = channel(1); + /// + /// assert!(!s.is_full()); + /// s.send(0).await; + /// assert!(s.is_full()); + /// # + /// # }) + /// ``` + pub fn is_full(&self) -> bool { + self.channel.is_full() + } + + /// Returns the number of messages in the channel. + /// + /// # Examples + /// + /// ``` + /// # async_std::task::block_on(async { + /// # + /// use async_std::sync::channel; + /// + /// let (s, r) = channel(2); + /// assert_eq!(s.len(), 0); + /// + /// s.send(1).await; + /// s.send(2).await; + /// assert_eq!(s.len(), 2); + /// # + /// # }) + /// ``` + pub fn len(&self) -> usize { + self.channel.len() + } +} + +impl Drop for Sender { + fn drop(&mut self) { + // Decrement the sender count and disconnect the channel if it drops down to zero. + if self.channel.sender_count.fetch_sub(1, Ordering::AcqRel) == 1 { + self.channel.disconnect(); + } + } +} + +impl Clone for Sender { + fn clone(&self) -> Sender { + let count = self.channel.sender_count.fetch_add(1, Ordering::Relaxed); + + // Make sure the count never overflows, even if lots of sender clones are leaked. + if count > isize::MAX as usize { + process::abort(); + } + + Sender { + channel: self.channel.clone(), + } + } +} + +impl fmt::Debug for Sender { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.pad("Sender { .. }") + } +} + +/// The receiving side of a channel. +/// +/// This type implements the [`Stream`] trait, which means it can act as an asynchronous iterator. +/// +/// [`Stream`]: ../stream/trait.Stream.html +/// +/// # Examples +/// +/// ``` +/// # async_std::task::block_on(async { +/// # +/// use std::time::Duration; +/// +/// use async_std::sync::channel; +/// use async_std::task; +/// +/// let (s, r) = channel(100); +/// +/// task::spawn(async move { +/// s.send(1).await; +/// task::sleep(Duration::from_secs(1)).await; +/// s.send(2).await; +/// }); +/// +/// assert_eq!(r.recv().await, Some(1)); // Received immediately. +/// assert_eq!(r.recv().await, Some(2)); // Received after 1 second. +/// # +/// # }) +/// ``` +#[cfg(feature = "unstable")] +#[cfg_attr(feature = "docs", doc(cfg(unstable)))] +pub struct Receiver { + /// The inner channel. + channel: Arc>, + + /// The registration key for this receiver in the `channel.streams` registry. + opt_key: Option, +} + +impl Receiver { + /// Receives a message from the channel. + /// + /// If the channel is empty and it still has senders, this method will block the current task + /// until the receive operation can proceed. If the channel is empty and there are no more + /// senders, this method returns `None`. + /// + /// # Examples + /// + /// ``` + /// # async_std::task::block_on(async { + /// # + /// use async_std::sync::channel; + /// use async_std::task; + /// + /// let (s, r) = channel(1); + /// + /// task::spawn(async move { + /// s.send(1).await; + /// s.send(2).await; + /// }); + /// + /// assert_eq!(r.recv().await, Some(1)); + /// assert_eq!(r.recv().await, Some(2)); + /// assert_eq!(r.recv().await, None); + /// # + /// # }) + /// ``` + pub async fn recv(&self) -> Option { + struct RecvFuture<'a, T> { + channel: &'a Channel, + opt_key: Option, + } + + impl Future for RecvFuture<'_, T> { + type Output = Option; + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + poll_recv(&self.channel, &self.channel.recvs, &mut self.opt_key, cx) + } + } + + impl Drop for RecvFuture<'_, T> { + fn drop(&mut self) { + // If the current task was registered, unregister now. + if let Some(key) = self.opt_key { + // `false` means the receive operation is cancelled. + self.channel.recvs.unregister(key, false); + } + } + } + + RecvFuture { + channel: &self.channel, + opt_key: None, + } + .await + } + + /// Returns the channel capacity. + /// + /// # Examples + /// + /// ``` + /// use async_std::sync::channel; + /// + /// let (_, r) = channel::(5); + /// assert_eq!(r.capacity(), 5); + /// ``` + pub fn capacity(&self) -> usize { + self.channel.cap + } + + /// Returns `true` if the channel is empty. + /// + /// # Examples + /// + /// ``` + /// # async_std::task::block_on(async { + /// # + /// use async_std::sync::channel; + /// + /// let (s, r) = channel(1); + /// + /// assert!(r.is_empty()); + /// s.send(0).await; + /// assert!(!r.is_empty()); + /// # + /// # }) + /// ``` + pub fn is_empty(&self) -> bool { + self.channel.is_empty() + } + + /// Returns `true` if the channel is full. + /// + /// # Examples + /// + /// ``` + /// # async_std::task::block_on(async { + /// # + /// use async_std::sync::channel; + /// + /// let (s, r) = channel(1); + /// + /// assert!(!r.is_full()); + /// s.send(0).await; + /// assert!(r.is_full()); + /// # + /// # }) + /// ``` + pub fn is_full(&self) -> bool { + self.channel.is_full() + } + + /// Returns the number of messages in the channel. + /// + /// # Examples + /// + /// ``` + /// # async_std::task::block_on(async { + /// # + /// use async_std::sync::channel; + /// + /// let (s, r) = channel(2); + /// assert_eq!(r.len(), 0); + /// + /// s.send(1).await; + /// s.send(2).await; + /// assert_eq!(r.len(), 2); + /// # + /// # }) + /// ``` + pub fn len(&self) -> usize { + self.channel.len() + } +} + +impl Drop for Receiver { + fn drop(&mut self) { + // If the current task was registered as blocked on this stream, unregister now. + if let Some(key) = self.opt_key { + // `false` means the last request for a stream item is cancelled. + self.channel.streams.unregister(key, false); + } + + // Decrement the receiver count and disconnect the channel if it drops down to zero. + if self.channel.receiver_count.fetch_sub(1, Ordering::AcqRel) == 1 { + self.channel.disconnect(); + } + } +} + +impl Clone for Receiver { + fn clone(&self) -> Receiver { + let count = self.channel.receiver_count.fetch_add(1, Ordering::Relaxed); + + // Make sure the count never overflows, even if lots of receiver clones are leaked. + if count > isize::MAX as usize { + process::abort(); + } + + Receiver { + channel: self.channel.clone(), + opt_key: None, + } + } +} + +impl Stream for Receiver { + type Item = T; + + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let this = &mut *self; + poll_recv(&this.channel, &this.channel.streams, &mut this.opt_key, cx) + } +} + +impl fmt::Debug for Receiver { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.pad("Receiver { .. }") + } +} + +/// Polls a receive operation on a channel. +/// +/// If the receive operation is blocked, the current task will be registered in `registry` and its +/// registration key will then be stored in `opt_key`. +fn poll_recv( + channel: &Channel, + registry: &Registry, + opt_key: &mut Option, + cx: &mut Context<'_>, +) -> Poll> { + // Try receiving a message. + let poll = match channel.pop() { + Ok(msg) => Poll::Ready(Some(msg)), + Err(PopError::Disconnected) => Poll::Ready(None), + Err(PopError::Empty) => { + // Register the current task. + match *opt_key { + None => *opt_key = Some(registry.register(cx)), + Some(key) => registry.reregister(key, cx), + } + + // Try receiving a message again. + match channel.pop() { + Ok(msg) => Poll::Ready(Some(msg)), + Err(PopError::Disconnected) => Poll::Ready(None), + Err(PopError::Empty) => Poll::Pending, + } + } + }; + + if poll.is_ready() { + // If the current task was registered, unregister now. + if let Some(key) = opt_key.take() { + // `true` means the receive operation is completed. + registry.unregister(key, true); + } + } + + poll +} + +/// A slot in a channel. +struct Slot { + /// The current stamp. + stamp: AtomicUsize, + + /// The message in this slot. + msg: UnsafeCell, +} + +/// Bounded channel based on a preallocated array. +struct Channel { + /// The head of the channel. + /// + /// This value is a "stamp" consisting of an index into the buffer, a mark bit, and a lap, but + /// packed into a single `usize`. The lower bits represent the index, while the upper bits + /// represent the lap. The mark bit in the head is always zero. + /// + /// Messages are popped from the head of the channel. + head: CachePadded, + + /// The tail of the channel. + /// + /// This value is a "stamp" consisting of an index into the buffer, a mark bit, and a lap, but + /// packed into a single `usize`. The lower bits represent the index, while the upper bits + /// represent the lap. The mark bit indicates that the channel is disconnected. + /// + /// Messages are pushed into the tail of the channel. + tail: CachePadded, + + /// The buffer holding slots. + buffer: *mut Slot, + + /// The channel capacity. + cap: usize, + + /// A stamp with the value of `{ lap: 1, mark: 0, index: 0 }`. + one_lap: usize, + + /// If this bit is set in the tail, that means either all senders were dropped or all receivers + /// were dropped. + mark_bit: usize, + + /// Send operations waiting while the channel is full. + sends: Registry, + + /// Receive operations waiting while the channel is empty and not disconnected. + recvs: Registry, + + /// Streams waiting while the channel is empty and not disconnected. + streams: Registry, + + /// The number of currently active `Sender`s. + sender_count: AtomicUsize, + + /// The number of currently active `Receivers`s. + receiver_count: AtomicUsize, + + /// Indicates that dropping a `Channel` may drop values of type `T`. + _marker: PhantomData, +} + +unsafe impl Send for Channel {} +unsafe impl Sync for Channel {} +impl Unpin for Channel {} + +impl Channel { + /// Creates a bounded channel of capacity `cap`. + fn with_capacity(cap: usize) -> Self { + assert!(cap > 0, "capacity must be positive"); + + // Compute constants `mark_bit` and `one_lap`. + let mark_bit = (cap + 1).next_power_of_two(); + let one_lap = mark_bit * 2; + + // Head is initialized to `{ lap: 0, mark: 0, index: 0 }`. + let head = 0; + // Tail is initialized to `{ lap: 0, mark: 0, index: 0 }`. + let tail = 0; + + // Allocate a buffer of `cap` slots. + let buffer = { + let mut v = Vec::>::with_capacity(cap); + let ptr = v.as_mut_ptr(); + mem::forget(v); + ptr + }; + + // Initialize stamps in the slots. + for i in 0..cap { + unsafe { + // Set the stamp to `{ lap: 0, mark: 0, index: i }`. + let slot = buffer.add(i); + ptr::write(&mut (*slot).stamp, AtomicUsize::new(i)); + } + } + + Channel { + buffer, + cap, + one_lap, + mark_bit, + head: CachePadded::new(AtomicUsize::new(head)), + tail: CachePadded::new(AtomicUsize::new(tail)), + sends: Registry::new(), + recvs: Registry::new(), + streams: Registry::new(), + sender_count: AtomicUsize::new(1), + receiver_count: AtomicUsize::new(1), + _marker: PhantomData, + } + } + + /// Attempts to push a message. + fn push(&self, msg: T) -> Result<(), PushError> { + let backoff = Backoff::new(); + let mut tail = self.tail.load(Ordering::Relaxed); + + loop { + // Deconstruct the tail. + let index = tail & (self.mark_bit - 1); + let lap = tail & !(self.one_lap - 1); + + // Inspect the corresponding slot. + let slot = unsafe { &*self.buffer.add(index) }; + let stamp = slot.stamp.load(Ordering::Acquire); + + // If the tail and the stamp match, we may attempt to push. + if tail == stamp { + let new_tail = if index + 1 < self.cap { + // Same lap, incremented index. + // Set to `{ lap: lap, mark: 0, index: index + 1 }`. + tail + 1 + } else { + // One lap forward, index wraps around to zero. + // Set to `{ lap: lap.wrapping_add(1), mark: 0, index: 0 }`. + lap.wrapping_add(self.one_lap) + }; + + // Try moving the tail. + match self.tail.compare_exchange_weak( + tail, + new_tail, + Ordering::SeqCst, + Ordering::Relaxed, + ) { + Ok(_) => { + // Write the message into the slot and update the stamp. + unsafe { slot.msg.get().write(msg) }; + let stamp = tail + 1; + slot.stamp.store(stamp, Ordering::Release); + + // Wake a blocked receive operation. + self.recvs.notify_one(); + + // Wake all blocked streams. + self.streams.notify_all(); + + return Ok(()); + } + Err(t) => { + tail = t; + backoff.spin(); + } + } + } else if stamp.wrapping_add(self.one_lap) == tail + 1 { + atomic::fence(Ordering::SeqCst); + let head = self.head.load(Ordering::Relaxed); + + // If the head lags one lap behind the tail as well... + if head.wrapping_add(self.one_lap) == tail { + // ...then the channel is full. + + // Check if the channel is disconnected. + if tail & self.mark_bit != 0 { + return Err(PushError::Disconnected(msg)); + } else { + return Err(PushError::Full(msg)); + } + } + + backoff.spin(); + tail = self.tail.load(Ordering::Relaxed); + } else { + // Snooze because we need to wait for the stamp to get updated. + backoff.snooze(); + tail = self.tail.load(Ordering::Relaxed); + } + } + } + + /// Attempts to pop a message. + fn pop(&self) -> Result { + let backoff = Backoff::new(); + let mut head = self.head.load(Ordering::Relaxed); + + loop { + // Deconstruct the head. + let index = head & (self.mark_bit - 1); + let lap = head & !(self.one_lap - 1); + + // Inspect the corresponding slot. + let slot = unsafe { &*self.buffer.add(index) }; + let stamp = slot.stamp.load(Ordering::Acquire); + + // If the the stamp is ahead of the head by 1, we may attempt to pop. + if head + 1 == stamp { + let new = if index + 1 < self.cap { + // Same lap, incremented index. + // Set to `{ lap: lap, mark: 0, index: index + 1 }`. + head + 1 + } else { + // One lap forward, index wraps around to zero. + // Set to `{ lap: lap.wrapping_add(1), mark: 0, index: 0 }`. + lap.wrapping_add(self.one_lap) + }; + + // Try moving the head. + match self.head.compare_exchange_weak( + head, + new, + Ordering::SeqCst, + Ordering::Relaxed, + ) { + Ok(_) => { + // Read the message from the slot and update the stamp. + let msg = unsafe { slot.msg.get().read() }; + let stamp = head.wrapping_add(self.one_lap); + slot.stamp.store(stamp, Ordering::Release); + + // Wake a blocked send operation. + self.sends.notify_one(); + + return Ok(msg); + } + Err(h) => { + head = h; + backoff.spin(); + } + } + } else if stamp == head { + atomic::fence(Ordering::SeqCst); + let tail = self.tail.load(Ordering::Relaxed); + + // If the tail equals the head, that means the channel is empty. + if (tail & !self.mark_bit) == head { + // If the channel is disconnected... + if tail & self.mark_bit != 0 { + return Err(PopError::Disconnected); + } else { + // Otherwise, the receive operation is not ready. + return Err(PopError::Empty); + } + } + + backoff.spin(); + head = self.head.load(Ordering::Relaxed); + } else { + // Snooze because we need to wait for the stamp to get updated. + backoff.snooze(); + head = self.head.load(Ordering::Relaxed); + } + } + } + + /// Returns the current number of messages inside the channel. + fn len(&self) -> usize { + loop { + // Load the tail, then load the head. + let tail = self.tail.load(Ordering::SeqCst); + let head = self.head.load(Ordering::SeqCst); + + // If the tail didn't change, we've got consistent values to work with. + if self.tail.load(Ordering::SeqCst) == tail { + let hix = head & (self.mark_bit - 1); + let tix = tail & (self.mark_bit - 1); + + return if hix < tix { + tix - hix + } else if hix > tix { + self.cap - hix + tix + } else if (tail & !self.mark_bit) == head { + 0 + } else { + self.cap + }; + } + } + } + + /// Returns `true` if the channel is empty. + fn is_empty(&self) -> bool { + let head = self.head.load(Ordering::SeqCst); + let tail = self.tail.load(Ordering::SeqCst); + + // Is the tail equal to the head? + // + // Note: If the head changes just before we load the tail, that means there was a moment + // when the channel was not empty, so it is safe to just return `false`. + (tail & !self.mark_bit) == head + } + + /// Returns `true` if the channel is full. + fn is_full(&self) -> bool { + let tail = self.tail.load(Ordering::SeqCst); + let head = self.head.load(Ordering::SeqCst); + + // Is the head lagging one lap behind tail? + // + // Note: If the tail changes just before we load the head, that means there was a moment + // when the channel was not full, so it is safe to just return `false`. + head.wrapping_add(self.one_lap) == tail & !self.mark_bit + } + + /// Disconnects the channel and wakes up all blocked operations. + fn disconnect(&self) { + let tail = self.tail.fetch_or(self.mark_bit, Ordering::SeqCst); + + if tail & self.mark_bit == 0 { + // Notify everyone blocked on this channel. + self.sends.notify_all(); + self.recvs.notify_all(); + self.streams.notify_all(); + } + } +} + +impl Drop for Channel { + fn drop(&mut self) { + // Get the index of the head. + let hix = self.head.load(Ordering::Relaxed) & (self.mark_bit - 1); + + // Loop over all slots that hold a message and drop them. + for i in 0..self.len() { + // Compute the index of the next slot holding a message. + let index = if hix + i < self.cap { + hix + i + } else { + hix + i - self.cap + }; + + unsafe { + self.buffer.add(index).drop_in_place(); + } + } + + // Finally, deallocate the buffer, but don't run any destructors. + unsafe { + Vec::from_raw_parts(self.buffer, 0, self.cap); + } + } +} + +/// An error returned from the `push()` method. +enum PushError { + /// The channel is full but not disconnected. + Full(T), + + /// The channel is full and disconnected. + Disconnected(T), +} + +/// An error returned from the `pop()` method. +enum PopError { + /// The channel is empty but not disconnected. + Empty, + + /// The channel is empty and disconnected. + Disconnected, +} + +/// A list of blocked channel operations. +struct Blocked { + /// A list of registered channel operations. + /// + /// Each entry has a waker associated with the task that is executing the operation. If the + /// waker is set to `None`, that means the task has been woken up but hasn't removed itself + /// from the registry yet. + entries: Slab>, + + /// The number of wakers in the entry list. + waker_count: usize, +} + +/// A registry of blocked channel operations. +/// +/// Blocked operations register themselves in a registry. Successful operations on the opposite +/// side of the channel wake blocked operations in the registry. +struct Registry { + /// A list of blocked channel operations. + blocked: Spinlock, + + /// Set to `true` if there are no wakers in the registry. + /// + /// Note that this either means there are no entries in the registry, or that all entries have + /// been notified. + is_empty: AtomicBool, +} + +impl Registry { + /// Creates a new registry. + fn new() -> Registry { + Registry { + blocked: Spinlock::new(Blocked { + entries: Slab::new(), + waker_count: 0, + }), + is_empty: AtomicBool::new(true), + } + } + + /// Registers a blocked channel operation and returns a key associated with it. + fn register(&self, cx: &Context<'_>) -> usize { + let mut blocked = self.blocked.lock(); + + // Insert a new entry into the list of blocked tasks. + let w = cx.waker().clone(); + let key = blocked.entries.insert(Some(w)); + + blocked.waker_count += 1; + if blocked.waker_count == 1 { + self.is_empty.store(false, Ordering::SeqCst); + } + + key + } + + /// Re-registers a blocked channel operation by filling in its waker. + fn reregister(&self, key: usize, cx: &Context<'_>) { + let mut blocked = self.blocked.lock(); + + let was_none = blocked.entries[key].is_none(); + let w = cx.waker().clone(); + blocked.entries[key] = Some(w); + + if was_none { + blocked.waker_count += 1; + if blocked.waker_count == 1 { + self.is_empty.store(false, Ordering::SeqCst); + } + } + } + + /// Unregisters a channel operation. + /// + /// If `completed` is `true`, the operation will be removed from the registry. If `completed` + /// is `false`, that means the operation was cancelled so another one will be notified. + fn unregister(&self, key: usize, completed: bool) { + let mut blocked = self.blocked.lock(); + let mut removed = false; + + match blocked.entries.remove(key) { + Some(_) => removed = true, + None => { + if !completed { + // This operation was cancelled. Notify another one. + if let Some((_, opt_waker)) = blocked.entries.iter_mut().next() { + if let Some(w) = opt_waker.take() { + w.wake(); + removed = true; + } + } + } + } + } + + if removed { + blocked.waker_count -= 1; + if blocked.waker_count == 0 { + self.is_empty.store(true, Ordering::SeqCst); + } + } + } + + /// Notifies one blocked channel operation. + #[inline] + fn notify_one(&self) { + if !self.is_empty.load(Ordering::SeqCst) { + let mut blocked = self.blocked.lock(); + + if let Some((_, opt_waker)) = blocked.entries.iter_mut().next() { + // If there is no waker in this entry, that means it was already woken. + if let Some(w) = opt_waker.take() { + w.wake(); + + blocked.waker_count -= 1; + if blocked.waker_count == 0 { + self.is_empty.store(true, Ordering::SeqCst); + } + } + } + } + } + + /// Notifies all blocked channel operations. + #[inline] + fn notify_all(&self) { + if !self.is_empty.load(Ordering::SeqCst) { + let mut blocked = self.blocked.lock(); + + for (_, opt_waker) in blocked.entries.iter_mut() { + // If there is no waker in this entry, that means it was already woken. + if let Some(w) = opt_waker.take() { + w.wake(); + } + } + + blocked.waker_count = 0; + self.is_empty.store(true, Ordering::SeqCst); + } + } +} + +/// A simple spinlock. +struct Spinlock { + flag: AtomicBool, + value: UnsafeCell, +} + +impl Spinlock { + /// Returns a new spinlock initialized with `value`. + fn new(value: T) -> Spinlock { + Spinlock { + flag: AtomicBool::new(false), + value: UnsafeCell::new(value), + } + } + + /// Locks the spinlock. + fn lock(&self) -> SpinlockGuard<'_, T> { + let backoff = Backoff::new(); + while self.flag.swap(true, Ordering::Acquire) { + backoff.snooze(); + } + SpinlockGuard { parent: self } + } +} + +/// A guard holding a spinlock locked. +struct SpinlockGuard<'a, T> { + parent: &'a Spinlock, +} + +impl<'a, T> Drop for SpinlockGuard<'a, T> { + fn drop(&mut self) { + self.parent.flag.store(false, Ordering::Release); + } +} + +impl<'a, T> Deref for SpinlockGuard<'a, T> { + type Target = T; + + fn deref(&self) -> &T { + unsafe { &*self.parent.value.get() } + } +} + +impl<'a, T> DerefMut for SpinlockGuard<'a, T> { + fn deref_mut(&mut self) -> &mut T { + unsafe { &mut *self.parent.value.get() } + } +} diff --git a/src/sync/mod.rs b/src/sync/mod.rs index be74d8f..3ad2776 100644 --- a/src/sync/mod.rs +++ b/src/sync/mod.rs @@ -40,5 +40,8 @@ mod rwlock; cfg_unstable! { pub use barrier::{Barrier, BarrierWaitResult}; + pub use channel::{channel, Sender, Receiver}; + mod barrier; + mod channel; } diff --git a/src/task/block_on.rs b/src/task/block_on.rs index 032bf02..b0adc38 100644 --- a/src/task/block_on.rs +++ b/src/task/block_on.rs @@ -7,6 +7,7 @@ use std::task::{RawWaker, RawWakerVTable}; use std::thread; use crossbeam_utils::sync::Parker; +use pin_project_lite::pin_project; use super::task; use super::task_local; @@ -100,19 +101,18 @@ where } } -struct CatchUnwindFuture { - future: F, -} - -impl CatchUnwindFuture { - pin_utils::unsafe_pinned!(future: F); +pin_project! { + struct CatchUnwindFuture { + #[pin] + future: F, + } } impl Future for CatchUnwindFuture { type Output = thread::Result; fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - panic::catch_unwind(AssertUnwindSafe(|| self.future().poll(cx)))?.map(Ok) + panic::catch_unwind(AssertUnwindSafe(|| self.project().future.poll(cx)))?.map(Ok) } } diff --git a/src/task/yield_now.rs b/src/task/yield_now.rs index 2a4788d..c11408a 100644 --- a/src/task/yield_now.rs +++ b/src/task/yield_now.rs @@ -18,13 +18,13 @@ use std::pin::Pin; /// Basic usage: /// /// ``` -/// # fn main() { async_std::task::block_on(async { +/// # async_std::task::block_on(async { /// # /// use async_std::task; /// /// task::yield_now().await; /// # -/// # }) } +/// # }) /// ``` #[cfg(feature = "unstable")] #[cfg_attr(feature = "docs", doc(cfg(unstable)))] diff --git a/tests/channel.rs b/tests/channel.rs new file mode 100644 index 0000000..91622b0 --- /dev/null +++ b/tests/channel.rs @@ -0,0 +1,350 @@ +use std::sync::atomic::{AtomicUsize, Ordering}; +use std::sync::Arc; +use std::time::Duration; + +use async_std::sync::channel; +use async_std::task; +use rand::{thread_rng, Rng}; + +fn ms(ms: u64) -> Duration { + Duration::from_millis(ms) +} + +#[test] +fn smoke() { + task::block_on(async { + let (s, r) = channel(1); + + s.send(7).await; + assert_eq!(r.recv().await, Some(7)); + + s.send(8).await; + assert_eq!(r.recv().await, Some(8)); + + drop(s); + assert_eq!(r.recv().await, None); + }) +} + +#[test] +fn capacity() { + for i in 1..10 { + let (s, r) = channel::<()>(i); + assert_eq!(s.capacity(), i); + assert_eq!(r.capacity(), i); + } +} + +#[test] +fn len_empty_full() { + task::block_on(async { + let (s, r) = channel(2); + + assert_eq!(s.len(), 0); + assert_eq!(s.is_empty(), true); + assert_eq!(s.is_full(), false); + assert_eq!(r.len(), 0); + assert_eq!(r.is_empty(), true); + assert_eq!(r.is_full(), false); + + s.send(()).await; + + assert_eq!(s.len(), 1); + assert_eq!(s.is_empty(), false); + assert_eq!(s.is_full(), false); + assert_eq!(r.len(), 1); + assert_eq!(r.is_empty(), false); + assert_eq!(r.is_full(), false); + + s.send(()).await; + + assert_eq!(s.len(), 2); + assert_eq!(s.is_empty(), false); + assert_eq!(s.is_full(), true); + assert_eq!(r.len(), 2); + assert_eq!(r.is_empty(), false); + assert_eq!(r.is_full(), true); + + r.recv().await; + + assert_eq!(s.len(), 1); + assert_eq!(s.is_empty(), false); + assert_eq!(s.is_full(), false); + assert_eq!(r.len(), 1); + assert_eq!(r.is_empty(), false); + assert_eq!(r.is_full(), false); + }) +} + +#[test] +fn recv() { + task::block_on(async { + let (s, r) = channel(100); + + task::spawn(async move { + assert_eq!(r.recv().await, Some(7)); + task::sleep(ms(1000)).await; + assert_eq!(r.recv().await, Some(8)); + task::sleep(ms(1000)).await; + assert_eq!(r.recv().await, Some(9)); + assert_eq!(r.recv().await, None); + }); + + task::sleep(ms(1500)).await; + s.send(7).await; + s.send(8).await; + s.send(9).await; + }) +} + +#[test] +fn send() { + task::block_on(async { + let (s, r) = channel(1); + + task::spawn(async move { + s.send(7).await; + task::sleep(ms(1000)).await; + s.send(8).await; + task::sleep(ms(1000)).await; + s.send(9).await; + task::sleep(ms(1000)).await; + s.send(10).await; + }); + + task::sleep(ms(1500)).await; + assert_eq!(r.recv().await, Some(7)); + assert_eq!(r.recv().await, Some(8)); + assert_eq!(r.recv().await, Some(9)); + }) +} + +#[test] +fn recv_after_disconnect() { + task::block_on(async { + let (s, r) = channel(100); + + s.send(1).await; + s.send(2).await; + s.send(3).await; + + drop(s); + + assert_eq!(r.recv().await, Some(1)); + assert_eq!(r.recv().await, Some(2)); + assert_eq!(r.recv().await, Some(3)); + assert_eq!(r.recv().await, None); + }) +} + +#[test] +fn len() { + const COUNT: usize = 25_000; + const CAP: usize = 1000; + + task::block_on(async { + let (s, r) = channel(CAP); + + assert_eq!(s.len(), 0); + assert_eq!(r.len(), 0); + + for _ in 0..CAP / 10 { + for i in 0..50 { + s.send(i).await; + assert_eq!(s.len(), i + 1); + } + + for i in 0..50 { + r.recv().await; + assert_eq!(r.len(), 50 - i - 1); + } + } + + assert_eq!(s.len(), 0); + assert_eq!(r.len(), 0); + + for i in 0..CAP { + s.send(i).await; + assert_eq!(s.len(), i + 1); + } + + for _ in 0..CAP { + r.recv().await.unwrap(); + } + + assert_eq!(s.len(), 0); + assert_eq!(r.len(), 0); + + let child = task::spawn({ + let r = r.clone(); + async move { + for i in 0..COUNT { + assert_eq!(r.recv().await, Some(i)); + let len = r.len(); + assert!(len <= CAP); + } + } + }); + + for i in 0..COUNT { + s.send(i).await; + let len = s.len(); + assert!(len <= CAP); + } + + child.await; + + assert_eq!(s.len(), 0); + assert_eq!(r.len(), 0); + }) +} + +#[test] +fn disconnect_wakes_receiver() { + task::block_on(async { + let (s, r) = channel::<()>(1); + + let child = task::spawn(async move { + assert_eq!(r.recv().await, None); + }); + + task::sleep(ms(1000)).await; + drop(s); + + child.await; + }) +} + +#[test] +fn spsc() { + const COUNT: usize = 100_000; + + task::block_on(async { + let (s, r) = channel(3); + + let child = task::spawn(async move { + for i in 0..COUNT { + assert_eq!(r.recv().await, Some(i)); + } + assert_eq!(r.recv().await, None); + }); + + for i in 0..COUNT { + s.send(i).await; + } + drop(s); + + child.await; + }) +} + +#[test] +fn mpmc() { + const COUNT: usize = 25_000; + const TASKS: usize = 4; + + task::block_on(async { + let (s, r) = channel::(3); + let v = (0..COUNT).map(|_| AtomicUsize::new(0)).collect::>(); + let v = Arc::new(v); + + let mut tasks = Vec::new(); + + for _ in 0..TASKS { + let r = r.clone(); + let v = v.clone(); + tasks.push(task::spawn(async move { + for _ in 0..COUNT { + let n = r.recv().await.unwrap(); + v[n].fetch_add(1, Ordering::SeqCst); + } + })); + } + + for _ in 0..TASKS { + let s = s.clone(); + tasks.push(task::spawn(async move { + for i in 0..COUNT { + s.send(i).await; + } + })); + } + + for t in tasks { + t.await; + } + + for c in v.iter() { + assert_eq!(c.load(Ordering::SeqCst), TASKS); + } + }); +} + +#[test] +fn oneshot() { + const COUNT: usize = 10_000; + + task::block_on(async { + for _ in 0..COUNT { + let (s, r) = channel(1); + + let c1 = task::spawn(async move { r.recv().await.unwrap() }); + let c2 = task::spawn(async move { s.send(0).await }); + + c1.await; + c2.await; + } + }) +} + +#[test] +fn drops() { + const RUNS: usize = 100; + + static DROPS: AtomicUsize = AtomicUsize::new(0); + + #[derive(Debug, PartialEq)] + struct DropCounter; + + impl Drop for DropCounter { + fn drop(&mut self) { + DROPS.fetch_add(1, Ordering::SeqCst); + } + } + + let mut rng = thread_rng(); + + for _ in 0..RUNS { + task::block_on(async { + let steps = rng.gen_range(0, 10_000); + let additional = rng.gen_range(0, 50); + + DROPS.store(0, Ordering::SeqCst); + let (s, r) = channel::(50); + + let child = task::spawn({ + let r = r.clone(); + async move { + for _ in 0..steps { + r.recv().await.unwrap(); + } + } + }); + + for _ in 0..steps { + s.send(DropCounter).await; + } + + child.await; + + for _ in 0..additional { + s.send(DropCounter).await; + } + + assert_eq!(DROPS.load(Ordering::SeqCst), steps); + drop(s); + drop(r); + assert_eq!(DROPS.load(Ordering::SeqCst), steps + additional); + }) + } +}