Refactor the networking driver

This commit is contained in:
Stjepan Glavina 2019-09-12 18:39:00 +02:00
parent 2ecaf1811b
commit d25dae5419
7 changed files with 192 additions and 600 deletions

View file

@ -1,10 +1,6 @@
use std::fmt; use std::fmt;
use std::io::{Read as _, Write as _};
use std::pin::Pin;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::{Arc, Mutex}; use std::sync::{Arc, Mutex};
use futures_io::{AsyncRead, AsyncWrite};
use lazy_static::lazy_static; use lazy_static::lazy_static;
use mio::{self, Evented}; use mio::{self, Evented};
use slab::Slab; use slab::Slab;
@ -19,9 +15,6 @@ struct Entry {
/// A unique identifier. /// A unique identifier.
token: mio::Token, token: mio::Token,
/// Indicates whether this I/O handle is ready for reading, writing, or if it is disconnected.
readiness: AtomicUsize,
/// Tasks that are blocked on reading from this I/O handle. /// Tasks that are blocked on reading from this I/O handle.
readers: Mutex<Vec<Waker>>, readers: Mutex<Vec<Waker>>,
@ -75,7 +68,6 @@ impl Reactor {
// Allocate an entry and insert it into the slab. // Allocate an entry and insert it into the slab.
let entry = Arc::new(Entry { let entry = Arc::new(Entry {
token, token,
readiness: AtomicUsize::new(mio::Ready::empty().as_usize()),
readers: Mutex::new(Vec::new()), readers: Mutex::new(Vec::new()),
writers: Mutex::new(Vec::new()), writers: Mutex::new(Vec::new()),
}); });
@ -151,9 +143,6 @@ fn main_loop() -> io::Result<()> {
if let Some(entry) = entries.get(token.0) { if let Some(entry) = entries.get(token.0) {
// Set the readiness flags from this I/O event. // Set the readiness flags from this I/O event.
let readiness = event.readiness(); let readiness = event.readiness();
entry
.readiness
.fetch_or(readiness.as_usize(), Ordering::SeqCst);
// Wake up reader tasks blocked on this I/O handle. // Wake up reader tasks blocked on this I/O handle.
if !(readiness & reader_interests()).is_empty() { if !(readiness & reader_interests()).is_empty() {
@ -178,7 +167,7 @@ fn main_loop() -> io::Result<()> {
/// ///
/// This handle wraps an I/O event source and exposes a "futurized" interface on top of it, /// This handle wraps an I/O event source and exposes a "futurized" interface on top of it,
/// implementing traits `AsyncRead` and `AsyncWrite`. /// implementing traits `AsyncRead` and `AsyncWrite`.
pub struct IoHandle<T: Evented> { pub struct Watcher<T: Evented> {
/// Data associated with the I/O handle. /// Data associated with the I/O handle.
entry: Arc<Entry>, entry: Arc<Entry>,
@ -186,13 +175,13 @@ pub struct IoHandle<T: Evented> {
source: Option<T>, source: Option<T>,
} }
impl<T: Evented> IoHandle<T> { impl<T: Evented> Watcher<T> {
/// Creates a new I/O handle. /// Creates a new I/O handle.
/// ///
/// The provided I/O event source will be kept registered inside the reactor's poller for the /// The provided I/O event source will be kept registered inside the reactor's poller for the
/// lifetime of the returned I/O handle. /// lifetime of the returned I/O handle.
pub fn new(source: T) -> IoHandle<T> { pub fn new(source: T) -> Watcher<T> {
IoHandle { Watcher {
entry: REACTOR entry: REACTOR
.register(&source) .register(&source)
.expect("cannot register an I/O event source"), .expect("cannot register an I/O event source"),
@ -205,91 +194,75 @@ impl<T: Evented> IoHandle<T> {
self.source.as_ref().unwrap() self.source.as_ref().unwrap()
} }
/// Polls the I/O handle for reading. /// Polls the inner I/O source for a non-blocking read operation.
/// ///
/// If reading from the I/O handle would block, `Poll::Pending` will be returned. /// If the operation returns an error of the `io::ErrorKind::WouldBlock` kind, the current task
pub fn poll_readable(&self, cx: &mut Context<'_>) -> Poll<io::Result<()>> { /// will be registered for wakeup when the I/O source becomes readable.
let mask = reader_interests(); pub fn poll_read_with<'a, F, R>(&'a self, cx: &mut Context<'_>, mut f: F) -> Poll<io::Result<R>>
let mut readiness = mio::Ready::from_usize(self.entry.readiness.load(Ordering::SeqCst)); where
F: FnMut(&'a T) -> io::Result<R>,
if (readiness & mask).is_empty() { {
let mut list = self.entry.readers.lock().unwrap(); // If the operation isn't blocked, return its result.
if list.iter().all(|w| !w.will_wake(cx.waker())) { match f(self.source.as_ref().unwrap()) {
list.push(cx.waker().clone()); Err(err) if err.kind() == io::ErrorKind::WouldBlock => {}
} res => return Poll::Ready(res),
readiness = mio::Ready::from_usize(self.entry.readiness.fetch_or(0, Ordering::SeqCst));
} }
if (readiness & mask).is_empty() { // Lock the waker list.
Poll::Pending let mut list = self.entry.readers.lock().unwrap();
} else {
Poll::Ready(Ok(())) // Try running the operation again.
match f(self.source.as_ref().unwrap()) {
Err(err) if err.kind() == io::ErrorKind::WouldBlock => {}
res => return Poll::Ready(res),
} }
// Register the task if it isn't registered already.
if list.iter().all(|w| !w.will_wake(cx.waker())) {
list.push(cx.waker().clone());
}
Poll::Pending
} }
/// Clears the readability status. /// Polls the inner I/O source for a non-blocking write operation.
/// ///
/// This method is usually called when an attempt at reading from the OS-level I/O handle /// If the operation returns an error of the `io::ErrorKind::WouldBlock` kind, the current task
/// returns `io::ErrorKind::WouldBlock`. /// will be registered for wakeup when the I/O source becomes writable.
pub fn clear_readable(&self, cx: &mut Context<'_>) -> io::Result<()> { pub fn poll_write_with<'a, F, R>(
let mask = reader_interests() - hup(); &'a self,
self.entry cx: &mut Context<'_>,
.readiness mut f: F,
.fetch_and(!mask.as_usize(), Ordering::SeqCst); ) -> Poll<io::Result<R>>
where
if self.poll_readable(cx)?.is_ready() { F: FnMut(&'a T) -> io::Result<R>,
// Wake the current task. {
cx.waker().wake_by_ref(); // If the operation isn't blocked, return its result.
match f(self.source.as_ref().unwrap()) {
Err(err) if err.kind() == io::ErrorKind::WouldBlock => {}
res => return Poll::Ready(res),
} }
Ok(()) // Lock the waker list.
} let mut list = self.entry.writers.lock().unwrap();
/// Polls the I/O handle for writing. // Try running the operation again.
/// match f(self.source.as_ref().unwrap()) {
/// If writing into the I/O handle would block, `Poll::Pending` will be returned. Err(err) if err.kind() == io::ErrorKind::WouldBlock => {}
pub fn poll_writable(&self, cx: &mut Context<'_>) -> Poll<io::Result<()>> { res => return Poll::Ready(res),
let mask = writer_interests();
let mut readiness = mio::Ready::from_usize(self.entry.readiness.load(Ordering::SeqCst));
if (readiness & mask).is_empty() {
let mut list = self.entry.writers.lock().unwrap();
if list.iter().all(|w| !w.will_wake(cx.waker())) {
list.push(cx.waker().clone());
}
readiness = mio::Ready::from_usize(self.entry.readiness.fetch_or(0, Ordering::SeqCst));
} }
if (readiness & mask).is_empty() { // Register the task if it isn't registered already.
Poll::Pending if list.iter().all(|w| !w.will_wake(cx.waker())) {
} else { list.push(cx.waker().clone());
Poll::Ready(Ok(()))
}
}
/// Clears the writability status.
///
/// This method is usually called when an attempt at writing from the OS-level I/O handle
/// returns `io::ErrorKind::WouldBlock`.
pub fn clear_writable(&self, cx: &mut Context<'_>) -> io::Result<()> {
let mask = writer_interests() - hup();
self.entry
.readiness
.fetch_and(!mask.as_usize(), Ordering::SeqCst);
if self.poll_writable(cx)?.is_ready() {
// Wake the current task.
cx.waker().wake_by_ref();
} }
Ok(()) Poll::Pending
} }
/// Deregisters and returns the inner I/O source. /// Deregisters and returns the inner I/O source.
/// ///
/// This method is typically used to convert `IoHandle`s to raw file descriptors/handles. /// This method is typically used to convert `Watcher`s to raw file descriptors/handles.
pub fn into_inner(mut self) -> T { pub fn into_inner(mut self) -> T {
let source = self.source.take().unwrap(); let source = self.source.take().unwrap();
REACTOR REACTOR
@ -299,7 +272,7 @@ impl<T: Evented> IoHandle<T> {
} }
} }
impl<T: Evented> Drop for IoHandle<T> { impl<T: Evented> Drop for Watcher<T> {
fn drop(&mut self) { fn drop(&mut self) {
if let Some(ref source) = self.source { if let Some(ref source) = self.source {
REACTOR REACTOR
@ -309,125 +282,15 @@ impl<T: Evented> Drop for IoHandle<T> {
} }
} }
impl<T: Evented + fmt::Debug> fmt::Debug for IoHandle<T> { impl<T: Evented + fmt::Debug> fmt::Debug for Watcher<T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("IoHandle") f.debug_struct("Watcher")
.field("entry", &self.entry) .field("entry", &self.entry)
.field("source", &self.source) .field("source", &self.source)
.finish() .finish()
} }
} }
impl<T: Evented + std::io::Read + Unpin> AsyncRead for IoHandle<T> {
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut [u8],
) -> Poll<io::Result<usize>> {
futures_core::ready!(Pin::new(&mut *self).poll_readable(cx)?);
match self.source.as_mut().unwrap().read(buf) {
Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => {
self.clear_readable(cx)?;
Poll::Pending
}
res => Poll::Ready(res),
}
}
}
impl<'a, T: Evented + Unpin> AsyncRead for &'a IoHandle<T>
where
&'a T: std::io::Read,
{
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut [u8],
) -> Poll<io::Result<usize>> {
futures_core::ready!(Pin::new(&mut *self).poll_readable(cx)?);
match self.source.as_ref().unwrap().read(buf) {
Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => {
self.clear_readable(cx)?;
Poll::Pending
}
res => Poll::Ready(res),
}
}
}
impl<T: Evented + std::io::Write + Unpin> AsyncWrite for IoHandle<T> {
fn poll_write(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
futures_core::ready!(self.poll_writable(cx)?);
match self.source.as_mut().unwrap().write(buf) {
Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => {
self.clear_writable(cx)?;
Poll::Pending
}
res => Poll::Ready(res),
}
}
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
futures_core::ready!(self.poll_writable(cx)?);
match self.source.as_mut().unwrap().flush() {
Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => {
self.clear_writable(cx)?;
Poll::Pending
}
res => Poll::Ready(res),
}
}
fn poll_close(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
Poll::Ready(Ok(()))
}
}
impl<'a, T: Evented + Unpin> AsyncWrite for &'a IoHandle<T>
where
&'a T: std::io::Write,
{
fn poll_write(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
futures_core::ready!(self.poll_writable(cx)?);
match self.get_ref().write(buf) {
Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => {
self.clear_writable(cx)?;
Poll::Pending
}
res => Poll::Ready(res),
}
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
futures_core::ready!(self.poll_writable(cx)?);
match self.get_ref().flush() {
Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => {
self.clear_writable(cx)?;
Poll::Pending
}
res => Poll::Ready(res),
}
}
fn poll_close(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
Poll::Ready(Ok(()))
}
}
/// Returns a mask containing flags that interest tasks reading from I/O handles. /// Returns a mask containing flags that interest tasks reading from I/O handles.
#[inline] #[inline]
fn reader_interests() -> mio::Ready { fn reader_interests() -> mio::Ready {

View file

@ -6,7 +6,7 @@ use cfg_if::cfg_if;
use super::TcpStream; use super::TcpStream;
use crate::future::{self, Future}; use crate::future::{self, Future};
use crate::io; use crate::io;
use crate::net::driver::IoHandle; use crate::net::driver::Watcher;
use crate::net::ToSocketAddrs; use crate::net::ToSocketAddrs;
use crate::task::{Context, Poll}; use crate::task::{Context, Poll};
@ -49,9 +49,7 @@ use crate::task::{Context, Poll};
/// ``` /// ```
#[derive(Debug)] #[derive(Debug)]
pub struct TcpListener { pub struct TcpListener {
io_handle: IoHandle<mio::net::TcpListener>, watcher: Watcher<mio::net::TcpListener>,
// #[cfg(windows)]
// raw_socket: std::os::windows::io::RawSocket,
} }
impl TcpListener { impl TcpListener {
@ -82,17 +80,9 @@ impl TcpListener {
for addr in addrs.to_socket_addrs().await? { for addr in addrs.to_socket_addrs().await? {
match mio::net::TcpListener::bind(&addr) { match mio::net::TcpListener::bind(&addr) {
Ok(mio_listener) => { Ok(mio_listener) => {
#[cfg(unix)] return Ok(TcpListener {
let listener = TcpListener { watcher: Watcher::new(mio_listener),
io_handle: IoHandle::new(mio_listener), });
};
#[cfg(windows)]
let listener = TcpListener {
// raw_socket: mio_listener.as_raw_socket(),
io_handle: IoHandle::new(mio_listener),
};
return Ok(listener);
} }
Err(err) => last_err = Some(err), Err(err) => last_err = Some(err),
} }
@ -123,34 +113,15 @@ impl TcpListener {
/// # Ok(()) }) } /// # Ok(()) }) }
/// ``` /// ```
pub async fn accept(&self) -> io::Result<(TcpStream, SocketAddr)> { pub async fn accept(&self) -> io::Result<(TcpStream, SocketAddr)> {
future::poll_fn(|cx| { let (io, addr) =
futures_core::ready!(self.io_handle.poll_readable(cx)?); future::poll_fn(|cx| self.watcher.poll_read_with(cx, |inner| inner.accept_std()))
.await?;
match self.io_handle.get_ref().accept_std() { let mio_stream = mio::net::TcpStream::from_stream(io)?;
Ok((io, addr)) => { let stream = TcpStream {
let mio_stream = mio::net::TcpStream::from_stream(io)?; watcher: Watcher::new(mio_stream),
};
#[cfg(unix)] Ok((stream, addr))
let stream = TcpStream {
io_handle: IoHandle::new(mio_stream),
};
#[cfg(windows)]
let stream = TcpStream {
// raw_socket: mio_stream.as_raw_socket(),
io_handle: IoHandle::new(mio_stream),
};
Poll::Ready(Ok((stream, addr)))
}
Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => {
self.io_handle.clear_readable(cx)?;
Poll::Pending
}
Err(err) => Poll::Ready(Err(err)),
}
})
.await
} }
/// Returns a stream of incoming connections. /// Returns a stream of incoming connections.
@ -201,7 +172,7 @@ impl TcpListener {
/// # Ok(()) }) } /// # Ok(()) }) }
/// ``` /// ```
pub fn local_addr(&self) -> io::Result<SocketAddr> { pub fn local_addr(&self) -> io::Result<SocketAddr> {
self.io_handle.get_ref().local_addr() self.watcher.get_ref().local_addr()
} }
} }
@ -235,19 +206,9 @@ impl From<std::net::TcpListener> for TcpListener {
/// Converts a `std::net::TcpListener` into its asynchronous equivalent. /// Converts a `std::net::TcpListener` into its asynchronous equivalent.
fn from(listener: std::net::TcpListener) -> TcpListener { fn from(listener: std::net::TcpListener) -> TcpListener {
let mio_listener = mio::net::TcpListener::from_std(listener).unwrap(); let mio_listener = mio::net::TcpListener::from_std(listener).unwrap();
TcpListener {
#[cfg(unix)] watcher: Watcher::new(mio_listener),
let listener = TcpListener { }
io_handle: IoHandle::new(mio_listener),
};
#[cfg(windows)]
let listener = TcpListener {
// raw_socket: mio_listener.as_raw_socket(),
io_handle: IoHandle::new(mio_listener),
};
listener
} }
} }
@ -267,7 +228,7 @@ cfg_if! {
if #[cfg(any(unix, feature = "docs"))] { if #[cfg(any(unix, feature = "docs"))] {
impl AsRawFd for TcpListener { impl AsRawFd for TcpListener {
fn as_raw_fd(&self) -> RawFd { fn as_raw_fd(&self) -> RawFd {
self.io_handle.get_ref().as_raw_fd() self.watcher.get_ref().as_raw_fd()
} }
} }
@ -279,7 +240,7 @@ cfg_if! {
impl IntoRawFd for TcpListener { impl IntoRawFd for TcpListener {
fn into_raw_fd(self) -> RawFd { fn into_raw_fd(self) -> RawFd {
self.io_handle.into_inner().into_raw_fd() self.watcher.into_inner().into_raw_fd()
} }
} }
} }

View file

@ -1,5 +1,4 @@
use std::io::{IoSlice, IoSliceMut}; use std::io::{IoSlice, IoSliceMut, Read as _, Write as _};
use std::mem;
use std::net::SocketAddr; use std::net::SocketAddr;
use std::pin::Pin; use std::pin::Pin;
@ -8,8 +7,9 @@ use futures_io::{AsyncRead, AsyncWrite};
use crate::future; use crate::future;
use crate::io; use crate::io;
use crate::net::driver::IoHandle; use crate::net::driver::Watcher;
use crate::net::ToSocketAddrs; use crate::net::ToSocketAddrs;
use crate::task::blocking;
use crate::task::{Context, Poll}; use crate::task::{Context, Poll};
/// A TCP stream between a local and a remote socket. /// A TCP stream between a local and a remote socket.
@ -50,9 +50,7 @@ use crate::task::{Context, Poll};
/// ``` /// ```
#[derive(Debug)] #[derive(Debug)]
pub struct TcpStream { pub struct TcpStream {
pub(super) io_handle: IoHandle<mio::net::TcpStream>, pub(super) watcher: Watcher<mio::net::TcpStream>,
// #[cfg(windows)]
// pub(super) raw_socket: std::os::windows::io::RawSocket,
} }
impl TcpStream { impl TcpStream {
@ -79,7 +77,14 @@ impl TcpStream {
let mut last_err = None; let mut last_err = None;
for addr in addrs.to_socket_addrs().await? { for addr in addrs.to_socket_addrs().await? {
let res = Self::connect_to(addr).await; let res = blocking::spawn(async move {
let std_stream = std::net::TcpStream::connect(addr)?;
let mio_stream = mio::net::TcpStream::from_stream(std_stream)?;
Ok(TcpStream {
watcher: Watcher::new(mio_stream),
})
})
.await;
match res { match res {
Ok(stream) => return Ok(stream), Ok(stream) => return Ok(stream),
@ -95,59 +100,6 @@ impl TcpStream {
})) }))
} }
/// Creates a new TCP stream connected to the specified address.
async fn connect_to(addr: SocketAddr) -> io::Result<TcpStream> {
let stream = mio::net::TcpStream::connect(&addr).map(|mio_stream| {
#[cfg(unix)]
let stream = TcpStream {
io_handle: IoHandle::new(mio_stream),
};
#[cfg(windows)]
let stream = TcpStream {
// raw_socket: mio_stream.as_raw_socket(),
io_handle: IoHandle::new(mio_stream),
};
stream
});
enum State {
Waiting(TcpStream),
Error(io::Error),
Done,
}
let mut state = match stream {
Ok(stream) => State::Waiting(stream),
Err(err) => State::Error(err),
};
future::poll_fn(|cx| {
match mem::replace(&mut state, State::Done) {
State::Waiting(stream) => {
// Once we've connected, wait for the stream to be writable as that's when
// the actual connection has been initiated. Once we're writable we check
// for `take_socket_error` to see if the connect actually hit an error or
// not.
//
// If all that succeeded then we ship everything on up.
if let Poll::Pending = stream.io_handle.poll_writable(cx)? {
state = State::Waiting(stream);
return Poll::Pending;
}
if let Some(err) = stream.io_handle.get_ref().take_error()? {
return Poll::Ready(Err(err));
}
Poll::Ready(Ok(stream))
}
State::Error(err) => Poll::Ready(Err(err)),
State::Done => panic!("`TcpStream::connect_to()` future polled after completion"),
}
})
.await
}
/// Returns the local address that this stream is connected to. /// Returns the local address that this stream is connected to.
/// ///
/// ## Examples /// ## Examples
@ -163,7 +115,7 @@ impl TcpStream {
/// # Ok(()) }) } /// # Ok(()) }) }
/// ``` /// ```
pub fn local_addr(&self) -> io::Result<SocketAddr> { pub fn local_addr(&self) -> io::Result<SocketAddr> {
self.io_handle.get_ref().local_addr() self.watcher.get_ref().local_addr()
} }
/// Returns the remote address that this stream is connected to. /// Returns the remote address that this stream is connected to.
@ -181,7 +133,7 @@ impl TcpStream {
/// # Ok(()) }) } /// # Ok(()) }) }
/// ``` /// ```
pub fn peer_addr(&self) -> io::Result<SocketAddr> { pub fn peer_addr(&self) -> io::Result<SocketAddr> {
self.io_handle.get_ref().peer_addr() self.watcher.get_ref().peer_addr()
} }
/// Gets the value of the `IP_TTL` option for this socket. /// Gets the value of the `IP_TTL` option for this socket.
@ -205,7 +157,7 @@ impl TcpStream {
/// # Ok(()) }) } /// # Ok(()) }) }
/// ``` /// ```
pub fn ttl(&self) -> io::Result<u32> { pub fn ttl(&self) -> io::Result<u32> {
self.io_handle.get_ref().ttl() self.watcher.get_ref().ttl()
} }
/// Sets the value for the `IP_TTL` option on this socket. /// Sets the value for the `IP_TTL` option on this socket.
@ -228,7 +180,7 @@ impl TcpStream {
/// # Ok(()) }) } /// # Ok(()) }) }
/// ``` /// ```
pub fn set_ttl(&self, ttl: u32) -> io::Result<()> { pub fn set_ttl(&self, ttl: u32) -> io::Result<()> {
self.io_handle.get_ref().set_ttl(ttl) self.watcher.get_ref().set_ttl(ttl)
} }
/// Receives data on the socket from the remote address to which it is connected, without /// Receives data on the socket from the remote address to which it is connected, without
@ -254,20 +206,7 @@ impl TcpStream {
/// # Ok(()) }) } /// # Ok(()) }) }
/// ``` /// ```
pub async fn peek(&self, buf: &mut [u8]) -> io::Result<usize> { pub async fn peek(&self, buf: &mut [u8]) -> io::Result<usize> {
let res = future::poll_fn(|cx| { future::poll_fn(|cx| self.watcher.poll_read_with(cx, |inner| inner.peek(buf))).await
futures_core::ready!(self.io_handle.poll_readable(cx)?);
match self.io_handle.get_ref().peek(buf) {
Ok(len) => Poll::Ready(Ok(len)),
Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => {
self.io_handle.clear_readable(cx)?;
Poll::Pending
}
Err(e) => Poll::Ready(Err(e)),
}
})
.await?;
Ok(res)
} }
/// Gets the value of the `TCP_NODELAY` option on this socket. /// Gets the value of the `TCP_NODELAY` option on this socket.
@ -291,7 +230,7 @@ impl TcpStream {
/// # Ok(()) }) } /// # Ok(()) }) }
/// ``` /// ```
pub fn nodelay(&self) -> io::Result<bool> { pub fn nodelay(&self) -> io::Result<bool> {
self.io_handle.get_ref().nodelay() self.watcher.get_ref().nodelay()
} }
/// Sets the value of the `TCP_NODELAY` option on this socket. /// Sets the value of the `TCP_NODELAY` option on this socket.
@ -317,7 +256,7 @@ impl TcpStream {
/// # Ok(()) }) } /// # Ok(()) }) }
/// ``` /// ```
pub fn set_nodelay(&self, nodelay: bool) -> io::Result<()> { pub fn set_nodelay(&self, nodelay: bool) -> io::Result<()> {
self.io_handle.get_ref().set_nodelay(nodelay) self.watcher.get_ref().set_nodelay(nodelay)
} }
/// Shuts down the read, write, or both halves of this connection. /// Shuts down the read, write, or both halves of this connection.
@ -342,7 +281,7 @@ impl TcpStream {
/// # Ok(()) }) } /// # Ok(()) }) }
/// ``` /// ```
pub fn shutdown(&self, how: std::net::Shutdown) -> std::io::Result<()> { pub fn shutdown(&self, how: std::net::Shutdown) -> std::io::Result<()> {
self.io_handle.get_ref().shutdown(how) self.watcher.get_ref().shutdown(how)
} }
} }
@ -370,15 +309,7 @@ impl AsyncRead for &TcpStream {
cx: &mut Context<'_>, cx: &mut Context<'_>,
buf: &mut [u8], buf: &mut [u8],
) -> Poll<io::Result<usize>> { ) -> Poll<io::Result<usize>> {
Pin::new(&mut &self.io_handle).poll_read(cx, buf) self.watcher.poll_read_with(cx, |mut inner| inner.read(buf))
}
fn poll_read_vectored(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
bufs: &mut [IoSliceMut<'_>],
) -> Poll<io::Result<usize>> {
Pin::new(&mut &self.io_handle).poll_read_vectored(cx, bufs)
} }
} }
@ -414,23 +345,15 @@ impl AsyncWrite for &TcpStream {
cx: &mut Context<'_>, cx: &mut Context<'_>,
buf: &[u8], buf: &[u8],
) -> Poll<io::Result<usize>> { ) -> Poll<io::Result<usize>> {
Pin::new(&mut &self.io_handle).poll_write(cx, buf) self.watcher.poll_write_with(cx, |mut inner| inner.write(buf))
}
fn poll_write_vectored(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
bufs: &[IoSlice<'_>],
) -> Poll<io::Result<usize>> {
Pin::new(&mut &self.io_handle).poll_write_vectored(cx, bufs)
} }
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> { fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
Pin::new(&mut &self.io_handle).poll_flush(cx) self.watcher.poll_write_with(cx, |mut inner| inner.flush())
} }
fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> { fn poll_close(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<io::Result<()>> {
Pin::new(&mut &self.io_handle).poll_close(cx) Poll::Ready(Ok(()))
} }
} }
@ -438,19 +361,9 @@ impl From<std::net::TcpStream> for TcpStream {
/// Converts a `std::net::TcpStream` into its asynchronous equivalent. /// Converts a `std::net::TcpStream` into its asynchronous equivalent.
fn from(stream: std::net::TcpStream) -> TcpStream { fn from(stream: std::net::TcpStream) -> TcpStream {
let mio_stream = mio::net::TcpStream::from_stream(stream).unwrap(); let mio_stream = mio::net::TcpStream::from_stream(stream).unwrap();
TcpStream {
#[cfg(unix)] watcher: Watcher::new(mio_stream),
let stream = TcpStream { }
io_handle: IoHandle::new(mio_stream),
};
#[cfg(windows)]
let stream = TcpStream {
// raw_socket: mio_stream.as_raw_socket(),
io_handle: IoHandle::new(mio_stream),
};
stream
} }
} }
@ -470,7 +383,7 @@ cfg_if! {
if #[cfg(any(unix, feature = "docs"))] { if #[cfg(any(unix, feature = "docs"))] {
impl AsRawFd for TcpStream { impl AsRawFd for TcpStream {
fn as_raw_fd(&self) -> RawFd { fn as_raw_fd(&self) -> RawFd {
self.io_handle.get_ref().as_raw_fd() self.watcher.get_ref().as_raw_fd()
} }
} }
@ -482,7 +395,7 @@ cfg_if! {
impl IntoRawFd for TcpStream { impl IntoRawFd for TcpStream {
fn into_raw_fd(self) -> RawFd { fn into_raw_fd(self) -> RawFd {
self.io_handle.into_inner().into_raw_fd() self.watcher.into_inner().into_raw_fd()
} }
} }
} }

View file

@ -5,9 +5,8 @@ use cfg_if::cfg_if;
use std::net::{Ipv4Addr, Ipv6Addr}; use std::net::{Ipv4Addr, Ipv6Addr};
use crate::future; use crate::future;
use crate::net::driver::IoHandle; use crate::net::driver::Watcher;
use crate::net::ToSocketAddrs; use crate::net::ToSocketAddrs;
use crate::task::Poll;
/// A UDP socket. /// A UDP socket.
/// ///
@ -47,9 +46,7 @@ use crate::task::Poll;
/// ``` /// ```
#[derive(Debug)] #[derive(Debug)]
pub struct UdpSocket { pub struct UdpSocket {
io_handle: IoHandle<mio::net::UdpSocket>, watcher: Watcher<mio::net::UdpSocket>,
// #[cfg(windows)]
// raw_socket: std::os::windows::io::RawSocket,
} }
impl UdpSocket { impl UdpSocket {
@ -77,18 +74,9 @@ impl UdpSocket {
for addr in addr.to_socket_addrs().await? { for addr in addr.to_socket_addrs().await? {
match mio::net::UdpSocket::bind(&addr) { match mio::net::UdpSocket::bind(&addr) {
Ok(mio_socket) => { Ok(mio_socket) => {
#[cfg(unix)] return Ok(UdpSocket {
let socket = UdpSocket { watcher: Watcher::new(mio_socket),
io_handle: IoHandle::new(mio_socket), });
};
#[cfg(windows)]
let socket = UdpSocket {
// raw_socket: mio_socket.as_raw_socket(),
io_handle: IoHandle::new(mio_socket),
};
return Ok(socket);
} }
Err(err) => last_err = Some(err), Err(err) => last_err = Some(err),
} }
@ -120,7 +108,7 @@ impl UdpSocket {
/// # Ok(()) }) } /// # Ok(()) }) }
/// ``` /// ```
pub fn local_addr(&self) -> io::Result<SocketAddr> { pub fn local_addr(&self) -> io::Result<SocketAddr> {
self.io_handle.get_ref().local_addr() self.watcher.get_ref().local_addr()
} }
/// Sends data on the socket to the given address. /// Sends data on the socket to the given address.
@ -161,16 +149,8 @@ impl UdpSocket {
}; };
future::poll_fn(|cx| { future::poll_fn(|cx| {
futures_core::ready!(self.io_handle.poll_writable(cx)?); self.watcher
.poll_write_with(cx, |inner| inner.send_to(buf, &addr))
match self.io_handle.get_ref().send_to(buf, &addr) {
Ok(n) => Poll::Ready(Ok(n)),
Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => {
self.io_handle.clear_writable(cx)?;
Poll::Pending
}
Err(err) => Poll::Ready(Err(err)),
}
}) })
.await .await
} }
@ -196,16 +176,8 @@ impl UdpSocket {
/// ``` /// ```
pub async fn recv_from(&self, buf: &mut [u8]) -> io::Result<(usize, SocketAddr)> { pub async fn recv_from(&self, buf: &mut [u8]) -> io::Result<(usize, SocketAddr)> {
future::poll_fn(|cx| { future::poll_fn(|cx| {
futures_core::ready!(self.io_handle.poll_readable(cx)?); self.watcher
.poll_read_with(cx, |inner| inner.recv_from(buf))
match self.io_handle.get_ref().recv_from(buf) {
Ok(n) => Poll::Ready(Ok(n)),
Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => {
self.io_handle.clear_readable(cx)?;
Poll::Pending
}
Err(err) => Poll::Ready(Err(err)),
}
}) })
.await .await
} }
@ -236,7 +208,8 @@ impl UdpSocket {
let mut last_err = None; let mut last_err = None;
for addr in addrs.to_socket_addrs().await? { for addr in addrs.to_socket_addrs().await? {
match self.io_handle.get_ref().connect(addr) { // TODO(stjepang): connect on the blocking pool
match self.watcher.get_ref().connect(addr) {
Ok(()) => return Ok(()), Ok(()) => return Ok(()),
Err(err) => last_err = Some(err), Err(err) => last_err = Some(err),
} }
@ -277,19 +250,7 @@ impl UdpSocket {
/// # Ok(()) }) } /// # Ok(()) }) }
/// ``` /// ```
pub async fn send(&self, buf: &[u8]) -> io::Result<usize> { pub async fn send(&self, buf: &[u8]) -> io::Result<usize> {
future::poll_fn(|cx| { future::poll_fn(|cx| self.watcher.poll_write_with(cx, |inner| inner.send(buf))).await
futures_core::ready!(self.io_handle.poll_writable(cx)?);
match self.io_handle.get_ref().send(buf) {
Ok(n) => Poll::Ready(Ok(n)),
Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => {
self.io_handle.clear_writable(cx)?;
Poll::Pending
}
Err(err) => Poll::Ready(Err(err)),
}
})
.await
} }
/// Receives data from the socket. /// Receives data from the socket.
@ -312,19 +273,7 @@ impl UdpSocket {
/// # Ok(()) }) } /// # Ok(()) }) }
/// ``` /// ```
pub async fn recv(&self, buf: &mut [u8]) -> io::Result<usize> { pub async fn recv(&self, buf: &mut [u8]) -> io::Result<usize> {
future::poll_fn(|cx| { future::poll_fn(|cx| self.watcher.poll_read_with(cx, |inner| inner.recv(buf))).await
futures_core::ready!(self.io_handle.poll_readable(cx)?);
match self.io_handle.get_ref().recv(buf) {
Ok(n) => Poll::Ready(Ok(n)),
Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => {
self.io_handle.clear_readable(cx)?;
Poll::Pending
}
Err(err) => Poll::Ready(Err(err)),
}
})
.await
} }
/// Gets the value of the `SO_BROADCAST` option for this socket. /// Gets the value of the `SO_BROADCAST` option for this socket.
@ -333,14 +282,14 @@ impl UdpSocket {
/// ///
/// [`set_broadcast`]: #method.set_broadcast /// [`set_broadcast`]: #method.set_broadcast
pub fn broadcast(&self) -> io::Result<bool> { pub fn broadcast(&self) -> io::Result<bool> {
self.io_handle.get_ref().broadcast() self.watcher.get_ref().broadcast()
} }
/// Sets the value of the `SO_BROADCAST` option for this socket. /// Sets the value of the `SO_BROADCAST` option for this socket.
/// ///
/// When enabled, this socket is allowed to send packets to a broadcast address. /// When enabled, this socket is allowed to send packets to a broadcast address.
pub fn set_broadcast(&self, on: bool) -> io::Result<()> { pub fn set_broadcast(&self, on: bool) -> io::Result<()> {
self.io_handle.get_ref().set_broadcast(on) self.watcher.get_ref().set_broadcast(on)
} }
/// Gets the value of the `IP_MULTICAST_LOOP` option for this socket. /// Gets the value of the `IP_MULTICAST_LOOP` option for this socket.
@ -349,7 +298,7 @@ impl UdpSocket {
/// ///
/// [`set_multicast_loop_v4`]: #method.set_multicast_loop_v4 /// [`set_multicast_loop_v4`]: #method.set_multicast_loop_v4
pub fn multicast_loop_v4(&self) -> io::Result<bool> { pub fn multicast_loop_v4(&self) -> io::Result<bool> {
self.io_handle.get_ref().multicast_loop_v4() self.watcher.get_ref().multicast_loop_v4()
} }
/// Sets the value of the `IP_MULTICAST_LOOP` option for this socket. /// Sets the value of the `IP_MULTICAST_LOOP` option for this socket.
@ -360,7 +309,7 @@ impl UdpSocket {
/// ///
/// This may not have any affect on IPv6 sockets. /// This may not have any affect on IPv6 sockets.
pub fn set_multicast_loop_v4(&self, on: bool) -> io::Result<()> { pub fn set_multicast_loop_v4(&self, on: bool) -> io::Result<()> {
self.io_handle.get_ref().set_multicast_loop_v4(on) self.watcher.get_ref().set_multicast_loop_v4(on)
} }
/// Gets the value of the `IP_MULTICAST_TTL` option for this socket. /// Gets the value of the `IP_MULTICAST_TTL` option for this socket.
@ -369,7 +318,7 @@ impl UdpSocket {
/// ///
/// [`set_multicast_ttl_v4`]: #method.set_multicast_ttl_v4 /// [`set_multicast_ttl_v4`]: #method.set_multicast_ttl_v4
pub fn multicast_ttl_v4(&self) -> io::Result<u32> { pub fn multicast_ttl_v4(&self) -> io::Result<u32> {
self.io_handle.get_ref().multicast_ttl_v4() self.watcher.get_ref().multicast_ttl_v4()
} }
/// Sets the value of the `IP_MULTICAST_TTL` option for this socket. /// Sets the value of the `IP_MULTICAST_TTL` option for this socket.
@ -382,7 +331,7 @@ impl UdpSocket {
/// ///
/// This may not have any affect on IPv6 sockets. /// This may not have any affect on IPv6 sockets.
pub fn set_multicast_ttl_v4(&self, ttl: u32) -> io::Result<()> { pub fn set_multicast_ttl_v4(&self, ttl: u32) -> io::Result<()> {
self.io_handle.get_ref().set_multicast_ttl_v4(ttl) self.watcher.get_ref().set_multicast_ttl_v4(ttl)
} }
/// Gets the value of the `IPV6_MULTICAST_LOOP` option for this socket. /// Gets the value of the `IPV6_MULTICAST_LOOP` option for this socket.
@ -391,7 +340,7 @@ impl UdpSocket {
/// ///
/// [`set_multicast_loop_v6`]: #method.set_multicast_loop_v6 /// [`set_multicast_loop_v6`]: #method.set_multicast_loop_v6
pub fn multicast_loop_v6(&self) -> io::Result<bool> { pub fn multicast_loop_v6(&self) -> io::Result<bool> {
self.io_handle.get_ref().multicast_loop_v6() self.watcher.get_ref().multicast_loop_v6()
} }
/// Sets the value of the `IPV6_MULTICAST_LOOP` option for this socket. /// Sets the value of the `IPV6_MULTICAST_LOOP` option for this socket.
@ -402,7 +351,7 @@ impl UdpSocket {
/// ///
/// This may not have any affect on IPv4 sockets. /// This may not have any affect on IPv4 sockets.
pub fn set_multicast_loop_v6(&self, on: bool) -> io::Result<()> { pub fn set_multicast_loop_v6(&self, on: bool) -> io::Result<()> {
self.io_handle.get_ref().set_multicast_loop_v6(on) self.watcher.get_ref().set_multicast_loop_v6(on)
} }
/// Gets the value of the `IP_TTL` option for this socket. /// Gets the value of the `IP_TTL` option for this socket.
@ -411,7 +360,7 @@ impl UdpSocket {
/// ///
/// [`set_ttl`]: #method.set_ttl /// [`set_ttl`]: #method.set_ttl
pub fn ttl(&self) -> io::Result<u32> { pub fn ttl(&self) -> io::Result<u32> {
self.io_handle.get_ref().ttl() self.watcher.get_ref().ttl()
} }
/// Sets the value for the `IP_TTL` option on this socket. /// Sets the value for the `IP_TTL` option on this socket.
@ -419,7 +368,7 @@ impl UdpSocket {
/// This value sets the time-to-live field that is used in every packet sent /// This value sets the time-to-live field that is used in every packet sent
/// from this socket. /// from this socket.
pub fn set_ttl(&self, ttl: u32) -> io::Result<()> { pub fn set_ttl(&self, ttl: u32) -> io::Result<()> {
self.io_handle.get_ref().set_ttl(ttl) self.watcher.get_ref().set_ttl(ttl)
} }
/// Executes an operation of the `IP_ADD_MEMBERSHIP` type. /// Executes an operation of the `IP_ADD_MEMBERSHIP` type.
@ -447,7 +396,7 @@ impl UdpSocket {
/// # Ok(()) }) } /// # Ok(()) }) }
/// ``` /// ```
pub fn join_multicast_v4(&self, multiaddr: &Ipv4Addr, interface: &Ipv4Addr) -> io::Result<()> { pub fn join_multicast_v4(&self, multiaddr: &Ipv4Addr, interface: &Ipv4Addr) -> io::Result<()> {
self.io_handle self.watcher
.get_ref() .get_ref()
.join_multicast_v4(multiaddr, interface) .join_multicast_v4(multiaddr, interface)
} }
@ -476,7 +425,7 @@ impl UdpSocket {
/// # Ok(()) }) } /// # Ok(()) }) }
/// ``` /// ```
pub fn join_multicast_v6(&self, multiaddr: &Ipv6Addr, interface: u32) -> io::Result<()> { pub fn join_multicast_v6(&self, multiaddr: &Ipv6Addr, interface: u32) -> io::Result<()> {
self.io_handle self.watcher
.get_ref() .get_ref()
.join_multicast_v6(multiaddr, interface) .join_multicast_v6(multiaddr, interface)
} }
@ -487,7 +436,7 @@ impl UdpSocket {
/// ///
/// [`join_multicast_v4`]: #method.join_multicast_v4 /// [`join_multicast_v4`]: #method.join_multicast_v4
pub fn leave_multicast_v4(&self, multiaddr: &Ipv4Addr, interface: &Ipv4Addr) -> io::Result<()> { pub fn leave_multicast_v4(&self, multiaddr: &Ipv4Addr, interface: &Ipv4Addr) -> io::Result<()> {
self.io_handle self.watcher
.get_ref() .get_ref()
.leave_multicast_v4(multiaddr, interface) .leave_multicast_v4(multiaddr, interface)
} }
@ -498,7 +447,7 @@ impl UdpSocket {
/// ///
/// [`join_multicast_v6`]: #method.join_multicast_v6 /// [`join_multicast_v6`]: #method.join_multicast_v6
pub fn leave_multicast_v6(&self, multiaddr: &Ipv6Addr, interface: u32) -> io::Result<()> { pub fn leave_multicast_v6(&self, multiaddr: &Ipv6Addr, interface: u32) -> io::Result<()> {
self.io_handle self.watcher
.get_ref() .get_ref()
.leave_multicast_v6(multiaddr, interface) .leave_multicast_v6(multiaddr, interface)
} }
@ -508,19 +457,9 @@ impl From<std::net::UdpSocket> for UdpSocket {
/// Converts a `std::net::UdpSocket` into its asynchronous equivalent. /// Converts a `std::net::UdpSocket` into its asynchronous equivalent.
fn from(socket: std::net::UdpSocket) -> UdpSocket { fn from(socket: std::net::UdpSocket) -> UdpSocket {
let mio_socket = mio::net::UdpSocket::from_socket(socket).unwrap(); let mio_socket = mio::net::UdpSocket::from_socket(socket).unwrap();
UdpSocket {
#[cfg(unix)] watcher: Watcher::new(mio_socket),
let socket = UdpSocket { }
io_handle: IoHandle::new(mio_socket),
};
#[cfg(windows)]
let socket = UdpSocket {
// raw_socket: mio_socket.as_raw_socket(),
io_handle: IoHandle::new(mio_socket),
};
socket
} }
} }
@ -540,7 +479,7 @@ cfg_if! {
if #[cfg(any(unix, feature = "docs"))] { if #[cfg(any(unix, feature = "docs"))] {
impl AsRawFd for UdpSocket { impl AsRawFd for UdpSocket {
fn as_raw_fd(&self) -> RawFd { fn as_raw_fd(&self) -> RawFd {
self.io_handle.get_ref().as_raw_fd() self.watcher.get_ref().as_raw_fd()
} }
} }
@ -552,7 +491,7 @@ cfg_if! {
impl IntoRawFd for UdpSocket { impl IntoRawFd for UdpSocket {
fn into_raw_fd(self) -> RawFd { fn into_raw_fd(self) -> RawFd {
self.io_handle.into_inner().into_raw_fd() self.watcher.into_inner().into_raw_fd()
} }
} }
} }

View file

@ -9,9 +9,9 @@ use mio_uds;
use super::SocketAddr; use super::SocketAddr;
use crate::future; use crate::future;
use crate::io; use crate::io;
use crate::net::driver::IoHandle; use crate::net::driver::Watcher;
use crate::os::unix::io::{AsRawFd, FromRawFd, IntoRawFd, RawFd}; use crate::os::unix::io::{AsRawFd, FromRawFd, IntoRawFd, RawFd};
use crate::task::{blocking, Poll}; use crate::task::blocking;
/// A Unix datagram socket. /// A Unix datagram socket.
/// ///
@ -42,15 +42,13 @@ use crate::task::{blocking, Poll};
/// # Ok(()) }) } /// # Ok(()) }) }
/// ``` /// ```
pub struct UnixDatagram { pub struct UnixDatagram {
#[cfg(not(feature = "docs"))] watcher: Watcher<mio_uds::UnixDatagram>,
io_handle: IoHandle<mio_uds::UnixDatagram>,
} }
impl UnixDatagram { impl UnixDatagram {
#[cfg(not(feature = "docs"))]
fn new(socket: mio_uds::UnixDatagram) -> UnixDatagram { fn new(socket: mio_uds::UnixDatagram) -> UnixDatagram {
UnixDatagram { UnixDatagram {
io_handle: IoHandle::new(socket), watcher: Watcher::new(socket),
} }
} }
@ -137,7 +135,7 @@ impl UnixDatagram {
pub async fn connect<P: AsRef<Path>>(&self, path: P) -> io::Result<()> { pub async fn connect<P: AsRef<Path>>(&self, path: P) -> io::Result<()> {
// TODO(stjepang): Connect the socket on a blocking pool. // TODO(stjepang): Connect the socket on a blocking pool.
let p = path.as_ref(); let p = path.as_ref();
self.io_handle.get_ref().connect(p) self.watcher.get_ref().connect(p)
} }
/// Returns the address of this socket. /// Returns the address of this socket.
@ -155,7 +153,7 @@ impl UnixDatagram {
/// # Ok(()) }) } /// # Ok(()) }) }
/// ``` /// ```
pub fn local_addr(&self) -> io::Result<SocketAddr> { pub fn local_addr(&self) -> io::Result<SocketAddr> {
self.io_handle.get_ref().local_addr() self.watcher.get_ref().local_addr()
} }
/// Returns the address of this socket's peer. /// Returns the address of this socket's peer.
@ -178,7 +176,7 @@ impl UnixDatagram {
/// # Ok(()) }) } /// # Ok(()) }) }
/// ``` /// ```
pub fn peer_addr(&self) -> io::Result<SocketAddr> { pub fn peer_addr(&self) -> io::Result<SocketAddr> {
self.io_handle.get_ref().peer_addr() self.watcher.get_ref().peer_addr()
} }
/// Receives data from the socket. /// Receives data from the socket.
@ -200,16 +198,8 @@ impl UnixDatagram {
/// ``` /// ```
pub async fn recv_from(&self, buf: &mut [u8]) -> io::Result<(usize, SocketAddr)> { pub async fn recv_from(&self, buf: &mut [u8]) -> io::Result<(usize, SocketAddr)> {
future::poll_fn(|cx| { future::poll_fn(|cx| {
futures_core::ready!(self.io_handle.poll_readable(cx)?); self.watcher
.poll_read_with(cx, |inner| inner.recv_from(buf))
match self.io_handle.get_ref().recv_from(buf) {
Ok(n) => Poll::Ready(Ok(n)),
Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => {
self.io_handle.clear_readable(cx)?;
Poll::Pending
}
Err(err) => Poll::Ready(Err(err)),
}
}) })
.await .await
} }
@ -232,19 +222,7 @@ impl UnixDatagram {
/// # Ok(()) }) } /// # Ok(()) }) }
/// ``` /// ```
pub async fn recv(&self, buf: &mut [u8]) -> io::Result<usize> { pub async fn recv(&self, buf: &mut [u8]) -> io::Result<usize> {
future::poll_fn(|cx| { future::poll_fn(|cx| self.watcher.poll_read_with(cx, |inner| inner.recv(buf))).await
futures_core::ready!(self.io_handle.poll_writable(cx)?);
match self.io_handle.get_ref().recv(buf) {
Ok(n) => Poll::Ready(Ok(n)),
Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => {
self.io_handle.clear_writable(cx)?;
Poll::Pending
}
Err(err) => Poll::Ready(Err(err)),
}
})
.await
} }
/// Sends data on the socket to the specified address. /// Sends data on the socket to the specified address.
@ -265,16 +243,8 @@ impl UnixDatagram {
/// ``` /// ```
pub async fn send_to<P: AsRef<Path>>(&self, buf: &[u8], path: P) -> io::Result<usize> { pub async fn send_to<P: AsRef<Path>>(&self, buf: &[u8], path: P) -> io::Result<usize> {
future::poll_fn(|cx| { future::poll_fn(|cx| {
futures_core::ready!(self.io_handle.poll_writable(cx)?); self.watcher
.poll_write_with(cx, |inner| inner.send_to(buf, path.as_ref()))
match self.io_handle.get_ref().send_to(buf, path.as_ref()) {
Ok(n) => Poll::Ready(Ok(n)),
Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => {
self.io_handle.clear_writable(cx)?;
Poll::Pending
}
Err(err) => Poll::Ready(Err(err)),
}
}) })
.await .await
} }
@ -297,19 +267,7 @@ impl UnixDatagram {
/// # Ok(()) }) } /// # Ok(()) }) }
/// ``` /// ```
pub async fn send(&self, buf: &[u8]) -> io::Result<usize> { pub async fn send(&self, buf: &[u8]) -> io::Result<usize> {
future::poll_fn(|cx| { future::poll_fn(|cx| self.watcher.poll_write_with(cx, |inner| inner.send(buf))).await
futures_core::ready!(self.io_handle.poll_writable(cx)?);
match self.io_handle.get_ref().send(buf) {
Ok(n) => Poll::Ready(Ok(n)),
Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => {
self.io_handle.clear_writable(cx)?;
Poll::Pending
}
Err(err) => Poll::Ready(Err(err)),
}
})
.await
} }
/// Shut down the read, write, or both halves of this connection. /// Shut down the read, write, or both halves of this connection.
@ -333,7 +291,7 @@ impl UnixDatagram {
/// # Ok(()) }) } /// # Ok(()) }) }
/// ``` /// ```
pub fn shutdown(&self, how: Shutdown) -> io::Result<()> { pub fn shutdown(&self, how: Shutdown) -> io::Result<()> {
self.io_handle.get_ref().shutdown(how) self.watcher.get_ref().shutdown(how)
} }
} }
@ -359,14 +317,14 @@ impl From<std::os::unix::net::UnixDatagram> for UnixDatagram {
fn from(datagram: std::os::unix::net::UnixDatagram) -> UnixDatagram { fn from(datagram: std::os::unix::net::UnixDatagram) -> UnixDatagram {
let mio_datagram = mio_uds::UnixDatagram::from_datagram(datagram).unwrap(); let mio_datagram = mio_uds::UnixDatagram::from_datagram(datagram).unwrap();
UnixDatagram { UnixDatagram {
io_handle: IoHandle::new(mio_datagram), watcher: Watcher::new(mio_datagram),
} }
} }
} }
impl AsRawFd for UnixDatagram { impl AsRawFd for UnixDatagram {
fn as_raw_fd(&self) -> RawFd { fn as_raw_fd(&self) -> RawFd {
self.io_handle.get_ref().as_raw_fd() self.watcher.get_ref().as_raw_fd()
} }
} }
@ -379,6 +337,6 @@ impl FromRawFd for UnixDatagram {
impl IntoRawFd for UnixDatagram { impl IntoRawFd for UnixDatagram {
fn into_raw_fd(self) -> RawFd { fn into_raw_fd(self) -> RawFd {
self.io_handle.into_inner().into_raw_fd() self.watcher.into_inner().into_raw_fd()
} }
} }

View file

@ -10,7 +10,7 @@ use super::SocketAddr;
use super::UnixStream; use super::UnixStream;
use crate::future::{self, Future}; use crate::future::{self, Future};
use crate::io; use crate::io;
use crate::net::driver::IoHandle; use crate::net::driver::Watcher;
use crate::os::unix::io::{AsRawFd, FromRawFd, IntoRawFd, RawFd}; use crate::os::unix::io::{AsRawFd, FromRawFd, IntoRawFd, RawFd};
use crate::task::{blocking, Context, Poll}; use crate::task::{blocking, Context, Poll};
@ -48,8 +48,7 @@ use crate::task::{blocking, Context, Poll};
/// # Ok(()) }) } /// # Ok(()) }) }
/// ``` /// ```
pub struct UnixListener { pub struct UnixListener {
#[cfg(not(feature = "docs"))] watcher: Watcher<mio_uds::UnixListener>,
io_handle: IoHandle<mio_uds::UnixListener>,
} }
impl UnixListener { impl UnixListener {
@ -71,7 +70,7 @@ impl UnixListener {
let listener = blocking::spawn(async move { mio_uds::UnixListener::bind(path) }).await?; let listener = blocking::spawn(async move { mio_uds::UnixListener::bind(path) }).await?;
Ok(UnixListener { Ok(UnixListener {
io_handle: IoHandle::new(listener), watcher: Watcher::new(listener),
}) })
} }
@ -93,25 +92,18 @@ impl UnixListener {
/// ``` /// ```
pub async fn accept(&self) -> io::Result<(UnixStream, SocketAddr)> { pub async fn accept(&self) -> io::Result<(UnixStream, SocketAddr)> {
future::poll_fn(|cx| { future::poll_fn(|cx| {
futures_core::ready!(self.io_handle.poll_readable(cx)?); let res =
futures_core::ready!(self.watcher.poll_read_with(cx, |inner| inner.accept_std()));
match self.io_handle.get_ref().accept_std() { match res? {
Ok(Some((io, addr))) => { None => Poll::Pending,
Some((io, addr)) => {
let mio_stream = mio_uds::UnixStream::from_stream(io)?; let mio_stream = mio_uds::UnixStream::from_stream(io)?;
let stream = UnixStream { let stream = UnixStream {
io_handle: IoHandle::new(mio_stream), watcher: Watcher::new(mio_stream),
}; };
Poll::Ready(Ok((stream, addr))) Poll::Ready(Ok((stream, addr)))
} }
Ok(None) => {
self.io_handle.clear_readable(cx)?;
Poll::Pending
}
Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => {
self.io_handle.clear_readable(cx)?;
Poll::Pending
}
Err(err) => Poll::Ready(Err(err)),
} }
}) })
.await .await
@ -162,7 +154,7 @@ impl UnixListener {
/// # Ok(()) }) } /// # Ok(()) }) }
/// ``` /// ```
pub fn local_addr(&self) -> io::Result<SocketAddr> { pub fn local_addr(&self) -> io::Result<SocketAddr> {
self.io_handle.get_ref().local_addr() self.watcher.get_ref().local_addr()
} }
} }
@ -210,14 +202,14 @@ impl From<std::os::unix::net::UnixListener> for UnixListener {
fn from(listener: std::os::unix::net::UnixListener) -> UnixListener { fn from(listener: std::os::unix::net::UnixListener) -> UnixListener {
let mio_listener = mio_uds::UnixListener::from_listener(listener).unwrap(); let mio_listener = mio_uds::UnixListener::from_listener(listener).unwrap();
UnixListener { UnixListener {
io_handle: IoHandle::new(mio_listener), watcher: Watcher::new(mio_listener),
} }
} }
} }
impl AsRawFd for UnixListener { impl AsRawFd for UnixListener {
fn as_raw_fd(&self) -> RawFd { fn as_raw_fd(&self) -> RawFd {
self.io_handle.get_ref().as_raw_fd() self.watcher.get_ref().as_raw_fd()
} }
} }
@ -230,6 +222,6 @@ impl FromRawFd for UnixListener {
impl IntoRawFd for UnixListener { impl IntoRawFd for UnixListener {
fn into_raw_fd(self) -> RawFd { fn into_raw_fd(self) -> RawFd {
self.io_handle.into_inner().into_raw_fd() self.watcher.into_inner().into_raw_fd()
} }
} }

View file

@ -1,18 +1,17 @@
//! Unix-specific networking extensions. //! Unix-specific networking extensions.
use std::fmt; use std::fmt;
use std::mem;
use std::net::Shutdown; use std::net::Shutdown;
use std::path::Path; use std::path::Path;
use std::io::{Read as _, Write as _};
use std::pin::Pin; use std::pin::Pin;
use futures_io::{AsyncRead, AsyncWrite}; use futures_io::{AsyncRead, AsyncWrite};
use mio_uds; use mio_uds;
use super::SocketAddr; use super::SocketAddr;
use crate::future;
use crate::io; use crate::io;
use crate::net::driver::IoHandle; use crate::net::driver::Watcher;
use crate::os::unix::io::{AsRawFd, FromRawFd, IntoRawFd, RawFd}; use crate::os::unix::io::{AsRawFd, FromRawFd, IntoRawFd, RawFd};
use crate::task::{blocking, Context, Poll}; use crate::task::{blocking, Context, Poll};
@ -40,8 +39,7 @@ use crate::task::{blocking, Context, Poll};
/// # Ok(()) }) } /// # Ok(()) }) }
/// ``` /// ```
pub struct UnixStream { pub struct UnixStream {
#[cfg(not(feature = "docs"))] pub(super) watcher: Watcher<mio_uds::UnixStream>,
pub(super) io_handle: IoHandle<mio_uds::UnixStream>,
} }
impl UnixStream { impl UnixStream {
@ -59,46 +57,14 @@ impl UnixStream {
/// # Ok(()) }) } /// # Ok(()) }) }
/// ``` /// ```
pub async fn connect<P: AsRef<Path>>(path: P) -> io::Result<UnixStream> { pub async fn connect<P: AsRef<Path>>(path: P) -> io::Result<UnixStream> {
enum State {
Waiting(UnixStream),
Error(io::Error),
Done,
}
let path = path.as_ref().to_owned(); let path = path.as_ref().to_owned();
let mut state = {
match blocking::spawn(async move { mio_uds::UnixStream::connect(path) }).await {
Ok(mio_stream) => State::Waiting(UnixStream {
io_handle: IoHandle::new(mio_stream),
}),
Err(err) => State::Error(err),
}
};
future::poll_fn(|cx| { blocking::spawn(async move {
match &mut state { let std_stream = std::os::unix::net::UnixStream::connect(path)?;
State::Waiting(stream) => { let mio_stream = mio_uds::UnixStream::from_stream(std_stream)?;
futures_core::ready!(stream.io_handle.poll_writable(cx)?); Ok(UnixStream {
watcher: Watcher::new(mio_stream),
if let Some(err) = stream.io_handle.get_ref().take_error()? { })
return Poll::Ready(Err(err));
}
}
State::Error(_) => {
let err = match mem::replace(&mut state, State::Done) {
State::Error(err) => err,
_ => unreachable!(),
};
return Poll::Ready(Err(err));
}
State::Done => panic!("`UnixStream::connect()` future polled after completion"),
}
match mem::replace(&mut state, State::Done) {
State::Waiting(stream) => Poll::Ready(Ok(stream)),
_ => unreachable!(),
}
}) })
.await .await
} }
@ -121,10 +87,10 @@ impl UnixStream {
pub fn pair() -> io::Result<(UnixStream, UnixStream)> { pub fn pair() -> io::Result<(UnixStream, UnixStream)> {
let (a, b) = mio_uds::UnixStream::pair()?; let (a, b) = mio_uds::UnixStream::pair()?;
let a = UnixStream { let a = UnixStream {
io_handle: IoHandle::new(a), watcher: Watcher::new(a),
}; };
let b = UnixStream { let b = UnixStream {
io_handle: IoHandle::new(b), watcher: Watcher::new(b),
}; };
Ok((a, b)) Ok((a, b))
} }
@ -144,7 +110,7 @@ impl UnixStream {
/// # Ok(()) }) } /// # Ok(()) }) }
/// ``` /// ```
pub fn local_addr(&self) -> io::Result<SocketAddr> { pub fn local_addr(&self) -> io::Result<SocketAddr> {
self.io_handle.get_ref().local_addr() self.watcher.get_ref().local_addr()
} }
/// Returns the socket address of the remote half of this connection. /// Returns the socket address of the remote half of this connection.
@ -162,7 +128,7 @@ impl UnixStream {
/// # Ok(()) }) } /// # Ok(()) }) }
/// ``` /// ```
pub fn peer_addr(&self) -> io::Result<SocketAddr> { pub fn peer_addr(&self) -> io::Result<SocketAddr> {
self.io_handle.get_ref().peer_addr() self.watcher.get_ref().peer_addr()
} }
/// Shuts down the read, write, or both halves of this connection. /// Shuts down the read, write, or both halves of this connection.
@ -184,7 +150,7 @@ impl UnixStream {
/// # Ok(()) }) } /// # Ok(()) }) }
/// ``` /// ```
pub fn shutdown(&self, how: Shutdown) -> io::Result<()> { pub fn shutdown(&self, how: Shutdown) -> io::Result<()> {
self.io_handle.get_ref().shutdown(how) self.watcher.get_ref().shutdown(how)
} }
} }
@ -204,7 +170,7 @@ impl AsyncRead for &UnixStream {
cx: &mut Context<'_>, cx: &mut Context<'_>,
buf: &mut [u8], buf: &mut [u8],
) -> Poll<io::Result<usize>> { ) -> Poll<io::Result<usize>> {
Pin::new(&mut &self.io_handle).poll_read(cx, buf) self.watcher.poll_read_with(cx, |mut inner| inner.read(buf))
} }
} }
@ -232,15 +198,15 @@ impl AsyncWrite for &UnixStream {
cx: &mut Context<'_>, cx: &mut Context<'_>,
buf: &[u8], buf: &[u8],
) -> Poll<io::Result<usize>> { ) -> Poll<io::Result<usize>> {
Pin::new(&mut &self.io_handle).poll_write(cx, buf) self.watcher.poll_write_with(cx, |mut inner| inner.write(buf))
} }
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> { fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
Pin::new(&mut &self.io_handle).poll_flush(cx) self.watcher.poll_write_with(cx, |mut inner| inner.flush())
} }
fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> { fn poll_close(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<io::Result<()>> {
Pin::new(&mut &self.io_handle).poll_close(cx) Poll::Ready(Ok(()))
} }
} }
@ -266,14 +232,14 @@ impl From<std::os::unix::net::UnixStream> for UnixStream {
fn from(stream: std::os::unix::net::UnixStream) -> UnixStream { fn from(stream: std::os::unix::net::UnixStream) -> UnixStream {
let mio_stream = mio_uds::UnixStream::from_stream(stream).unwrap(); let mio_stream = mio_uds::UnixStream::from_stream(stream).unwrap();
UnixStream { UnixStream {
io_handle: IoHandle::new(mio_stream), watcher: Watcher::new(mio_stream),
} }
} }
} }
impl AsRawFd for UnixStream { impl AsRawFd for UnixStream {
fn as_raw_fd(&self) -> RawFd { fn as_raw_fd(&self) -> RawFd {
self.io_handle.get_ref().as_raw_fd() self.watcher.get_ref().as_raw_fd()
} }
} }
@ -286,6 +252,6 @@ impl FromRawFd for UnixStream {
impl IntoRawFd for UnixStream { impl IntoRawFd for UnixStream {
fn into_raw_fd(self) -> RawFd { fn into_raw_fd(self) -> RawFd {
self.io_handle.into_inner().into_raw_fd() self.watcher.into_inner().into_raw_fd()
} }
} }