From 303ac90b7c3b22dbc5d04395d485d893ddd58d6c Mon Sep 17 00:00:00 2001 From: Oleg Nosov Date: Fri, 7 Feb 2020 22:09:42 +0300 Subject: [PATCH] Fixed `flat_map` --- src/stream/stream/flat_map.rs | 17 ++++++--- tests/stream.rs | 70 +++++++++++++++++++++++++++++++++++ 2 files changed, 82 insertions(+), 5 deletions(-) diff --git a/src/stream/stream/flat_map.rs b/src/stream/stream/flat_map.rs index 6c828c9..8d5a12f 100644 --- a/src/stream/stream/flat_map.rs +++ b/src/stream/stream/flat_map.rs @@ -51,14 +51,21 @@ where let mut this = self.project(); loop { if let Some(inner) = this.inner_stream.as_mut().as_pin_mut() { - if let item @ Some(_) = futures_core::ready!(inner.poll_next(cx)) { - return Poll::Ready(item); + let next_item = futures_core::ready!(inner.poll_next(cx)); + + if next_item.is_some() { + return Poll::Ready(next_item); + } else { + this.inner_stream.set(None); } } - match futures_core::ready!(this.stream.as_mut().poll_next(cx)) { - None => return Poll::Ready(None), - Some(inner) => this.inner_stream.set(Some(inner.into_stream())), + let inner = futures_core::ready!(this.stream.as_mut().poll_next(cx)); + + if inner.is_some() { + this.inner_stream.set(inner.map(IntoStream::into_stream)); + } else { + return Poll::Ready(None); } } } diff --git a/tests/stream.rs b/tests/stream.rs index 42a6191..75c1b10 100644 --- a/tests/stream.rs +++ b/tests/stream.rs @@ -98,3 +98,73 @@ fn merge_works_with_unfused_streams() { }); assert_eq!(xs, vec![92, 92]); } + +#[test] +fn flat_map_doesnt_poll_completed_inner_stream() { + async_std::task::block_on(async { + use async_std::prelude::*; + use async_std::task::*; + use std::convert::identity; + use std::marker::Unpin; + use std::pin::Pin; + + struct S(T); + + impl Stream for S { + type Item = T::Item; + + fn poll_next(mut self: Pin<&mut Self>, ctx: &mut Context) -> Poll> { + unsafe { Pin::new_unchecked(&mut self.0) }.poll_next(ctx) + } + } + + struct StrictOnce { + polled: bool, + }; + + impl Stream for StrictOnce { + type Item = (); + + fn poll_next(mut self: Pin<&mut Self>, _: &mut Context) -> Poll> { + if !self.polled { + self.polled = true; + Poll::Ready(None) + } else { + panic!("Polled after completion!"); + } + } + } + + struct Interchanger { + polled: bool, + }; + + impl Stream for Interchanger { + type Item = S + Unpin>>; + + fn poll_next(mut self: Pin<&mut Self>, ctx: &mut Context) -> Poll> { + if self.polled { + let waker = ctx.waker().clone(); + std::thread::spawn(move || { + std::thread::sleep(std::time::Duration::from_millis(10)); + waker.wake_by_ref(); + }); + self.polled = false; + Poll::Pending + } else { + self.polled = true; + Poll::Ready(Some(S(Box::new(StrictOnce { polled: false })))) + } + } + } + + assert_eq!( + Interchanger { polled: false } + .take(2) + .flat_map(identity) + .count() + .await, + 0 + ); + }); +}