diff --git a/src/stream/stream/max_by_key.rs b/src/stream/stream/max_by_key.rs index b5bc7e0..e421f94 100644 --- a/src/stream/stream/max_by_key.rs +++ b/src/stream/stream/max_by_key.rs @@ -1,6 +1,6 @@ use std::cmp::Ordering; -use std::pin::Pin; use std::future::Future; +use std::pin::Pin; use pin_project_lite::pin_project; @@ -13,7 +13,7 @@ pin_project! { pub struct MaxByKeyFuture { #[pin] stream: S, - max: Option, + max: Option<(T, T)>, key_by: K, } } @@ -37,24 +37,29 @@ where type Output = Option; fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + fn key(mut f: impl FnMut(&T) -> B) -> impl FnMut(T) -> (B, T) { + move |x| (f(&x), x) + } + let this = self.project(); let next = futures_core::ready!(this.stream.poll_next(cx)); match next { Some(new) => { - let new = (this.key_by)(&new); + let (key, value) = key(this.key_by)(new); cx.waker().wake_by_ref(); + match this.max.take() { - None => *this.max = Some(new), + None => *this.max = Some((key, value)), - Some(old) => match new.cmp(&old) { - Ordering::Greater => *this.max = Some(new), + Some(old) => match key.cmp(&old.0) { + Ordering::Greater => *this.max = Some((key, value)), _ => *this.max = Some(old), }, } Poll::Pending } - None => Poll::Ready(this.max.take()), + None => Poll::Ready(this.max.take().map(|max| max.1)), } } } diff --git a/src/stream/stream/min_by_key.rs b/src/stream/stream/min_by_key.rs index 8179fb3..142dfe1 100644 --- a/src/stream/stream/min_by_key.rs +++ b/src/stream/stream/min_by_key.rs @@ -1,6 +1,6 @@ use std::cmp::Ordering; -use std::pin::Pin; use std::future::Future; +use std::pin::Pin; use pin_project_lite::pin_project; @@ -13,7 +13,7 @@ pin_project! { pub struct MinByKeyFuture { #[pin] stream: S, - min: Option, + min: Option<(T, T)>, key_by: K, } } @@ -37,24 +37,29 @@ where type Output = Option; fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + fn key(mut f: impl FnMut(&T) -> B) -> impl FnMut(T) -> (B, T) { + move |x| (f(&x), x) + } + let this = self.project(); let next = futures_core::ready!(this.stream.poll_next(cx)); match next { Some(new) => { - let new = (this.key_by)(&new); + let (key, value) = key(this.key_by)(new); cx.waker().wake_by_ref(); + match this.min.take() { - None => *this.min = Some(new), + None => *this.min = Some((key, value)), - Some(old) => match new.cmp(&old) { - Ordering::Less => *this.min = Some(new), + Some(old) => match key.cmp(&old.0) { + Ordering::Less => *this.min = Some((key, value)), _ => *this.min = Some(old), }, } Poll::Pending } - None => Poll::Ready(this.min.take()), + None => Poll::Ready(this.min.take().map(|min| min.1)), } } } diff --git a/src/stream/stream/mod.rs b/src/stream/stream/mod.rs index f876576..48d865e 100644 --- a/src/stream/stream/mod.rs +++ b/src/stream/stream/mod.rs @@ -875,10 +875,10 @@ extension_trait! { use async_std::prelude::*; use async_std::stream; - let s = stream::from_iter(vec![1isize, 2, -3]); + let s = stream::from_iter(vec![-1isize, 2, -3]); let min = s.clone().min_by_key(|x| x.abs()).await; - assert_eq!(min, Some(1)); + assert_eq!(min, Some(-1)); let min = stream::empty::().min_by_key(|x| x.abs()).await; assert_eq!(min, None); @@ -911,12 +911,12 @@ extension_trait! { use async_std::prelude::*; use async_std::stream; - let s = stream::from_iter(vec![-1isize, -2, -3]); + let s = stream::from_iter(vec![-3_i32, 0, 1, 5, -10]); let max = s.clone().max_by_key(|x| x.abs()).await; - assert_eq!(max, Some(3)); + assert_eq!(max, Some(-10)); - let max = stream::empty::().min_by_key(|x| x.abs()).await; + let max = stream::empty::().max_by_key(|x| x.abs()).await; assert_eq!(max, None); # # }) }