From 6b00e5e66cae4bda1017dd4aab494744409c4a20 Mon Sep 17 00:00:00 2001 From: Wouter Geraedts Date: Wed, 16 Oct 2019 02:32:27 +0200 Subject: [PATCH] Implemented StreamExt::try_fold (#344) --- src/stream/stream/mod.rs | 42 +++++++++++++++++++++++++ src/stream/stream/try_fold.rs | 59 +++++++++++++++++++++++++++++++++++ 2 files changed, 101 insertions(+) create mode 100644 src/stream/stream/try_fold.rs diff --git a/src/stream/stream/mod.rs b/src/stream/stream/mod.rs index f2b9830..d582d70 100644 --- a/src/stream/stream/mod.rs +++ b/src/stream/stream/mod.rs @@ -49,6 +49,7 @@ mod skip_while; mod step_by; mod take; mod take_while; +mod try_fold; mod try_for_each; mod zip; @@ -69,6 +70,7 @@ use min_by::MinByFuture; use next::NextFuture; use nth::NthFuture; use partial_cmp::PartialCmpFuture; +use try_fold::TryFoldFuture; use try_for_each::TryForEeachFuture; pub use chain::Chain; @@ -1042,6 +1044,46 @@ extension_trait! { Skip::new(self, n) } + #[doc = r#" + A combinator that applies a function as long as it returns successfully, producing a single, final value. + Immediately returns the error when the function returns unsuccessfully. + + # Examples + + Basic usage: + + ``` + # fn main() { async_std::task::block_on(async { + # + use async_std::prelude::*; + use std::collections::VecDeque; + + let s: VecDeque = vec![1, 2, 3].into_iter().collect(); + let sum = s.try_fold(0, |acc, v| { + if (acc+v) % 2 == 1 { + Ok(v+3) + } else { + Err("fail") + } + }).await; + + assert_eq!(sum, Err("fail")); + # + # }) } + ``` + "#] + fn try_fold( + self, + init: T, + f: F, + ) -> impl Future> [TryFoldFuture] + where + Self: Sized, + F: FnMut(B, Self::Item) -> Result, + { + TryFoldFuture::new(self, init, f) + } + #[doc = r#" Applies a falliable function to each element in a stream, stopping at first error and returning it. diff --git a/src/stream/stream/try_fold.rs b/src/stream/stream/try_fold.rs new file mode 100644 index 0000000..212b058 --- /dev/null +++ b/src/stream/stream/try_fold.rs @@ -0,0 +1,59 @@ +use std::marker::PhantomData; +use std::pin::Pin; + +use crate::future::Future; +use crate::stream::Stream; +use crate::task::{Context, Poll}; + +#[doc(hidden)] +#[allow(missing_debug_implementations)] +pub struct TryFoldFuture { + stream: S, + f: F, + acc: Option, + __t: PhantomData, +} + +impl TryFoldFuture { + pin_utils::unsafe_pinned!(stream: S); + pin_utils::unsafe_unpinned!(f: F); + pin_utils::unsafe_unpinned!(acc: Option); + + pub(super) fn new(stream: S, init: T, f: F) -> Self { + TryFoldFuture { + stream, + f, + acc: Some(init), + __t: PhantomData, + } + } +} + +impl Future for TryFoldFuture +where + S: Stream + Sized, + F: FnMut(T, S::Item) -> Result, +{ + type Output = Result; + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + loop { + let next = futures_core::ready!(self.as_mut().stream().poll_next(cx)); + + match next { + Some(v) => { + let old = self.as_mut().acc().take().unwrap(); + let new = (self.as_mut().f())(old, v); + + match new { + Ok(o) => { + *self.as_mut().acc() = Some(o); + } + Err(e) => return Poll::Ready(Err(e)), + } + } + None => return Poll::Ready(Ok(self.as_mut().acc().take().unwrap())), + } + } + } +}