diff --git a/src/lib.rs b/src/lib.rs index f188a68..c75673a 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -47,6 +47,7 @@ #![doc(test(attr(allow(unused_extern_crates, unused_variables))))] #![doc(html_logo_url = "https://async.rs/images/logo--hero.svg")] #![recursion_limit = "1024"] +#![feature(try_trait)] use cfg_if::cfg_if; diff --git a/src/stream/stream/mod.rs b/src/stream/stream/mod.rs index 07de323..bdd964c 100644 --- a/src/stream/stream/mod.rs +++ b/src/stream/stream/mod.rs @@ -40,6 +40,7 @@ mod skip; mod skip_while; mod step_by; mod take; +mod try_for_each; mod zip; use all::AllFuture; @@ -52,6 +53,7 @@ use fold::FoldFuture; use min_by::MinByFuture; use next::NextFuture; use nth::NthFuture; +use try_for_each::TryForEeachFuture; pub use chain::Chain; pub use filter::Filter; @@ -66,6 +68,7 @@ pub use zip::Zip; use std::cmp::Ordering; use std::marker::PhantomData; +use std::ops::Try; use cfg_if::cfg_if; @@ -921,6 +924,52 @@ extension_trait! { Skip::new(self, n) } + #[doc = r#" + Applies a falliable function to each element in a stream, stopping at first error and returning it. + + # Examples + + ``` + # fn main() { async_std::task::block_on(async { + # + use std::collections::VecDeque; + use std::sync::mpsc::channel; + use async_std::prelude::*; + + let (tx, rx) = channel(); + + let s: VecDeque = vec![1, 2, 3].into_iter().collect(); + let s = s.try_for_each(|v| { + if v % 2 == 1 { + tx.clone().send(v).unwrap(); + Ok(()) + } else { + Err("even") + } + }); + + let res = s.await; + drop(tx); + let values: Vec<_> = rx.iter().collect(); + + assert_eq!(values, vec![1]); + assert_eq!(res, Err("even")); + # + # }) } + ``` + "#] + fn try_for_each( + self, + f: F, + ) -> impl Future [TryForEeachFuture] + where + Self: Sized, + F: FnMut(Self::Item) -> R, + R: Try, + { + TryForEeachFuture::new(self, f) + } + #[doc = r#" 'Zips up' two streams into a single stream of pairs. diff --git a/src/stream/stream/try_for_each.rs b/src/stream/stream/try_for_each.rs new file mode 100644 index 0000000..02136e8 --- /dev/null +++ b/src/stream/stream/try_for_each.rs @@ -0,0 +1,56 @@ +use std::marker::PhantomData; +use std::ops::Try; +use std::pin::Pin; + +use crate::future::Future; +use crate::stream::Stream; +use crate::task::{Context, Poll}; + +#[doc(hidden)] +#[allow(missing_debug_implementations)] +pub struct TryForEeachFuture { + stream: S, + f: F, + __from: PhantomData, + __to: PhantomData, +} + +impl TryForEeachFuture { + pin_utils::unsafe_pinned!(stream: S); + pin_utils::unsafe_unpinned!(f: F); + + pub(crate) fn new(stream: S, f: F) -> Self { + TryForEeachFuture { + stream, + f, + __from: PhantomData, + __to: PhantomData, + } + } +} + +impl Future for TryForEeachFuture +where + S: Stream, + S::Item: std::fmt::Debug, + F: FnMut(S::Item) -> R, + R: Try, +{ + type Output = R; + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + loop { + let item = futures_core::ready!(self.as_mut().stream().poll_next(cx)); + + match item { + None => return Poll::Ready(R::from_ok(())), + Some(v) => { + let res = (self.as_mut().f())(v); + if let Err(e) = res.into_result() { + return Poll::Ready(R::from_error(e)); + } + } + } + } + } +}