From f7aa962dafc8e49437440148505fd4a114270e17 Mon Sep 17 00:00:00 2001 From: Stjepan Glavina Date: Mon, 28 Sep 2020 18:49:55 +0200 Subject: [PATCH] Store a future inside Incoming --- src/net/tcp/listener.rs | 42 ++++++++++++++++++++++++++++--------- src/os/unix/net/listener.rs | 39 ++++++++++++++++++++++++++-------- 2 files changed, 62 insertions(+), 19 deletions(-) diff --git a/src/net/tcp/listener.rs b/src/net/tcp/listener.rs index fc88cda5..cfefc7d2 100644 --- a/src/net/tcp/listener.rs +++ b/src/net/tcp/listener.rs @@ -1,3 +1,4 @@ +use std::fmt; use std::future::Future; use std::net::SocketAddr; use std::pin::Pin; @@ -8,7 +9,7 @@ use crate::io; use crate::net::{TcpStream, ToSocketAddrs}; use crate::stream::Stream; use crate::sync::Arc; -use crate::task::{Context, Poll}; +use crate::task::{ready, Context, Poll}; /// A TCP socket server, listening for connections. /// @@ -146,7 +147,10 @@ impl TcpListener { /// # Ok(()) }) } /// ``` pub fn incoming(&self) -> Incoming<'_> { - Incoming(self) + Incoming { + listener: self, + accept: None, + } } /// Returns the local address that this listener is bound to. @@ -182,18 +186,36 @@ impl TcpListener { /// [`incoming`]: struct.TcpListener.html#method.incoming /// [`TcpListener`]: struct.TcpListener.html /// [`std::net::Incoming`]: https://doc.rust-lang.org/std/net/struct.Incoming.html -#[derive(Debug)] -pub struct Incoming<'a>(&'a TcpListener); +pub struct Incoming<'a> { + listener: &'a TcpListener, + accept: Option< + Pin> + Send + Sync + 'a>>, + >, +} -impl<'a> Stream for Incoming<'a> { +impl Stream for Incoming<'_> { type Item = io::Result; - fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - let future = self.0.accept(); - pin_utils::pin_mut!(future); + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + loop { + if self.accept.is_none() { + self.accept = Some(Box::pin(self.listener.accept())); + } + + if let Some(f) = &mut self.accept { + let res = ready!(f.as_mut().poll(cx)); + self.accept = None; + return Poll::Ready(Some(res.map(|(stream, _)| stream))); + } + } + } +} - let (socket, _) = futures_core::ready!(future.poll(cx))?; - Poll::Ready(Some(Ok(socket))) +impl fmt::Debug for Incoming<'_> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("Incoming") + .field("listener", self.listener) + .finish() } } diff --git a/src/os/unix/net/listener.rs b/src/os/unix/net/listener.rs index dc4c64c7..3573d7d3 100644 --- a/src/os/unix/net/listener.rs +++ b/src/os/unix/net/listener.rs @@ -14,7 +14,7 @@ use crate::os::unix::io::{AsRawFd, FromRawFd, IntoRawFd, RawFd}; use crate::path::Path; use crate::stream::Stream; use crate::sync::Arc; -use crate::task::{Context, Poll}; +use crate::task::{ready, Context, Poll}; /// A Unix domain socket server, listening for connections. /// @@ -128,7 +128,10 @@ impl UnixListener { /// # Ok(()) }) } /// ``` pub fn incoming(&self) -> Incoming<'_> { - Incoming(self) + Incoming { + listener: self, + accept: None, + } } /// Returns the local socket address of this listener. @@ -174,18 +177,36 @@ impl fmt::Debug for UnixListener { /// [`None`]: https://doc.rust-lang.org/std/option/enum.Option.html#variant.None /// [`incoming`]: struct.UnixListener.html#method.incoming /// [`UnixListener`]: struct.UnixListener.html -#[derive(Debug)] -pub struct Incoming<'a>(&'a UnixListener); +pub struct Incoming<'a> { + listener: &'a UnixListener, + accept: Option< + Pin> + Send + Sync + 'a>>, + >, +} impl Stream for Incoming<'_> { type Item = io::Result; - fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - let future = self.0.accept(); - pin_utils::pin_mut!(future); + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + loop { + if self.accept.is_none() { + self.accept = Some(Box::pin(self.listener.accept())); + } + + if let Some(f) = &mut self.accept { + let res = ready!(f.as_mut().poll(cx)); + self.accept = None; + return Poll::Ready(Some(res.map(|(stream, _)| stream))); + } + } + } +} - let (socket, _) = futures_core::ready!(future.poll(cx))?; - Poll::Ready(Some(Ok(socket))) +impl fmt::Debug for Incoming<'_> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("Incoming") + .field("listener", self.listener) + .finish() } }