Implement Clone for TcpStream (#689)

* Implement Clone for TcpStream

* Update examples

* Remove accidentally added examples
pull/286/head
Stjepan Glavina 5 years ago committed by GitHub
parent 57974ae0b7
commit 1d875836a2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -14,8 +14,9 @@ use async_std::task;
async fn process(stream: TcpStream) -> io::Result<()> { async fn process(stream: TcpStream) -> io::Result<()> {
println!("Accepted from: {}", stream.peer_addr()?); println!("Accepted from: {}", stream.peer_addr()?);
let (reader, writer) = &mut (&stream, &stream); let mut reader = stream.clone();
io::copy(reader, writer).await?; let mut writer = stream;
io::copy(&mut reader, &mut writer).await?;
Ok(()) Ok(())
} }

@ -15,8 +15,9 @@ use async_std::task;
async fn process(stream: TcpStream) -> io::Result<()> { async fn process(stream: TcpStream) -> io::Result<()> {
println!("Accepted from: {}", stream.peer_addr()?); println!("Accepted from: {}", stream.peer_addr()?);
let (reader, writer) = &mut (&stream, &stream); let mut reader = stream.clone();
io::copy(reader, writer).await?; let mut writer = stream;
io::copy(&mut reader, &mut writer).await?;
Ok(()) Ok(())
} }

@ -1,6 +1,7 @@
use std::future::Future; use std::future::Future;
use std::net::SocketAddr; use std::net::SocketAddr;
use std::pin::Pin; use std::pin::Pin;
use std::sync::Arc;
use crate::future; use crate::future;
use crate::io; use crate::io;
@ -75,9 +76,7 @@ impl TcpListener {
/// [`local_addr`]: #method.local_addr /// [`local_addr`]: #method.local_addr
pub async fn bind<A: ToSocketAddrs>(addrs: A) -> io::Result<TcpListener> { pub async fn bind<A: ToSocketAddrs>(addrs: A) -> io::Result<TcpListener> {
let mut last_err = None; let mut last_err = None;
let addrs = addrs let addrs = addrs.to_socket_addrs().await?;
.to_socket_addrs()
.await?;
for addr in addrs { for addr in addrs {
match mio::net::TcpListener::bind(&addr) { match mio::net::TcpListener::bind(&addr) {
@ -121,7 +120,7 @@ impl TcpListener {
let mio_stream = mio::net::TcpStream::from_stream(io)?; let mio_stream = mio::net::TcpStream::from_stream(io)?;
let stream = TcpStream { let stream = TcpStream {
watcher: Watcher::new(mio_stream), watcher: Arc::new(Watcher::new(mio_stream)),
}; };
Ok((stream, addr)) Ok((stream, addr))
} }

@ -1,6 +1,7 @@
use std::io::{IoSlice, IoSliceMut, Read as _, Write as _}; use std::io::{IoSlice, IoSliceMut, Read as _, Write as _};
use std::net::SocketAddr; use std::net::SocketAddr;
use std::pin::Pin; use std::pin::Pin;
use std::sync::Arc;
use crate::future; use crate::future;
use crate::io::{self, Read, Write}; use crate::io::{self, Read, Write};
@ -44,9 +45,9 @@ use crate::task::{Context, Poll};
/// # /// #
/// # Ok(()) }) } /// # Ok(()) }) }
/// ``` /// ```
#[derive(Debug)] #[derive(Debug, Clone)]
pub struct TcpStream { pub struct TcpStream {
pub(super) watcher: Watcher<mio::net::TcpStream>, pub(super) watcher: Arc<Watcher<mio::net::TcpStream>>,
} }
impl TcpStream { impl TcpStream {
@ -71,9 +72,7 @@ impl TcpStream {
/// ``` /// ```
pub async fn connect<A: ToSocketAddrs>(addrs: A) -> io::Result<TcpStream> { pub async fn connect<A: ToSocketAddrs>(addrs: A) -> io::Result<TcpStream> {
let mut last_err = None; let mut last_err = None;
let addrs = addrs let addrs = addrs.to_socket_addrs().await?;
.to_socket_addrs()
.await?;
for addr in addrs { for addr in addrs {
// mio's TcpStream::connect is non-blocking and may just be in progress // mio's TcpStream::connect is non-blocking and may just be in progress
@ -84,16 +83,20 @@ impl TcpStream {
Ok(s) => Watcher::new(s), Ok(s) => Watcher::new(s),
Err(e) => { Err(e) => {
last_err = Some(e); last_err = Some(e);
continue continue;
} }
}; };
future::poll_fn(|cx| watcher.poll_write_ready(cx)).await; future::poll_fn(|cx| watcher.poll_write_ready(cx)).await;
match watcher.get_ref().take_error() { match watcher.get_ref().take_error() {
Ok(None) => return Ok(TcpStream { watcher }), Ok(None) => {
return Ok(TcpStream {
watcher: Arc::new(watcher),
});
}
Ok(Some(e)) => last_err = Some(e), Ok(Some(e)) => last_err = Some(e),
Err(e) => last_err = Some(e) Err(e) => last_err = Some(e),
} }
} }
@ -369,7 +372,7 @@ impl From<std::net::TcpStream> for TcpStream {
fn from(stream: std::net::TcpStream) -> TcpStream { fn from(stream: std::net::TcpStream) -> TcpStream {
let mio_stream = mio::net::TcpStream::from_stream(stream).unwrap(); let mio_stream = mio::net::TcpStream::from_stream(stream).unwrap();
TcpStream { TcpStream {
watcher: Watcher::new(mio_stream), watcher: Arc::new(Watcher::new(mio_stream)),
} }
} }
} }
@ -391,7 +394,10 @@ cfg_unix! {
impl IntoRawFd for TcpStream { impl IntoRawFd for TcpStream {
fn into_raw_fd(self) -> RawFd { fn into_raw_fd(self) -> RawFd {
self.watcher.into_inner().into_raw_fd() // TODO(stjepang): This does not mean `RawFd` is now the sole owner of the file
// descriptor because it's possible that there are other clones of this `TcpStream`
// using it at the same time. We should probably document that behavior.
self.as_raw_fd()
} }
} }
} }

@ -94,3 +94,25 @@ fn smoke_async_stream_to_std_listener() -> io::Result<()> {
Ok(()) Ok(())
} }
#[test]
fn cloned_streams() -> io::Result<()> {
task::block_on(async {
let listener = TcpListener::bind("127.0.0.1:0").await?;
let addr = listener.local_addr()?;
let mut stream = TcpStream::connect(&addr).await?;
let mut cloned_stream = stream.clone();
let mut incoming = listener.incoming();
let mut write_stream = incoming.next().await.unwrap()?;
write_stream.write_all(b"Each your doing").await?;
let mut buf = [0; 15];
stream.read_exact(&mut buf[..8]).await?;
cloned_stream.read_exact(&mut buf[8..]).await?;
assert_eq!(&buf[..15], b"Each your doing");
Ok(())
})
}

Loading…
Cancel
Save