From 238a3c882bcb86da8c4041834ce25688ce270c7f Mon Sep 17 00:00:00 2001 From: DCjanus Date: Thu, 5 Sep 2019 02:09:49 +0800 Subject: [PATCH] Implement an async version of ToSocketAddrs (#74) * Implement an async version of ToSocketAddrs * fix documentation issue * genius hack: pretending to be `impl Future` * replace `std::net::ToSocketAddrs` with `async-std::net::ToSocketAddrs` * Move unit tests into the tests directory * Stylistic changes * Remove re-exports in async_std::net * fix broken link * some mirror changes * remove unnecessary format * migrate: `std::net::ToSocketAddrs` -> `async_std::net::ToSocketAddrs` * fix typo(tutorial) * remove unnecessary type bound * lifetime for future --- docs/src/tutorial/accept_loop.md | 25 ++- docs/src/tutorial/all_together.md | 3 +- docs/src/tutorial/clean_shutdown.md | 6 +- docs/src/tutorial/handling_disconnection.md | 3 +- docs/src/tutorial/implementing_a_client.md | 3 +- docs/src/tutorial/receiving_messages.md | 6 +- src/net/addr.rs | 162 ++++++++++++++++++++ src/net/mod.rs | 2 + src/net/tcp/listener.rs | 11 +- src/net/tcp/stream.rs | 11 +- src/net/udp/mod.rs | 16 +- tests/addr.rs | 84 ++++++++++ 12 files changed, 286 insertions(+), 46 deletions(-) create mode 100644 src/net/addr.rs create mode 100644 tests/addr.rs diff --git a/docs/src/tutorial/accept_loop.md b/docs/src/tutorial/accept_loop.md index d40d348..96c15ba 100644 --- a/docs/src/tutorial/accept_loop.md +++ b/docs/src/tutorial/accept_loop.md @@ -6,23 +6,20 @@ First of all, let's add required import boilerplate: ```rust,edition2018 # extern crate async_std; -use std::net::ToSocketAddrs; // 1 use async_std::{ - prelude::*, // 2 - task, // 3 - net::TcpListener, // 4 + prelude::*, // 1 + task, // 2 + net::{TcpListener, ToSocketAddrs}, // 3 }; -type Result = std::result::Result>; // 5 +type Result = std::result::Result>; // 4 ``` -1. `async_std` uses `std` types where appropriate. - We'll need `ToSocketAddrs` to specify address to listen on. -2. `prelude` re-exports some traits required to work with futures and streams. -3. The `task` module roughly corresponds to the `std::thread` module, but tasks are much lighter weight. +1. `prelude` re-exports some traits required to work with futures and streams. +2. The `task` module roughly corresponds to the `std::thread` module, but tasks are much lighter weight. A single thread can run many tasks. -4. For the socket type, we use `TcpListener` from `async_std`, which is just like `std::net::TcpListener`, but is non-blocking and uses `async` API. -5. We will skip implementing comprehensive error handling in this example. +3. For the socket type, we use `TcpListener` from `async_std`, which is just like `std::net::TcpListener`, but is non-blocking and uses `async` API. +4. We will skip implementing comprehensive error handling in this example. To propagate the errors, we will use a boxed error trait object. Do you know that there's `From<&'_ str> for Box` implementation in stdlib, which allows you to use strings with `?` operator? @@ -31,10 +28,9 @@ Now we can write the server's accept loop: ```rust,edition2018 # extern crate async_std; # use async_std::{ -# net::TcpListener, +# net::{TcpListener, ToSocketAddrs}, # prelude::Stream, # }; -# use std::net::ToSocketAddrs; # # type Result = std::result::Result>; # @@ -69,11 +65,10 @@ Finally, let's add main: ```rust,edition2018 # extern crate async_std; # use async_std::{ -# net::TcpListener, +# net::{TcpListener, ToSocketAddrs}, # prelude::Stream, # task, # }; -# use std::net::ToSocketAddrs; # # type Result = std::result::Result>; # diff --git a/docs/src/tutorial/all_together.md b/docs/src/tutorial/all_together.md index 352d692..415f3b8 100644 --- a/docs/src/tutorial/all_together.md +++ b/docs/src/tutorial/all_together.md @@ -7,7 +7,7 @@ At this point, we only need to start the broker to get a fully-functioning (in t # extern crate futures; use async_std::{ io::{self, BufReader}, - net::{TcpListener, TcpStream}, + net::{TcpListener, TcpStream, ToSocketAddrs}, prelude::*, task, }; @@ -17,7 +17,6 @@ use futures::{ }; use std::{ collections::hash_map::{HashMap, Entry}, - net::ToSocketAddrs, sync::Arc, }; diff --git a/docs/src/tutorial/clean_shutdown.md b/docs/src/tutorial/clean_shutdown.md index 1c2fc76..6bf7056 100644 --- a/docs/src/tutorial/clean_shutdown.md +++ b/docs/src/tutorial/clean_shutdown.md @@ -25,7 +25,7 @@ Let's add waiting to the server: # extern crate futures; # use async_std::{ # io::{self, BufReader}, -# net::{TcpListener, TcpStream}, +# net::{TcpListener, TcpStream, ToSocketAddrs}, # prelude::*, # task, # }; @@ -35,7 +35,6 @@ Let's add waiting to the server: # }; # use std::{ # collections::hash_map::{HashMap, Entry}, -# net::ToSocketAddrs, # sync::Arc, # }; # @@ -160,7 +159,7 @@ And to the broker: # extern crate futures; # use async_std::{ # io::{self, BufReader}, -# net::{TcpListener, TcpStream}, +# net::{TcpListener, TcpStream, ToSocketAddrs}, # prelude::*, # task, # }; @@ -170,7 +169,6 @@ And to the broker: # }; # use std::{ # collections::hash_map::{HashMap, Entry}, -# net::ToSocketAddrs, # sync::Arc, # }; # diff --git a/docs/src/tutorial/handling_disconnection.md b/docs/src/tutorial/handling_disconnection.md index 1cc07b2..30827ba 100644 --- a/docs/src/tutorial/handling_disconnection.md +++ b/docs/src/tutorial/handling_disconnection.md @@ -121,13 +121,12 @@ The final code looks like this: # extern crate futures; use async_std::{ io::{BufReader, BufRead, Write}, - net::{TcpListener, TcpStream}, + net::{TcpListener, TcpStream, ToSocketAddrs}, task, }; use futures::{channel::mpsc, future::Future, select, FutureExt, SinkExt, StreamExt}; use std::{ collections::hash_map::{Entry, HashMap}, - net::ToSocketAddrs, sync::Arc, }; diff --git a/docs/src/tutorial/implementing_a_client.md b/docs/src/tutorial/implementing_a_client.md index 35cccd8..97e7319 100644 --- a/docs/src/tutorial/implementing_a_client.md +++ b/docs/src/tutorial/implementing_a_client.md @@ -19,11 +19,10 @@ With async, we can just use the `select!` macro. # extern crate futures; use async_std::{ io::{stdin, BufRead, BufReader, Write}, - net::TcpStream, + net::{TcpStream, ToSocketAddrs}, task, }; use futures::{select, FutureExt, StreamExt}; -use std::net::ToSocketAddrs; type Result = std::result::Result>; diff --git a/docs/src/tutorial/receiving_messages.md b/docs/src/tutorial/receiving_messages.md index 9cef56d..667cf1c 100644 --- a/docs/src/tutorial/receiving_messages.md +++ b/docs/src/tutorial/receiving_messages.md @@ -11,11 +11,10 @@ We need to: # extern crate async_std; # use async_std::{ # io::{BufRead, BufReader}, -# net::{TcpListener, TcpStream}, +# net::{TcpListener, TcpStream, ToSocketAddrs}, # prelude::Stream, # task, # }; -# use std::net::ToSocketAddrs; # # type Result = std::result::Result>; # @@ -77,11 +76,10 @@ We can "fix" it by waiting for the task to be joined, like this: # extern crate async_std; # use async_std::{ # io::{BufRead, BufReader}, -# net::{TcpListener, TcpStream}, +# net::{TcpListener, TcpStream, ToSocketAddrs}, # prelude::Stream, # task, # }; -# use std::net::ToSocketAddrs; # # type Result = std::result::Result>; # diff --git a/src/net/addr.rs b/src/net/addr.rs new file mode 100644 index 0000000..39dba52 --- /dev/null +++ b/src/net/addr.rs @@ -0,0 +1,162 @@ +use std::net::{IpAddr, Ipv4Addr, Ipv6Addr}; +use std::net::{SocketAddr, SocketAddrV4, SocketAddrV6}; +use std::pin::Pin; + +use cfg_if::cfg_if; +use futures::future::{ready, Ready}; + +use crate::future::Future; +use crate::io; +use crate::task::blocking; +use crate::task::{Context, Poll}; +use std::marker::PhantomData; + +cfg_if! { + if #[cfg(feature = "docs")] { + #[doc(hidden)] + pub struct ImplFuture<'a, T>(std::marker::PhantomData<&'a T>); + + macro_rules! ret { + ($a:lifetime, $f:tt, $i:ty) => (ImplFuture<$a, io::Result<$i>>); + } + } else { + macro_rules! ret { + ($a:lifetime, $f:tt, $i:ty) => ($f<$a, $i>); + } + } +} + +/// A trait for objects which can be converted or resolved to one or more [`SocketAddr`] values. +/// +/// This trait is an async version of [`std::net::ToSocketAddrs`]. +/// +/// [`std::net::ToSocketAddrs`]: https://doc.rust-lang.org/std/net/trait.ToSocketAddrs.html +/// [`SocketAddr`]: https://doc.rust-lang.org/std/net/enum.SocketAddr.html +pub trait ToSocketAddrs { + /// Returned iterator over socket addresses which this type may correspond to. + type Iter: Iterator; + + /// Converts this object to an iterator of resolved `SocketAddr`s. + /// + /// The returned iterator may not actually yield any values depending on the outcome of any + /// resolution performed. + /// + /// Note that this function may block a backend thread while resolution is performed. + fn to_socket_addrs(&self) -> ret!('_, ToSocketAddrsFuture, Self::Iter); +} + +#[doc(hidden)] +#[allow(missing_debug_implementations)] +pub enum ToSocketAddrsFuture<'a, I: Iterator> { + Phantom(PhantomData<&'a ()>), + Join(blocking::JoinHandle>), + Ready(Ready>), +} + +impl> Future for ToSocketAddrsFuture<'_, I> { + type Output = io::Result; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + match self.get_mut() { + ToSocketAddrsFuture::Join(f) => Pin::new(&mut *f).poll(cx), + ToSocketAddrsFuture::Ready(f) => Pin::new(&mut *f).poll(cx), + _ => unreachable!(), + } + } +} + +impl ToSocketAddrs for SocketAddr { + type Iter = std::option::IntoIter; + + fn to_socket_addrs(&self) -> ret!('_, ToSocketAddrsFuture, Self::Iter) { + ToSocketAddrsFuture::Ready(ready(std::net::ToSocketAddrs::to_socket_addrs(self))) + } +} + +impl ToSocketAddrs for SocketAddrV4 { + type Iter = std::option::IntoIter; + + fn to_socket_addrs(&self) -> ret!('_, ToSocketAddrsFuture, Self::Iter) { + ToSocketAddrsFuture::Ready(ready(std::net::ToSocketAddrs::to_socket_addrs(self))) + } +} + +impl ToSocketAddrs for SocketAddrV6 { + type Iter = std::option::IntoIter; + + fn to_socket_addrs(&self) -> ret!('_, ToSocketAddrsFuture, Self::Iter) { + ToSocketAddrsFuture::Ready(ready(std::net::ToSocketAddrs::to_socket_addrs(self))) + } +} + +impl ToSocketAddrs for (IpAddr, u16) { + type Iter = std::option::IntoIter; + + fn to_socket_addrs(&self) -> ret!('_, ToSocketAddrsFuture, Self::Iter) { + ToSocketAddrsFuture::Ready(ready(std::net::ToSocketAddrs::to_socket_addrs(self))) + } +} + +impl ToSocketAddrs for (Ipv4Addr, u16) { + type Iter = std::option::IntoIter; + + fn to_socket_addrs(&self) -> ret!('_, ToSocketAddrsFuture, Self::Iter) { + ToSocketAddrsFuture::Ready(ready(std::net::ToSocketAddrs::to_socket_addrs(self))) + } +} + +impl ToSocketAddrs for (Ipv6Addr, u16) { + type Iter = std::option::IntoIter; + + fn to_socket_addrs(&self) -> ret!('_, ToSocketAddrsFuture, Self::Iter) { + ToSocketAddrsFuture::Ready(ready(std::net::ToSocketAddrs::to_socket_addrs(self))) + } +} + +impl ToSocketAddrs for (&str, u16) { + type Iter = std::vec::IntoIter; + + fn to_socket_addrs(&self) -> ret!('_, ToSocketAddrsFuture, Self::Iter) { + let host = self.0.to_string(); + let port = self.1; + let join = blocking::spawn(async move { + std::net::ToSocketAddrs::to_socket_addrs(&(host.as_str(), port)) + }); + ToSocketAddrsFuture::Join(join) + } +} + +impl ToSocketAddrs for str { + type Iter = std::vec::IntoIter; + + fn to_socket_addrs(&self) -> ret!('_, ToSocketAddrsFuture, Self::Iter) { + let socket_addrs = self.to_string(); + let join = + blocking::spawn(async move { std::net::ToSocketAddrs::to_socket_addrs(&socket_addrs) }); + ToSocketAddrsFuture::Join(join) + } +} + +impl<'a> ToSocketAddrs for &'a [SocketAddr] { + type Iter = std::iter::Cloned>; + + fn to_socket_addrs(&self) -> ret!('_, ToSocketAddrsFuture, Self::Iter) { + ToSocketAddrsFuture::Ready(ready(std::net::ToSocketAddrs::to_socket_addrs(self))) + } +} + +impl ToSocketAddrs for &T { + type Iter = T::Iter; + + fn to_socket_addrs(&self) -> ret!('_, ToSocketAddrsFuture, Self::Iter) { + (**self).to_socket_addrs() + } +} + +impl ToSocketAddrs for String { + type Iter = std::vec::IntoIter; + + fn to_socket_addrs(&self) -> ret!('_, ToSocketAddrsFuture, Self::Iter) { + ToSocketAddrs::to_socket_addrs(self.as_str()) + } +} diff --git a/src/net/mod.rs b/src/net/mod.rs index db4bd3c..259dc1d 100644 --- a/src/net/mod.rs +++ b/src/net/mod.rs @@ -28,9 +28,11 @@ //! # }) } //! ``` +pub use addr::ToSocketAddrs; pub use tcp::{Incoming, TcpListener, TcpStream}; pub use udp::UdpSocket; +mod addr; pub(crate) mod driver; mod tcp; mod udp; diff --git a/src/net/tcp/listener.rs b/src/net/tcp/listener.rs index 60a9689..ac18387 100644 --- a/src/net/tcp/listener.rs +++ b/src/net/tcp/listener.rs @@ -1,4 +1,4 @@ -use std::net::{self, SocketAddr, ToSocketAddrs}; +use std::net::SocketAddr; use std::pin::Pin; use cfg_if::cfg_if; @@ -8,6 +8,7 @@ use super::TcpStream; use crate::future::Future; use crate::io; use crate::net::driver::IoHandle; +use crate::net::ToSocketAddrs; use crate::task::{Context, Poll}; /// A TCP socket server, listening for connections. @@ -82,7 +83,7 @@ impl TcpListener { pub async fn bind(addrs: A) -> io::Result { let mut last_err = None; - for addr in addrs.to_socket_addrs()? { + for addr in addrs.to_socket_addrs().await? { match mio::net::TcpListener::bind(&addr) { Ok(mio_listener) => { #[cfg(unix)] @@ -236,9 +237,9 @@ impl<'a> futures::Stream for Incoming<'a> { } } -impl From for TcpListener { +impl From for TcpListener { /// Converts a `std::net::TcpListener` into its asynchronous equivalent. - fn from(listener: net::TcpListener) -> TcpListener { + fn from(listener: std::net::TcpListener) -> TcpListener { let mio_listener = mio::net::TcpListener::from_std(listener).unwrap(); #[cfg(unix)] @@ -279,7 +280,7 @@ cfg_if! { impl FromRawFd for TcpListener { unsafe fn from_raw_fd(fd: RawFd) -> TcpListener { - net::TcpListener::from_raw_fd(fd).into() + std::net::TcpListener::from_raw_fd(fd).into() } } diff --git a/src/net/tcp/stream.rs b/src/net/tcp/stream.rs index d823317..5ea181f 100644 --- a/src/net/tcp/stream.rs +++ b/src/net/tcp/stream.rs @@ -1,6 +1,6 @@ use std::io::{IoSlice, IoSliceMut}; use std::mem; -use std::net::{self, SocketAddr, ToSocketAddrs}; +use std::net::SocketAddr; use std::pin::Pin; use cfg_if::cfg_if; @@ -9,6 +9,7 @@ use futures::io::{AsyncRead, AsyncWrite}; use crate::io; use crate::net::driver::IoHandle; +use crate::net::ToSocketAddrs; use crate::task::{Context, Poll}; /// A TCP stream between a local and a remote socket. @@ -80,7 +81,7 @@ impl TcpStream { pub async fn connect(addrs: A) -> io::Result { let mut last_err = None; - for addr in addrs.to_socket_addrs()? { + for addr in addrs.to_socket_addrs().await? { let res = Self::connect_to(addr).await; match res { @@ -437,9 +438,9 @@ impl AsyncWrite for &TcpStream { } } -impl From for TcpStream { +impl From for TcpStream { /// Converts a `std::net::TcpStream` into its asynchronous equivalent. - fn from(stream: net::TcpStream) -> TcpStream { + fn from(stream: std::net::TcpStream) -> TcpStream { let mio_stream = mio::net::TcpStream::from_stream(stream).unwrap(); #[cfg(unix)] @@ -480,7 +481,7 @@ cfg_if! { impl FromRawFd for TcpStream { unsafe fn from_raw_fd(fd: RawFd) -> TcpStream { - net::TcpStream::from_raw_fd(fd).into() + std::net::TcpStream::from_raw_fd(fd).into() } } diff --git a/src/net/udp/mod.rs b/src/net/udp/mod.rs index 3e9e749..19119a5 100644 --- a/src/net/udp/mod.rs +++ b/src/net/udp/mod.rs @@ -1,10 +1,12 @@ use std::io; -use std::net::{self, Ipv4Addr, Ipv6Addr, SocketAddr, ToSocketAddrs}; +use std::net::SocketAddr; use cfg_if::cfg_if; use futures::future; +use std::net::{Ipv4Addr, Ipv6Addr}; use crate::net::driver::IoHandle; +use crate::net::ToSocketAddrs; use crate::task::Poll; /// A UDP socket. @@ -75,7 +77,7 @@ impl UdpSocket { pub async fn bind(addr: A) -> io::Result { let mut last_err = None; - for addr in addr.to_socket_addrs()? { + for addr in addr.to_socket_addrs().await? { match mio::net::UdpSocket::bind(&addr) { Ok(mio_socket) => { #[cfg(unix)] @@ -152,7 +154,7 @@ impl UdpSocket { /// # Ok(()) }) } /// ``` pub async fn send_to(&self, buf: &[u8], addrs: A) -> io::Result { - let addr = match addrs.to_socket_addrs()?.next() { + let addr = match addrs.to_socket_addrs().await?.next() { Some(addr) => addr, None => { return Err(io::Error::new( @@ -237,7 +239,7 @@ impl UdpSocket { pub async fn connect(&self, addrs: A) -> io::Result<()> { let mut last_err = None; - for addr in addrs.to_socket_addrs()? { + for addr in addrs.to_socket_addrs().await? { match self.io_handle.get_ref().connect(addr) { Ok(()) => return Ok(()), Err(err) => last_err = Some(err), @@ -506,9 +508,9 @@ impl UdpSocket { } } -impl From for UdpSocket { +impl From for UdpSocket { /// Converts a `std::net::UdpSocket` into its asynchronous equivalent. - fn from(socket: net::UdpSocket) -> UdpSocket { + fn from(socket: std::net::UdpSocket) -> UdpSocket { let mio_socket = mio::net::UdpSocket::from_socket(socket).unwrap(); #[cfg(unix)] @@ -549,7 +551,7 @@ cfg_if! { impl FromRawFd for UdpSocket { unsafe fn from_raw_fd(fd: RawFd) -> UdpSocket { - net::UdpSocket::from_raw_fd(fd).into() + std::net::UdpSocket::from_raw_fd(fd).into() } } diff --git a/tests/addr.rs b/tests/addr.rs new file mode 100644 index 0000000..aada557 --- /dev/null +++ b/tests/addr.rs @@ -0,0 +1,84 @@ +use std::net::{Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6}; + +use async_std::net::ToSocketAddrs; +use async_std::task; + +fn blocking_resolve(a: A) -> Result, String> +where + A: ToSocketAddrs, + A::Iter: Send, +{ + let socket_addrs = task::block_on(a.to_socket_addrs()); + match socket_addrs { + Ok(a) => Ok(a.collect()), + Err(e) => Err(e.to_string()), + } +} + +#[test] +fn to_socket_addr_ipaddr_u16() { + let a = Ipv4Addr::new(77, 88, 21, 11); + let p = 12345; + let e = SocketAddr::V4(SocketAddrV4::new(a, p)); + assert_eq!(Ok(vec![e]), blocking_resolve((a, p))); +} + +#[test] +fn to_socket_addr_str_u16() { + let a = SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(77, 88, 21, 11), 24352)); + assert_eq!(Ok(vec![a]), blocking_resolve(("77.88.21.11", 24352))); + + let a = SocketAddr::V6(SocketAddrV6::new( + Ipv6Addr::new(0x2a02, 0x6b8, 0, 1, 0, 0, 0, 1), + 53, + 0, + 0, + )); + assert_eq!(Ok(vec![a]), blocking_resolve(("2a02:6b8:0:1::1", 53))); + + let a = SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(127, 0, 0, 1), 23924)); + #[cfg(not(target_env = "sgx"))] + assert!(blocking_resolve(("localhost", 23924)).unwrap().contains(&a)); + #[cfg(target_env = "sgx")] + let _ = a; +} + +#[test] +fn to_socket_addr_str() { + let a = SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(77, 88, 21, 11), 24352)); + assert_eq!(Ok(vec![a]), blocking_resolve("77.88.21.11:24352")); + + let a = SocketAddr::V6(SocketAddrV6::new( + Ipv6Addr::new(0x2a02, 0x6b8, 0, 1, 0, 0, 0, 1), + 53, + 0, + 0, + )); + assert_eq!(Ok(vec![a]), blocking_resolve("[2a02:6b8:0:1::1]:53")); + + let a = SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(127, 0, 0, 1), 23924)); + #[cfg(not(target_env = "sgx"))] + assert!(blocking_resolve("localhost:23924").unwrap().contains(&a)); + #[cfg(target_env = "sgx")] + let _ = a; +} + +#[test] +fn to_socket_addr_string() { + let a = SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(77, 88, 21, 11), 24352)); + let s: &str = "77.88.21.11:24352"; + assert_eq!(Ok(vec![a]), blocking_resolve(s)); + + let s: &String = &"77.88.21.11:24352".to_string(); + assert_eq!(Ok(vec![a]), blocking_resolve(s)); + + let s: String = "77.88.21.11:24352".to_string(); + assert_eq!(Ok(vec![a]), blocking_resolve(s)); +} + +// FIXME: figure out why this fails on openbsd and fix it +#[test] +#[cfg(not(any(windows, target_os = "openbsd")))] +fn to_socket_addr_str_bad() { + assert!(blocking_resolve("1200::AB00:1234::2552:7777:1313:34300").is_err()); +}