From 32068942a6130d12a7152706ae6e24637377339d Mon Sep 17 00:00:00 2001 From: Oleg Nosov Date: Sat, 8 Feb 2020 15:41:33 +0300 Subject: [PATCH] Fixed `flatten` --- src/stream/stream/flat_map.rs | 2 +- src/stream/stream/flatten.rs | 21 +++++--- tests/stream.rs | 96 +++++++++++++++++++---------------- 3 files changed, 68 insertions(+), 51 deletions(-) diff --git a/src/stream/stream/flat_map.rs b/src/stream/stream/flat_map.rs index 8d5a12f..f9ceb86 100644 --- a/src/stream/stream/flat_map.rs +++ b/src/stream/stream/flat_map.rs @@ -69,4 +69,4 @@ where } } } -} +} \ No newline at end of file diff --git a/src/stream/stream/flatten.rs b/src/stream/stream/flatten.rs index 1d6fcae..d0e0d20 100644 --- a/src/stream/stream/flatten.rs +++ b/src/stream/stream/flatten.rs @@ -1,5 +1,5 @@ -use std::fmt; -use std::pin::Pin; +use core::fmt; +use core::pin::Pin; use pin_project_lite::pin_project; @@ -52,14 +52,21 @@ where let mut this = self.project(); loop { if let Some(inner) = this.inner_stream.as_mut().as_pin_mut() { - if let item @ Some(_) = futures_core::ready!(inner.poll_next(cx)) { - return Poll::Ready(item); + let next_item = futures_core::ready!(inner.poll_next(cx)); + + if next_item.is_some() { + return Poll::Ready(next_item); + } else { + this.inner_stream.set(None); } } - match futures_core::ready!(this.stream.as_mut().poll_next(cx)) { - None => return Poll::Ready(None), - Some(inner) => this.inner_stream.set(Some(inner.into_stream())), + let inner = futures_core::ready!(this.stream.as_mut().poll_next(cx)); + + if inner.is_some() { + this.inner_stream.set(inner.map(IntoStream::into_stream)); + } else { + return Poll::Ready(None); } } } diff --git a/tests/stream.rs b/tests/stream.rs index e80fc99..210ceae 100644 --- a/tests/stream.rs +++ b/tests/stream.rs @@ -1,3 +1,5 @@ +use std::convert::identity; +use std::marker::Unpin; use std::pin::Pin; use std::task::{Context, Poll}; @@ -99,58 +101,52 @@ fn merge_works_with_unfused_streams() { assert_eq!(xs, vec![92, 92]); } -#[test] -fn flat_map_doesnt_poll_completed_inner_stream() { - async_std::task::block_on(async { - use async_std::prelude::*; - use async_std::task::*; - use std::convert::identity; - use std::marker::Unpin; - use std::pin::Pin; +struct S(T); - struct S(T); +impl Stream for S { + type Item = T::Item; - impl Stream for S { - type Item = T::Item; + fn poll_next(mut self: Pin<&mut Self>, ctx: &mut Context) -> Poll> { + unsafe { Pin::new_unchecked(&mut self.0) }.poll_next(ctx) + } +} - fn poll_next(mut self: Pin<&mut Self>, ctx: &mut Context) -> Poll> { - unsafe { Pin::new_unchecked(&mut self.0) }.poll_next(ctx) - } - } +struct StrictOnce { + polled: bool, +} - struct StrictOnce { - polled: bool, - }; +impl Stream for StrictOnce { + type Item = (); - impl Stream for StrictOnce { - type Item = (); + fn poll_next(mut self: Pin<&mut Self>, _: &mut Context) -> Poll> { + assert!(!self.polled, "Polled after completion!"); + self.polled = true; + Poll::Ready(None) + } +} - fn poll_next(mut self: Pin<&mut Self>, _: &mut Context) -> Poll> { - assert!(!self.polled, "Polled after completion!"); - self.polled = true; - Poll::Ready(None) - } - } +struct Interchanger { + polled: bool, +} - struct Interchanger { - polled: bool, - }; - - impl Stream for Interchanger { - type Item = S + Unpin>>; - - fn poll_next(mut self: Pin<&mut Self>, ctx: &mut Context) -> Poll> { - if self.polled { - self.polled = false; - ctx.waker().wake_by_ref(); - Poll::Pending - } else { - self.polled = true; - Poll::Ready(Some(S(Box::new(StrictOnce { polled: false })))) - } - } +impl Stream for Interchanger { + type Item = S + Unpin>>; + + fn poll_next(mut self: Pin<&mut Self>, ctx: &mut Context) -> Poll> { + if self.polled { + self.polled = false; + ctx.waker().wake_by_ref(); + Poll::Pending + } else { + self.polled = true; + Poll::Ready(Some(S(Box::new(StrictOnce { polled: false })))) } + } +} +#[test] +fn flat_map_doesnt_poll_completed_inner_stream() { + task::block_on(async { assert_eq!( Interchanger { polled: false } .take(2) @@ -161,3 +157,17 @@ fn flat_map_doesnt_poll_completed_inner_stream() { ); }); } + +#[test] +fn flatten_doesnt_poll_completed_inner_stream() { + task::block_on(async { + assert_eq!( + Interchanger { polled: false } + .take(2) + .flatten() + .count() + .await, + 0 + ); + }); +}