diff --git a/src/stream/stream/fold.rs b/src/stream/stream/fold.rs new file mode 100644 index 0000000..18ddcd8 --- /dev/null +++ b/src/stream/stream/fold.rs @@ -0,0 +1,53 @@ +use std::marker::PhantomData; +use std::pin::Pin; + +use crate::future::Future; +use crate::stream::Stream; +use crate::task::{Context, Poll}; + +#[doc(hidden)] +#[allow(missing_debug_implementations)] +pub struct FoldFuture<S, F, T, B> { + stream: S, + f: F, + acc: Option<B>, + __t: PhantomData<T>, +} + +impl<S, F, T, B> FoldFuture<S, F, T, B> { + pin_utils::unsafe_pinned!(stream: S); + pin_utils::unsafe_unpinned!(f: F); + pin_utils::unsafe_unpinned!(acc: Option<B>); + + pub(super) fn new(stream: S, init: B, f: F) -> Self { + FoldFuture { + stream, + f, + acc: Some(init), + __t: PhantomData, + } + } +} + +impl<S, F, B> Future for FoldFuture<S, F, S::Item, B> +where + S: Stream + Sized, + F: FnMut(B, S::Item) -> B, +{ + type Output = B; + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { + loop { + let next = futures_core::ready!(self.as_mut().stream().poll_next(cx)); + + match next { + Some(v) => { + let old = self.as_mut().acc().take().unwrap(); + let new = (self.as_mut().f())(old, v); + *self.as_mut().acc() = Some(new); + } + None => return Poll::Ready(self.as_mut().acc().take().unwrap()), + } + } + } +} diff --git a/src/stream/stream/mod.rs b/src/stream/stream/mod.rs index 74bf740..578dd56 100644 --- a/src/stream/stream/mod.rs +++ b/src/stream/stream/mod.rs @@ -27,6 +27,7 @@ mod enumerate; mod filter_map; mod find; mod find_map; +mod fold; mod min_by; mod next; mod nth; @@ -44,6 +45,7 @@ use enumerate::Enumerate; use filter_map::FilterMap; use find::FindFuture; use find_map::FindMapFuture; +use fold::FoldFuture; use min_by::MinByFuture; use next::NextFuture; use nth::NthFuture; @@ -481,6 +483,34 @@ pub trait Stream { FindMapFuture::new(self, f) } + /// A combinator that applies a function to every element in a stream + /// producing a single, final value. + /// + /// # Examples + /// + /// Basic usage: + /// + /// ``` + /// # fn main() { async_std::task::block_on(async { + /// # + /// use async_std::prelude::*; + /// use std::collections::VecDeque; + /// + /// let s: VecDeque<usize> = vec![1, 2, 3].into_iter().collect(); + /// let sum = s.fold(0, |acc, x| acc + x).await; + /// + /// assert_eq!(sum, 6); + /// # + /// # }) } + /// ``` + fn fold<B, F>(self, init: B, f: F) -> FoldFuture<Self, F, Self::Item, B> + where + Self: Sized, + F: FnMut(B, Self::Item) -> B, + { + FoldFuture::new(self, init, f) + } + /// Tests if any element of the stream matches a predicate. /// /// `any()` takes a closure that returns `true` or `false`. It applies