commit 5b0a6269a90b20307a430febc21fc6207f771102 Author: Florian Gilcher Date: Thu Aug 8 14:44:48 2019 +0200 Initial commit diff --git a/.gitignore b/.gitignore new file mode 100644 index 00000000..69369904 --- /dev/null +++ b/.gitignore @@ -0,0 +1,3 @@ +/target +**/*.rs.bk +Cargo.lock diff --git a/.travis.yml b/.travis.yml new file mode 100644 index 00000000..b58b8b10 --- /dev/null +++ b/.travis.yml @@ -0,0 +1,20 @@ +language: rust + +before_script: +- rustup component add rustfmt + +matrix: + fast_finish: true + include: + - rust: nightly + os: linux + - rust: nightly + os: osx + - rust: nightly-x86_64-pc-windows-msvc + os: windows + +script: + - cargo check --all --benches --bins --examples --tests + - cargo test --all + - cargo doc --features docs.rs + - cargo fmt --all -- --check diff --git a/Cargo.toml b/Cargo.toml new file mode 100644 index 00000000..34da0398 --- /dev/null +++ b/Cargo.toml @@ -0,0 +1,44 @@ +[package] +name = "async-std" +version = "0.1.0" +authors = ["Stjepan Glavina "] +edition = "2018" +license = "Apache-2.0/MIT" +repository = "https://github.com/stjepang/async-std" +homepage = "https://github.com/stjepang/async-std" +documentation = "https://docs.rs/async-std" +description = "Asynchronous standard library" +keywords = [] +categories = ["asynchronous", "concurrency"] + +[package.metadata.docs.rs] +features = ["docs.rs"] +rustdoc-args = ["--features docs.rs"] + +[features] +"docs.rs" = [] + +[dependencies] +async-task = { path = "async-task" } +cfg-if = "0.1.9" +crossbeam = "0.7.1" +futures-preview = "0.3.0-alpha.17" +futures-timer = "0.3.0" +lazy_static = "1.3.0" +log = { version = "0.4.8", features = ["kv_unstable"] } +mio = "0.6.19" +mio-uds = "0.6.7" +num_cpus = "1.10.0" +pin-utils = "0.1.0-alpha.4" +slab = "0.4.2" + +[dev-dependencies] +femme = "1.1.0" +# surf = { git = "ssh://github.com/yoshuawuyts/surf" } +tempdir = "0.3.7" + +[workspace] +members = [ + ".", + "async-task", +] diff --git a/LICENSE-APACHE b/LICENSE-APACHE new file mode 100644 index 00000000..16fe87b0 --- /dev/null +++ b/LICENSE-APACHE @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + +TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + +1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + +2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + +3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + +4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + +5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + +6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + +7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + +8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + +9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + +END OF TERMS AND CONDITIONS + +APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + +Copyright [yyyy] [name of copyright owner] + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. diff --git a/LICENSE-MIT b/LICENSE-MIT new file mode 100644 index 00000000..31aa7938 --- /dev/null +++ b/LICENSE-MIT @@ -0,0 +1,23 @@ +Permission is hereby granted, free of charge, to any +person obtaining a copy of this software and associated +documentation files (the "Software"), to deal in the +Software without restriction, including without +limitation the rights to use, copy, modify, merge, +publish, distribute, sublicense, and/or sell copies of +the Software, and to permit persons to whom the Software +is furnished to do so, subject to the following +conditions: + +The above copyright notice and this permission notice +shall be included in all copies or substantial portions +of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF +ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED +TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A +PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT +SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY +CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION +OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR +IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +DEALINGS IN THE SOFTWARE. diff --git a/README.md b/README.md new file mode 100644 index 00000000..7a29cd7d --- /dev/null +++ b/README.md @@ -0,0 +1,64 @@ +# Async version of Rust's standard library + + + + + + + + + +[![chat](https://img.shields.io/discord/598880689856970762.svg?logo=discord)](https://discord.gg/JvZeVNe) + +This crate is an async version of [`std`]. + +[`std`]: https://doc.rust-lang.org/std/index.html + +## Quickstart + +Clone the repo: + +``` +git clone git@github.com:stjepang/async-std.git && cd async-std +``` + +Read the docs: + +``` +cargo doc --features docs.rs --open +``` + +Check out the [examples](examples). To run an example: + +``` +cargo run --example hello-world +``` + +## Hello world + +```rust +#![feature(async_await)] + +use async_std::task; + +fn main() { + task::block_on(async { + println!("Hello, world!"); + }) +} +``` + +## License + +Licensed under either of + + * Apache License, Version 2.0 ([LICENSE-APACHE](LICENSE-APACHE) or http://www.apache.org/licenses/LICENSE-2.0) + * MIT license ([LICENSE-MIT](LICENSE-MIT) or http://opensource.org/licenses/MIT) + +at your option. + +#### Contribution + +Unless you explicitly state otherwise, any contribution intentionally submitted +for inclusion in the work by you, as defined in the Apache-2.0 license, shall be +dual licensed as above, without any additional terms or conditions. diff --git a/async-task/Cargo.toml b/async-task/Cargo.toml new file mode 100644 index 00000000..2286f1ee --- /dev/null +++ b/async-task/Cargo.toml @@ -0,0 +1,20 @@ +[package] +name = "async-task" +version = "0.1.0" +authors = ["Stjepan Glavina "] +edition = "2018" +license = "Apache-2.0/MIT" +repository = "https://github.com/stjepang/async-task" +homepage = "https://github.com/stjepang/async-task" +documentation = "https://docs.rs/async-task" +description = "Task abstraction for building executors" +keywords = ["future", "task", "executor", "spawn"] +categories = ["asynchronous", "concurrency"] + +[dependencies] +crossbeam-utils = "0.6.5" + +[dev-dependencies] +crossbeam = "0.7.1" +futures-preview = "0.3.0-alpha.17" +lazy_static = "1.3.0" diff --git a/async-task/LICENSE-APACHE b/async-task/LICENSE-APACHE new file mode 100644 index 00000000..16fe87b0 --- /dev/null +++ b/async-task/LICENSE-APACHE @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + +TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + +1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + +2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + +3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + +4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + +5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + +6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + +7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + +8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + +9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + +END OF TERMS AND CONDITIONS + +APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + +Copyright [yyyy] [name of copyright owner] + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. diff --git a/async-task/LICENSE-MIT b/async-task/LICENSE-MIT new file mode 100644 index 00000000..31aa7938 --- /dev/null +++ b/async-task/LICENSE-MIT @@ -0,0 +1,23 @@ +Permission is hereby granted, free of charge, to any +person obtaining a copy of this software and associated +documentation files (the "Software"), to deal in the +Software without restriction, including without +limitation the rights to use, copy, modify, merge, +publish, distribute, sublicense, and/or sell copies of +the Software, and to permit persons to whom the Software +is furnished to do so, subject to the following +conditions: + +The above copyright notice and this permission notice +shall be included in all copies or substantial portions +of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF +ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED +TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A +PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT +SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY +CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION +OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR +IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +DEALINGS IN THE SOFTWARE. diff --git a/async-task/README.md b/async-task/README.md new file mode 100644 index 00000000..22fe9f23 --- /dev/null +++ b/async-task/README.md @@ -0,0 +1,21 @@ +# async-task + +A task abstraction for building executors. + +This crate makes it possible to build an efficient and extendable executor in few lines of +code. + +## License + +Licensed under either of + + * Apache License, Version 2.0 ([LICENSE-APACHE](LICENSE-APACHE) or http://www.apache.org/licenses/LICENSE-2.0) + * MIT license ([LICENSE-MIT](LICENSE-MIT) or http://opensource.org/licenses/MIT) + +at your option. + +#### Contribution + +Unless you explicitly state otherwise, any contribution intentionally submitted +for inclusion in the work by you, as defined in the Apache-2.0 license, shall be +dual licensed as above, without any additional terms or conditions. diff --git a/async-task/benches/bench.rs b/async-task/benches/bench.rs new file mode 100644 index 00000000..6fd79353 --- /dev/null +++ b/async-task/benches/bench.rs @@ -0,0 +1,43 @@ +#![feature(async_await, test)] + +extern crate test; + +use futures::channel::oneshot; +use futures::executor; +use futures::future::TryFutureExt; +use test::Bencher; + +#[bench] +fn task_create(b: &mut Bencher) { + b.iter(|| { + async_task::spawn(async {}, drop, ()); + }); +} + +#[bench] +fn task_run(b: &mut Bencher) { + b.iter(|| { + let (task, handle) = async_task::spawn(async {}, drop, ()); + task.run(); + executor::block_on(handle).unwrap(); + }); +} + +#[bench] +fn oneshot_create(b: &mut Bencher) { + b.iter(|| { + let (tx, _rx) = oneshot::channel::<()>(); + let _task = Box::new(async move { tx.send(()).map_err(|_| ()) }); + }); +} + +#[bench] +fn oneshot_run(b: &mut Bencher) { + b.iter(|| { + let (tx, rx) = oneshot::channel::<()>(); + let task = Box::new(async move { tx.send(()).map_err(|_| ()) }); + + let future = task.and_then(|_| rx.map_err(|_| ())); + executor::block_on(future).unwrap(); + }); +} diff --git a/async-task/examples/panic-propagation.rs b/async-task/examples/panic-propagation.rs new file mode 100644 index 00000000..9c4f081a --- /dev/null +++ b/async-task/examples/panic-propagation.rs @@ -0,0 +1,75 @@ +//! A single-threaded executor where join handles propagate panics from tasks. + +#![feature(async_await)] + +use std::future::Future; +use std::panic::{resume_unwind, AssertUnwindSafe}; +use std::pin::Pin; +use std::task::{Context, Poll}; +use std::thread; + +use crossbeam::channel::{unbounded, Sender}; +use futures::executor; +use futures::future::FutureExt; +use lazy_static::lazy_static; + +/// Spawns a future on the executor. +fn spawn(future: F) -> JoinHandle +where + F: Future + Send + 'static, + R: Send + 'static, +{ + lazy_static! { + // A channel that holds scheduled tasks. + static ref QUEUE: Sender> = { + let (sender, receiver) = unbounded::>(); + + // Start the executor thread. + thread::spawn(|| { + for task in receiver { + // No need for `catch_unwind()` here because panics are already caught. + task.run(); + } + }); + + sender + }; + } + + // Create a future that catches panics within itself. + let future = AssertUnwindSafe(future).catch_unwind(); + + // Create a task that is scheduled by sending itself into the channel. + let schedule = |t| QUEUE.send(t).unwrap(); + let (task, handle) = async_task::spawn(future, schedule, ()); + + // Schedule the task by sending it into the channel. + task.schedule(); + + // Wrap the handle into one that propagates panics. + JoinHandle(handle) +} + +/// A join handle that propagates panics inside the task. +struct JoinHandle(async_task::JoinHandle, ()>); + +impl Future for JoinHandle { + type Output = Option; + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + match Pin::new(&mut self.0).poll(cx) { + Poll::Pending => Poll::Pending, + Poll::Ready(None) => Poll::Ready(None), + Poll::Ready(Some(Ok(val))) => Poll::Ready(Some(val)), + Poll::Ready(Some(Err(err))) => resume_unwind(err), + } + } +} + +fn main() { + // Spawn a future that panics and block on it. + let handle = spawn(async { + panic!("Ooops!"); + }); + executor::block_on(handle); +} diff --git a/async-task/examples/panic-result.rs b/async-task/examples/panic-result.rs new file mode 100644 index 00000000..b1200a38 --- /dev/null +++ b/async-task/examples/panic-result.rs @@ -0,0 +1,74 @@ +//! A single-threaded executor where join handles catch panics inside tasks. + +#![feature(async_await)] + +use std::future::Future; +use std::panic::AssertUnwindSafe; +use std::thread; + +use crossbeam::channel::{unbounded, Sender}; +use futures::executor; +use futures::future::FutureExt; +use lazy_static::lazy_static; + +/// Spawns a future on the executor. +fn spawn(future: F) -> async_task::JoinHandle, ()> +where + F: Future + Send + 'static, + R: Send + 'static, +{ + lazy_static! { + // A channel that holds scheduled tasks. + static ref QUEUE: Sender> = { + let (sender, receiver) = unbounded::>(); + + // Start the executor thread. + thread::spawn(|| { + for task in receiver { + // No need for `catch_unwind()` here because panics are already caught. + task.run(); + } + }); + + sender + }; + } + + // Create a future that catches panics within itself. + let future = AssertUnwindSafe(future).catch_unwind(); + + // Create a task that is scheduled by sending itself into the channel. + let schedule = |t| QUEUE.send(t).unwrap(); + let (task, handle) = async_task::spawn(future, schedule, ()); + + // Schedule the task by sending it into the channel. + task.schedule(); + + handle +} + +fn main() { + // Spawn a future that completes succesfully. + let handle = spawn(async { + println!("Hello, world!"); + }); + + // Block on the future and report its result. + match executor::block_on(handle) { + None => println!("The task was cancelled."), + Some(Ok(val)) => println!("The task completed with {:?}", val), + Some(Err(_)) => println!("The task has panicked"), + } + + // Spawn a future that panics. + let handle = spawn(async { + panic!("Ooops!"); + }); + + // Block on the future and report its result. + match executor::block_on(handle) { + None => println!("The task was cancelled."), + Some(Ok(val)) => println!("The task completed with {:?}", val), + Some(Err(_)) => println!("The task has panicked"), + } +} diff --git a/async-task/examples/spawn-on-thread.rs b/async-task/examples/spawn-on-thread.rs new file mode 100644 index 00000000..6d5b9a20 --- /dev/null +++ b/async-task/examples/spawn-on-thread.rs @@ -0,0 +1,55 @@ +//! A function that runs a future to completion on a dedicated thread. + +#![feature(async_await)] + +use std::future::Future; +use std::sync::Arc; +use std::thread; + +use crossbeam::channel; +use futures::executor; + +/// Spawns a future on a new dedicated thread. +/// +/// The returned handle can be used to await the output of the future. +fn spawn_on_thread(future: F) -> async_task::JoinHandle +where + F: Future + Send + 'static, + R: Send + 'static, +{ + // Create a channel that holds the task when it is scheduled for running. + let (sender, receiver) = channel::unbounded(); + let sender = Arc::new(sender); + let s = Arc::downgrade(&sender); + + // Wrap the future into one that disconnects the channel on completion. + let future = async move { + // When the inner future completes, the sender gets dropped and disconnects the channel. + let _sender = sender; + future.await + }; + + // Create a task that is scheduled by sending itself into the channel. + let schedule = move |t| s.upgrade().unwrap().send(t).unwrap(); + let (task, handle) = async_task::spawn(future, schedule, ()); + + // Schedule the task by sending it into the channel. + task.schedule(); + + // Spawn a thread running the task to completion. + thread::spawn(move || { + // Keep taking the task from the channel and running it until completion. + for task in receiver { + task.run(); + } + }); + + handle +} + +fn main() { + // Spawn a future on a dedicated thread. + executor::block_on(spawn_on_thread(async { + println!("Hello, world!"); + })); +} diff --git a/async-task/examples/spawn.rs b/async-task/examples/spawn.rs new file mode 100644 index 00000000..6e798c0b --- /dev/null +++ b/async-task/examples/spawn.rs @@ -0,0 +1,52 @@ +//! A simple single-threaded executor. + +#![feature(async_await)] + +use std::future::Future; +use std::panic::catch_unwind; +use std::thread; + +use crossbeam::channel::{unbounded, Sender}; +use futures::executor; +use lazy_static::lazy_static; + +/// Spawns a future on the executor. +fn spawn(future: F) -> async_task::JoinHandle +where + F: Future + Send + 'static, + R: Send + 'static, +{ + lazy_static! { + // A channel that holds scheduled tasks. + static ref QUEUE: Sender> = { + let (sender, receiver) = unbounded::>(); + + // Start the executor thread. + thread::spawn(|| { + for task in receiver { + // Ignore panics for simplicity. + let _ignore_panic = catch_unwind(|| task.run()); + } + }); + + sender + }; + } + + // Create a task that is scheduled by sending itself into the channel. + let schedule = |t| QUEUE.send(t).unwrap(); + let (task, handle) = async_task::spawn(future, schedule, ()); + + // Schedule the task by sending it into the channel. + task.schedule(); + + handle +} + +fn main() { + // Spawn a future and await its result. + let handle = spawn(async { + println!("Hello, world!"); + }); + executor::block_on(handle); +} diff --git a/async-task/examples/task-id.rs b/async-task/examples/task-id.rs new file mode 100644 index 00000000..b3832d07 --- /dev/null +++ b/async-task/examples/task-id.rs @@ -0,0 +1,88 @@ +//! An executor that assigns an ID to every spawned task. + +#![feature(async_await)] + +use std::cell::Cell; +use std::future::Future; +use std::panic::catch_unwind; +use std::thread; + +use crossbeam::atomic::AtomicCell; +use crossbeam::channel::{unbounded, Sender}; +use futures::executor; +use lazy_static::lazy_static; + +#[derive(Clone, Copy, Debug)] +struct TaskId(usize); + +thread_local! { + /// The ID of the current task. + static TASK_ID: Cell> = Cell::new(None); +} + +/// Returns the ID of the currently executing task. +/// +/// Returns `None` if called outside the runtime. +fn task_id() -> Option { + TASK_ID.with(|id| id.get()) +} + +/// Spawns a future on the executor. +fn spawn(future: F) -> async_task::JoinHandle +where + F: Future + Send + 'static, + R: Send + 'static, +{ + lazy_static! { + // A channel that holds scheduled tasks. + static ref QUEUE: Sender> = { + let (sender, receiver) = unbounded::>(); + + // Start the executor thread. + thread::spawn(|| { + TASK_ID.with(|id| { + for task in receiver { + // Store the task ID into the thread-local before running. + id.set(Some(*task.tag())); + + // Ignore panics for simplicity. + let _ignore_panic = catch_unwind(|| task.run()); + } + }) + }); + + sender + }; + + // A counter that assigns IDs to spawned tasks. + static ref COUNTER: AtomicCell = AtomicCell::new(0); + } + + // Reserve an ID for the new task. + let id = TaskId(COUNTER.fetch_add(1)); + + // Create a task that is scheduled by sending itself into the channel. + let schedule = |task| QUEUE.send(task).unwrap(); + let (task, handle) = async_task::spawn(future, schedule, id); + + // Schedule the task by sending it into the channel. + task.schedule(); + + handle +} + +fn main() { + let mut handles = vec![]; + + // Spawn a bunch of tasks. + for _ in 0..10 { + handles.push(spawn(async move { + println!("Hello from task with {:?}", task_id()); + })); + } + + // Wait for the tasks to finish. + for handle in handles { + executor::block_on(handle); + } +} diff --git a/async-task/src/header.rs b/async-task/src/header.rs new file mode 100644 index 00000000..0ce51645 --- /dev/null +++ b/async-task/src/header.rs @@ -0,0 +1,158 @@ +use std::alloc::Layout; +use std::cell::Cell; +use std::fmt; +use std::sync::atomic::{AtomicUsize, Ordering}; +use std::task::Waker; + +use crossbeam_utils::Backoff; + +use crate::raw::TaskVTable; +use crate::state::*; +use crate::utils::{abort_on_panic, extend}; + +/// The header of a task. +/// +/// This header is stored right at the beginning of every heap-allocated task. +pub(crate) struct Header { + /// Current state of the task. + /// + /// Contains flags representing the current state and the reference count. + pub(crate) state: AtomicUsize, + + /// The task that is blocked on the `JoinHandle`. + /// + /// This waker needs to be woken once the task completes or is closed. + pub(crate) awaiter: Cell>, + + /// The virtual table. + /// + /// In addition to the actual waker virtual table, it also contains pointers to several other + /// methods necessary for bookkeeping the heap-allocated task. + pub(crate) vtable: &'static TaskVTable, +} + +impl Header { + /// Cancels the task. + /// + /// This method will only mark the task as closed and will notify the awaiter, but it won't + /// reschedule the task if it's not completed. + pub(crate) fn cancel(&self) { + let mut state = self.state.load(Ordering::Acquire); + + loop { + // If the task has been completed or closed, it can't be cancelled. + if state & (COMPLETED | CLOSED) != 0 { + break; + } + + // Mark the task as closed. + match self.state.compare_exchange_weak( + state, + state | CLOSED, + Ordering::AcqRel, + Ordering::Acquire, + ) { + Ok(_) => { + // Notify the awaiter that the task has been closed. + if state & AWAITER != 0 { + self.notify(); + } + + break; + } + Err(s) => state = s, + } + } + } + + /// Notifies the task blocked on the task. + /// + /// If there is a registered waker, it will be removed from the header and woken. + #[inline] + pub(crate) fn notify(&self) { + if let Some(waker) = self.swap_awaiter(None) { + // We need a safeguard against panics because waking can panic. + abort_on_panic(|| { + waker.wake(); + }); + } + } + + /// Notifies the task blocked on the task unless its waker matches `current`. + /// + /// If there is a registered waker, it will be removed from the header. + #[inline] + pub(crate) fn notify_unless(&self, current: &Waker) { + if let Some(waker) = self.swap_awaiter(None) { + if !waker.will_wake(current) { + // We need a safeguard against panics because waking can panic. + abort_on_panic(|| { + waker.wake(); + }); + } + } + } + + /// Swaps the awaiter and returns the previous value. + #[inline] + pub(crate) fn swap_awaiter(&self, new: Option) -> Option { + let new_is_none = new.is_none(); + + // We're about to try acquiring the lock in a loop. If it's already being held by another + // thread, we'll have to spin for a while so it's best to employ a backoff strategy. + let backoff = Backoff::new(); + loop { + // Acquire the lock. If we're storing an awaiter, then also set the awaiter flag. + let state = if new_is_none { + self.state.fetch_or(LOCKED, Ordering::Acquire) + } else { + self.state.fetch_or(LOCKED | AWAITER, Ordering::Acquire) + }; + + // If the lock was acquired, break from the loop. + if state & LOCKED == 0 { + break; + } + + // Snooze for a little while because the lock is held by another thread. + backoff.snooze(); + } + + // Replace the awaiter. + let old = self.awaiter.replace(new); + + // Release the lock. If we've cleared the awaiter, then also unset the awaiter flag. + if new_is_none { + self.state.fetch_and(!LOCKED & !AWAITER, Ordering::Release); + } else { + self.state.fetch_and(!LOCKED, Ordering::Release); + } + + old + } + + /// Returns the offset at which the tag of type `T` is stored. + #[inline] + pub(crate) fn offset_tag() -> usize { + let layout_header = Layout::new::
(); + let layout_t = Layout::new::(); + let (_, offset_t) = extend(layout_header, layout_t); + offset_t + } +} + +impl fmt::Debug for Header { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let state = self.state.load(Ordering::SeqCst); + + f.debug_struct("Header") + .field("scheduled", &(state & SCHEDULED != 0)) + .field("running", &(state & RUNNING != 0)) + .field("completed", &(state & COMPLETED != 0)) + .field("closed", &(state & CLOSED != 0)) + .field("awaiter", &(state & AWAITER != 0)) + .field("handle", &(state & HANDLE != 0)) + .field("ref_count", &(state / REFERENCE)) + .finish() + } +} diff --git a/async-task/src/join_handle.rs b/async-task/src/join_handle.rs new file mode 100644 index 00000000..fb5c275e --- /dev/null +++ b/async-task/src/join_handle.rs @@ -0,0 +1,333 @@ +use std::fmt; +use std::future::Future; +use std::marker::{PhantomData, Unpin}; +use std::pin::Pin; +use std::ptr::NonNull; +use std::sync::atomic::Ordering; +use std::task::{Context, Poll}; + +use crate::header::Header; +use crate::state::*; +use crate::utils::abort_on_panic; + +/// A handle that awaits the result of a task. +/// +/// If the task has completed with `value`, the handle returns it as `Some(value)`. If the task was +/// cancelled or has panicked, the handle returns `None`. Otherwise, the handle has to wait until +/// the task completes, panics, or gets cancelled. +/// +/// # Examples +/// +/// ``` +/// #![feature(async_await)] +/// +/// use crossbeam::channel; +/// use futures::executor; +/// +/// // The future inside the task. +/// let future = async { 1 + 2 }; +/// +/// // If the task gets woken, it will be sent into this channel. +/// let (s, r) = channel::unbounded(); +/// let schedule = move |task| s.send(task).unwrap(); +/// +/// // Create a task with the future and the schedule function. +/// let (task, handle) = async_task::spawn(future, schedule, ()); +/// +/// // Run the task. In this example, it will complete after a single run. +/// task.run(); +/// assert!(r.is_empty()); +/// +/// // Await the result of the task. +/// let result = executor::block_on(handle); +/// assert_eq!(result, Some(3)); +/// ``` +pub struct JoinHandle { + /// A raw task pointer. + pub(crate) raw_task: NonNull<()>, + + /// A marker capturing the generic type `R`. + pub(crate) _marker: PhantomData<(R, T)>, +} + +unsafe impl Send for JoinHandle {} +unsafe impl Sync for JoinHandle {} + +impl Unpin for JoinHandle {} + +impl JoinHandle { + /// Cancels the task. + /// + /// When cancelled, the task won't be scheduled again even if a [`Waker`] wakes it. An attempt + /// to run it won't do anything. And if it's completed, awaiting its result evaluates to + /// `None`. + /// + /// [`Waker`]: https://doc.rust-lang.org/std/task/struct.Waker.html + /// + /// # Examples + /// + /// ``` + /// # #![feature(async_await)] + /// use crossbeam::channel; + /// use futures::executor; + /// + /// // The future inside the task. + /// let future = async { 1 + 2 }; + /// + /// // If the task gets woken, it will be sent into this channel. + /// let (s, r) = channel::unbounded(); + /// let schedule = move |task| s.send(task).unwrap(); + /// + /// // Create a task with the future and the schedule function. + /// let (task, handle) = async_task::spawn(future, schedule, ()); + /// + /// // Cancel the task. + /// handle.cancel(); + /// + /// // Running a cancelled task does nothing. + /// task.run(); + /// + /// // Await the result of the task. + /// let result = executor::block_on(handle); + /// assert_eq!(result, None); + /// ``` + pub fn cancel(&self) { + let ptr = self.raw_task.as_ptr(); + let header = ptr as *const Header; + + unsafe { + let mut state = (*header).state.load(Ordering::Acquire); + + loop { + // If the task has been completed or closed, it can't be cancelled. + if state & (COMPLETED | CLOSED) != 0 { + break; + } + + // If the task is not scheduled nor running, we'll need to schedule it. + let new = if state & (SCHEDULED | RUNNING) == 0 { + (state | SCHEDULED | CLOSED) + REFERENCE + } else { + state | CLOSED + }; + + // Mark the task as closed. + match (*header).state.compare_exchange_weak( + state, + new, + Ordering::AcqRel, + Ordering::Acquire, + ) { + Ok(_) => { + // If the task is not scheduled nor running, schedule it so that its future + // gets dropped by the executor. + if state & (SCHEDULED | RUNNING) == 0 { + ((*header).vtable.schedule)(ptr); + } + + // Notify the awaiter that the task has been closed. + if state & AWAITER != 0 { + (*header).notify(); + } + + break; + } + Err(s) => state = s, + } + } + } + } + + /// Returns a reference to the tag stored inside the task. + /// + /// # Examples + /// + /// ``` + /// # #![feature(async_await)] + /// use crossbeam::channel; + /// + /// // The future inside the task. + /// let future = async { 1 + 2 }; + /// + /// // If the task gets woken, it will be sent into this channel. + /// let (s, r) = channel::unbounded(); + /// let schedule = move |task| s.send(task).unwrap(); + /// + /// // Create a task with the future and the schedule function. + /// let (task, handle) = async_task::spawn(future, schedule, "a simple task"); + /// + /// // Access the tag. + /// assert_eq!(*handle.tag(), "a simple task"); + /// ``` + pub fn tag(&self) -> &T { + let offset = Header::offset_tag::(); + let ptr = self.raw_task.as_ptr(); + + unsafe { + let raw = (ptr as *mut u8).add(offset) as *const T; + &*raw + } + } +} + +impl Drop for JoinHandle { + fn drop(&mut self) { + let ptr = self.raw_task.as_ptr(); + let header = ptr as *const Header; + + // A place where the output will be stored in case it needs to be dropped. + let mut output = None; + + unsafe { + // Optimistically assume the `JoinHandle` is being dropped just after creating the + // task. This is a common case so if the handle is not used, the overhead of it is only + // one compare-exchange operation. + if let Err(mut state) = (*header).state.compare_exchange_weak( + SCHEDULED | HANDLE | REFERENCE, + SCHEDULED | REFERENCE, + Ordering::AcqRel, + Ordering::Acquire, + ) { + loop { + // If the task has been completed but not yet closed, that means its output + // must be dropped. + if state & COMPLETED != 0 && state & CLOSED == 0 { + // Mark the task as closed in order to grab its output. + match (*header).state.compare_exchange_weak( + state, + state | CLOSED, + Ordering::AcqRel, + Ordering::Acquire, + ) { + Ok(_) => { + // Read the output. + output = + Some((((*header).vtable.get_output)(ptr) as *mut R).read()); + + // Update the state variable because we're continuing the loop. + state |= CLOSED; + } + Err(s) => state = s, + } + } else { + // If this is the last reference to task and it's not closed, then close + // it and schedule one more time so that its future gets dropped by the + // executor. + let new = if state & (!(REFERENCE - 1) | CLOSED) == 0 { + SCHEDULED | CLOSED | REFERENCE + } else { + state & !HANDLE + }; + + // Unset the handle flag. + match (*header).state.compare_exchange_weak( + state, + new, + Ordering::AcqRel, + Ordering::Acquire, + ) { + Ok(_) => { + // If this is the last reference to the task, we need to either + // schedule dropping its future or destroy it. + if state & !(REFERENCE - 1) == 0 { + if state & CLOSED == 0 { + ((*header).vtable.schedule)(ptr); + } else { + ((*header).vtable.destroy)(ptr); + } + } + + break; + } + Err(s) => state = s, + } + } + } + } + } + + // Drop the output if it was taken out of the task. + drop(output); + } +} + +impl Future for JoinHandle { + type Output = Option; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let ptr = self.raw_task.as_ptr(); + let header = ptr as *const Header; + + unsafe { + let mut state = (*header).state.load(Ordering::Acquire); + + loop { + // If the task has been closed, notify the awaiter and return `None`. + if state & CLOSED != 0 { + // Even though the awaiter is most likely the current task, it could also be + // another task. + (*header).notify_unless(cx.waker()); + return Poll::Ready(None); + } + + // If the task is not completed, register the current task. + if state & COMPLETED == 0 { + // Replace the waker with one associated with the current task. We need a + // safeguard against panics because dropping the previous waker can panic. + abort_on_panic(|| { + (*header).swap_awaiter(Some(cx.waker().clone())); + }); + + // Reload the state after registering. It is possible that the task became + // completed or closed just before registration so we need to check for that. + state = (*header).state.load(Ordering::Acquire); + + // If the task has been closed, notify the awaiter and return `None`. + if state & CLOSED != 0 { + // Even though the awaiter is most likely the current task, it could also + // be another task. + (*header).notify_unless(cx.waker()); + return Poll::Ready(None); + } + + // If the task is still not completed, we're blocked on it. + if state & COMPLETED == 0 { + return Poll::Pending; + } + } + + // Since the task is now completed, mark it as closed in order to grab its output. + match (*header).state.compare_exchange( + state, + state | CLOSED, + Ordering::AcqRel, + Ordering::Acquire, + ) { + Ok(_) => { + // Notify the awaiter. Even though the awaiter is most likely the current + // task, it could also be another task. + if state & AWAITER != 0 { + (*header).notify_unless(cx.waker()); + } + + // Take the output from the task. + let output = ((*header).vtable.get_output)(ptr) as *mut R; + return Poll::Ready(Some(output.read())); + } + Err(s) => state = s, + } + } + } + } +} + +impl fmt::Debug for JoinHandle { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let ptr = self.raw_task.as_ptr(); + let header = ptr as *const Header; + + f.debug_struct("JoinHandle") + .field("header", unsafe { &(*header) }) + .finish() + } +} diff --git a/async-task/src/lib.rs b/async-task/src/lib.rs new file mode 100644 index 00000000..55185154 --- /dev/null +++ b/async-task/src/lib.rs @@ -0,0 +1,149 @@ +//! Task abstraction for building executors. +//! +//! # What is an executor? +//! +//! An async block creates a future and an async function returns one. But futures don't do +//! anything unless they are awaited inside other async blocks or async functions. So the question +//! arises: who or what awaits the main future that awaits others? +//! +//! One solution is to call [`block_on()`] on the main future, which will block +//! the current thread and keep polling the future until it completes. But sometimes we don't want +//! to block the current thread and would prefer to *spawn* the future to let a background thread +//! block on it instead. +//! +//! This is where executors step in - they create a number of threads (typically equal to the +//! number of CPU cores on the system) that are dedicated to polling spawned futures. Each executor +//! thread keeps polling spawned futures in a loop and only blocks when all spawned futures are +//! either sleeping or running. +//! +//! # What is a task? +//! +//! In order to spawn a future on an executor, one needs to allocate the future on the heap and +//! keep some state alongside it, like whether the future is ready for polling, waiting to be woken +//! up, or completed. This allocation is usually called a *task*. +//! +//! The executor then runs the spawned task by polling its future. If the future is pending on a +//! resource, a [`Waker`] associated with the task will be registered somewhere so that the task +//! can be woken up and run again at a later time. +//! +//! For example, if the future wants to read something from a TCP socket that is not ready yet, the +//! networking system will clone the task's waker and wake it up once the socket becomes ready. +//! +//! # Task construction +//! +//! A task is constructed with [`Task::create()`]: +//! +//! ``` +//! # #![feature(async_await)] +//! let future = async { 1 + 2 }; +//! let schedule = |task| unimplemented!(); +//! +//! let (task, handle) = async_task::spawn(future, schedule, ()); +//! ``` +//! +//! The first argument to the constructor, `()` in this example, is an arbitrary piece of data +//! called a *tag*. This can be a task identifier, a task name, task-local storage, or something +//! of similar nature. +//! +//! The second argument is the future that gets polled when the task is run. +//! +//! The third argument is the schedule function, which is called every time when the task gets +//! woken up. This function should push the received task into some kind of queue of runnable +//! tasks. +//! +//! The constructor returns a runnable [`Task`] and a [`JoinHandle`] that can await the result of +//! the future. +//! +//! # Task scheduling +//! +//! TODO +//! +//! # Join handles +//! +//! TODO +//! +//! # Cancellation +//! +//! TODO +//! +//! # Performance +//! +//! TODO: explain single allocation, etc. +//! +//! Task [construction] incurs a single allocation only. The [`Task`] can then be run and its +//! result awaited through the [`JoinHandle`]. When woken, the task gets automatically rescheduled. +//! It's also possible to cancel the task so that it stops running and can't be awaited anymore. +//! +//! [construction]: struct.Task.html#method.create +//! [`JoinHandle`]: struct.JoinHandle.html +//! [`Task`]: struct.Task.html +//! [`Future`]: https://doc.rust-lang.org/nightly/std/future/trait.Future.html +//! [`Waker`]: https://doc.rust-lang.org/nightly/std/task/struct.Waker.html +//! [`block_on()`]: https://docs.rs/futures-preview/*/futures/executor/fn.block_on.html +//! +//! # Examples +//! +//! A simple single-threaded executor: +//! +//! ``` +//! # #![feature(async_await)] +//! use std::future::Future; +//! use std::panic::catch_unwind; +//! use std::thread; +//! +//! use async_task::{JoinHandle, Task}; +//! use crossbeam::channel::{unbounded, Sender}; +//! use futures::executor; +//! use lazy_static::lazy_static; +//! +//! /// Spawns a future on the executor. +//! fn spawn(future: F) -> JoinHandle +//! where +//! F: Future + Send + 'static, +//! R: Send + 'static, +//! { +//! lazy_static! { +//! // A channel that holds scheduled tasks. +//! static ref QUEUE: Sender> = { +//! let (sender, receiver) = unbounded::>(); +//! +//! // Start the executor thread. +//! thread::spawn(|| { +//! for task in receiver { +//! // Ignore panics for simplicity. +//! let _ignore_panic = catch_unwind(|| task.run()); +//! } +//! }); +//! +//! sender +//! }; +//! } +//! +//! // Create a task that is scheduled by sending itself into the channel. +//! let schedule = |t| QUEUE.send(t).unwrap(); +//! let (task, handle) = async_task::spawn(future, schedule, ()); +//! +//! // Schedule the task by sending it into the channel. +//! task.schedule(); +//! +//! handle +//! } +//! +//! // Spawn a future and await its result. +//! let handle = spawn(async { +//! println!("Hello, world!"); +//! }); +//! executor::block_on(handle); +//! ``` + +#![warn(missing_docs, missing_debug_implementations, rust_2018_idioms)] + +mod header; +mod join_handle; +mod raw; +mod state; +mod task; +mod utils; + +pub use crate::join_handle::JoinHandle; +pub use crate::task::{spawn, Task}; diff --git a/async-task/src/raw.rs b/async-task/src/raw.rs new file mode 100644 index 00000000..69284275 --- /dev/null +++ b/async-task/src/raw.rs @@ -0,0 +1,629 @@ +use std::alloc::{self, Layout}; +use std::cell::Cell; +use std::future::Future; +use std::marker::PhantomData; +use std::mem::{self, ManuallyDrop}; +use std::pin::Pin; +use std::ptr::NonNull; +use std::sync::atomic::{AtomicUsize, Ordering}; +use std::task::{Context, Poll, RawWaker, RawWakerVTable, Waker}; + +use crate::header::Header; +use crate::state::*; +use crate::utils::{abort_on_panic, extend}; +use crate::Task; + +/// The vtable for a task. +pub(crate) struct TaskVTable { + /// The raw waker vtable. + pub(crate) raw_waker: RawWakerVTable, + + /// Schedules the task. + pub(crate) schedule: unsafe fn(*const ()), + + /// Drops the future inside the task. + pub(crate) drop_future: unsafe fn(*const ()), + + /// Returns a pointer to the output stored after completion. + pub(crate) get_output: unsafe fn(*const ()) -> *const (), + + /// Drops a waker or a task. + pub(crate) decrement: unsafe fn(ptr: *const ()), + + /// Destroys the task. + pub(crate) destroy: unsafe fn(*const ()), + + /// Runs the task. + pub(crate) run: unsafe fn(*const ()), +} + +/// Memory layout of a task. +/// +/// This struct contains the information on: +/// +/// 1. How to allocate and deallocate the task. +/// 2. How to access the fields inside the task. +#[derive(Clone, Copy)] +pub(crate) struct TaskLayout { + /// Memory layout of the whole task. + pub(crate) layout: Layout, + + /// Offset into the task at which the tag is stored. + pub(crate) offset_t: usize, + + /// Offset into the task at which the schedule function is stored. + pub(crate) offset_s: usize, + + /// Offset into the task at which the future is stored. + pub(crate) offset_f: usize, + + /// Offset into the task at which the output is stored. + pub(crate) offset_r: usize, +} + +/// Raw pointers to the fields of a task. +pub(crate) struct RawTask { + /// The task header. + pub(crate) header: *const Header, + + /// The schedule function. + pub(crate) schedule: *const S, + + /// The tag inside the task. + pub(crate) tag: *mut T, + + /// The future. + pub(crate) future: *mut F, + + /// The output of the future. + pub(crate) output: *mut R, +} + +impl Copy for RawTask {} + +impl Clone for RawTask { + fn clone(&self) -> Self { + Self { + header: self.header, + schedule: self.schedule, + tag: self.tag, + future: self.future, + output: self.output, + } + } +} + +impl RawTask +where + F: Future + Send + 'static, + R: Send + 'static, + S: Fn(Task) + Send + Sync + 'static, + T: Send + 'static, +{ + /// Allocates a task with the given `future` and `schedule` function. + /// + /// It is assumed there are initially only the `Task` reference and the `JoinHandle`. + pub(crate) fn allocate(tag: T, future: F, schedule: S) -> NonNull<()> { + // Compute the layout of the task for allocation. Abort if the computation fails. + let task_layout = abort_on_panic(|| Self::task_layout()); + + unsafe { + // Allocate enough space for the entire task. + let raw_task = match NonNull::new(alloc::alloc(task_layout.layout) as *mut ()) { + None => std::process::abort(), + Some(p) => p, + }; + + let raw = Self::from_ptr(raw_task.as_ptr()); + + // Write the header as the first field of the task. + (raw.header as *mut Header).write(Header { + state: AtomicUsize::new(SCHEDULED | HANDLE | REFERENCE), + awaiter: Cell::new(None), + vtable: &TaskVTable { + raw_waker: RawWakerVTable::new( + Self::clone_waker, + Self::wake, + Self::wake_by_ref, + Self::decrement, + ), + schedule: Self::schedule, + drop_future: Self::drop_future, + get_output: Self::get_output, + decrement: Self::decrement, + destroy: Self::destroy, + run: Self::run, + }, + }); + + // Write the tag as the second field of the task. + (raw.tag as *mut T).write(tag); + + // Write the schedule function as the third field of the task. + (raw.schedule as *mut S).write(schedule); + + // Write the future as the fourth field of the task. + raw.future.write(future); + + raw_task + } + } + + /// Creates a `RawTask` from a raw task pointer. + #[inline] + pub(crate) fn from_ptr(ptr: *const ()) -> Self { + let task_layout = Self::task_layout(); + let p = ptr as *const u8; + + unsafe { + Self { + header: p as *const Header, + tag: p.add(task_layout.offset_t) as *mut T, + schedule: p.add(task_layout.offset_s) as *const S, + future: p.add(task_layout.offset_f) as *mut F, + output: p.add(task_layout.offset_r) as *mut R, + } + } + } + + /// Returns the memory layout for a task. + #[inline] + fn task_layout() -> TaskLayout { + // Compute the layouts for `Header`, `T`, `S`, `F`, and `R`. + let layout_header = Layout::new::
(); + let layout_t = Layout::new::(); + let layout_s = Layout::new::(); + let layout_f = Layout::new::(); + let layout_r = Layout::new::(); + + // Compute the layout for `union { F, R }`. + let size_union = layout_f.size().max(layout_r.size()); + let align_union = layout_f.align().max(layout_r.align()); + let layout_union = unsafe { Layout::from_size_align_unchecked(size_union, align_union) }; + + // Compute the layout for `Header` followed by `T`, then `S`, then `union { F, R }`. + let layout = layout_header; + let (layout, offset_t) = extend(layout, layout_t); + let (layout, offset_s) = extend(layout, layout_s); + let (layout, offset_union) = extend(layout, layout_union); + let offset_f = offset_union; + let offset_r = offset_union; + + TaskLayout { + layout, + offset_t, + offset_s, + offset_f, + offset_r, + } + } + + /// Wakes a waker. + unsafe fn wake(ptr: *const ()) { + let raw = Self::from_ptr(ptr); + + let mut state = (*raw.header).state.load(Ordering::Acquire); + + loop { + // If the task is completed or closed, it can't be woken. + if state & (COMPLETED | CLOSED) != 0 { + // Drop the waker. + Self::decrement(ptr); + break; + } + + // If the task is already scheduled, we just need to synchronize with the thread that + // will run the task by "publishing" our current view of the memory. + if state & SCHEDULED != 0 { + // Update the state without actually modifying it. + match (*raw.header).state.compare_exchange_weak( + state, + state, + Ordering::AcqRel, + Ordering::Acquire, + ) { + Ok(_) => { + // Drop the waker. + Self::decrement(ptr); + break; + } + Err(s) => state = s, + } + } else { + // Mark the task as scheduled. + match (*raw.header).state.compare_exchange_weak( + state, + state | SCHEDULED, + Ordering::AcqRel, + Ordering::Acquire, + ) { + Ok(_) => { + // If the task is not yet scheduled and isn't currently running, now is the + // time to schedule it. + if state & (SCHEDULED | RUNNING) == 0 { + // Schedule the task. + let task = Task { + raw_task: NonNull::new_unchecked(ptr as *mut ()), + _marker: PhantomData, + }; + (*raw.schedule)(task); + } else { + // Drop the waker. + Self::decrement(ptr); + } + + break; + } + Err(s) => state = s, + } + } + } + } + + /// Wakes a waker by reference. + unsafe fn wake_by_ref(ptr: *const ()) { + let raw = Self::from_ptr(ptr); + + let mut state = (*raw.header).state.load(Ordering::Acquire); + + loop { + // If the task is completed or closed, it can't be woken. + if state & (COMPLETED | CLOSED) != 0 { + break; + } + + // If the task is already scheduled, we just need to synchronize with the thread that + // will run the task by "publishing" our current view of the memory. + if state & SCHEDULED != 0 { + // Update the state without actually modifying it. + match (*raw.header).state.compare_exchange_weak( + state, + state, + Ordering::AcqRel, + Ordering::Acquire, + ) { + Ok(_) => break, + Err(s) => state = s, + } + } else { + // If the task is not scheduled nor running, we'll need to schedule after waking. + let new = if state & (SCHEDULED | RUNNING) == 0 { + (state | SCHEDULED) + REFERENCE + } else { + state | SCHEDULED + }; + + // Mark the task as scheduled. + match (*raw.header).state.compare_exchange_weak( + state, + new, + Ordering::AcqRel, + Ordering::Acquire, + ) { + Ok(_) => { + // If the task is not scheduled nor running, now is the time to schedule. + if state & (SCHEDULED | RUNNING) == 0 { + // If the reference count overflowed, abort. + if state > isize::max_value() as usize { + std::process::abort(); + } + + // Schedule the task. + let task = Task { + raw_task: NonNull::new_unchecked(ptr as *mut ()), + _marker: PhantomData, + }; + (*raw.schedule)(task); + } + + break; + } + Err(s) => state = s, + } + } + } + } + + /// Clones a waker. + unsafe fn clone_waker(ptr: *const ()) -> RawWaker { + let raw = Self::from_ptr(ptr); + let raw_waker = &(*raw.header).vtable.raw_waker; + + // Increment the reference count. With any kind of reference-counted data structure, + // relaxed ordering is fine when the reference is being cloned. + let state = (*raw.header).state.fetch_add(REFERENCE, Ordering::Relaxed); + + // If the reference count overflowed, abort. + if state > isize::max_value() as usize { + std::process::abort(); + } + + RawWaker::new(ptr, raw_waker) + } + + /// Drops a waker or a task. + /// + /// This function will decrement the reference count. If it drops down to zero and the + /// associated join handle has been dropped too, then the task gets destroyed. + #[inline] + unsafe fn decrement(ptr: *const ()) { + let raw = Self::from_ptr(ptr); + + // Decrement the reference count. + let new = (*raw.header).state.fetch_sub(REFERENCE, Ordering::AcqRel) - REFERENCE; + + // If this was the last reference to the task and the `JoinHandle` has been dropped as + // well, then destroy task. + if new & !(REFERENCE - 1) == 0 && new & HANDLE == 0 { + Self::destroy(ptr); + } + } + + /// Schedules a task for running. + /// + /// This function doesn't modify the state of the task. It only passes the task reference to + /// its schedule function. + unsafe fn schedule(ptr: *const ()) { + let raw = Self::from_ptr(ptr); + + (*raw.schedule)(Task { + raw_task: NonNull::new_unchecked(ptr as *mut ()), + _marker: PhantomData, + }); + } + + /// Drops the future inside a task. + #[inline] + unsafe fn drop_future(ptr: *const ()) { + let raw = Self::from_ptr(ptr); + + // We need a safeguard against panics because the destructor can panic. + abort_on_panic(|| { + raw.future.drop_in_place(); + }) + } + + /// Returns a pointer to the output inside a task. + unsafe fn get_output(ptr: *const ()) -> *const () { + let raw = Self::from_ptr(ptr); + raw.output as *const () + } + + /// Cleans up task's resources and deallocates it. + /// + /// If the task has not been closed, then its future or the output will be dropped. The + /// schedule function and the tag get dropped too. + #[inline] + unsafe fn destroy(ptr: *const ()) { + let raw = Self::from_ptr(ptr); + let task_layout = Self::task_layout(); + + // We need a safeguard against panics because destructors can panic. + abort_on_panic(|| { + // Drop the schedule function. + (raw.schedule as *mut S).drop_in_place(); + + // Drop the tag. + (raw.tag as *mut T).drop_in_place(); + }); + + // Finally, deallocate the memory reserved by the task. + alloc::dealloc(ptr as *mut u8, task_layout.layout); + } + + /// Runs a task. + /// + /// If polling its future panics, the task will be closed and the panic propagated into the + /// caller. + unsafe fn run(ptr: *const ()) { + let raw = Self::from_ptr(ptr); + + // Create a context from the raw task pointer and the vtable inside the its header. + let waker = ManuallyDrop::new(Waker::from_raw(RawWaker::new( + ptr, + &(*raw.header).vtable.raw_waker, + ))); + let cx = &mut Context::from_waker(&waker); + + let mut state = (*raw.header).state.load(Ordering::Acquire); + + // Update the task's state before polling its future. + loop { + // If the task has been closed, drop the task reference and return. + if state & CLOSED != 0 { + // Notify the awaiter that the task has been closed. + if state & AWAITER != 0 { + (*raw.header).notify(); + } + + // Drop the future. + Self::drop_future(ptr); + + // Drop the task reference. + Self::decrement(ptr); + return; + } + + // Mark the task as unscheduled and running. + match (*raw.header).state.compare_exchange_weak( + state, + (state & !SCHEDULED) | RUNNING, + Ordering::AcqRel, + Ordering::Acquire, + ) { + Ok(_) => { + // Update the state because we're continuing with polling the future. + state = (state & !SCHEDULED) | RUNNING; + break; + } + Err(s) => state = s, + } + } + + // Poll the inner future, but surround it with a guard that closes the task in case polling + // panics. + let guard = Guard(raw); + let poll = ::poll(Pin::new_unchecked(&mut *raw.future), cx); + mem::forget(guard); + + match poll { + Poll::Ready(out) => { + // Replace the future with its output. + Self::drop_future(ptr); + raw.output.write(out); + + // A place where the output will be stored in case it needs to be dropped. + let mut output = None; + + // The task is now completed. + loop { + // If the handle is dropped, we'll need to close it and drop the output. + let new = if state & HANDLE == 0 { + (state & !RUNNING & !SCHEDULED) | COMPLETED | CLOSED + } else { + (state & !RUNNING & !SCHEDULED) | COMPLETED + }; + + // Mark the task as not running and completed. + match (*raw.header).state.compare_exchange_weak( + state, + new, + Ordering::AcqRel, + Ordering::Acquire, + ) { + Ok(_) => { + // If the handle is dropped or if the task was closed while running, + // now it's time to drop the output. + if state & HANDLE == 0 || state & CLOSED != 0 { + // Read the output. + output = Some(raw.output.read()); + } + + // Notify the awaiter that the task has been completed. + if state & AWAITER != 0 { + (*raw.header).notify(); + } + + // Drop the task reference. + Self::decrement(ptr); + break; + } + Err(s) => state = s, + } + } + + // Drop the output if it was taken out of the task. + drop(output); + } + Poll::Pending => { + // The task is still not completed. + loop { + // If the task was closed while running, we'll need to unschedule in case it + // was woken and then clean up its resources. + let new = if state & CLOSED != 0 { + state & !RUNNING & !SCHEDULED + } else { + state & !RUNNING + }; + + // Mark the task as not running. + match (*raw.header).state.compare_exchange_weak( + state, + new, + Ordering::AcqRel, + Ordering::Acquire, + ) { + Ok(state) => { + // If the task was closed while running, we need to drop its future. + // If the task was woken while running, we need to schedule it. + // Otherwise, we just drop the task reference. + if state & CLOSED != 0 { + // The thread that closed the task didn't drop the future because + // it was running so now it's our responsibility to do so. + Self::drop_future(ptr); + + // Drop the task reference. + Self::decrement(ptr); + } else if state & SCHEDULED != 0 { + // The thread that has woken the task didn't reschedule it because + // it was running so now it's our responsibility to do so. + Self::schedule(ptr); + } else { + // Drop the task reference. + Self::decrement(ptr); + } + break; + } + Err(s) => state = s, + } + } + } + } + + /// A guard that closes the task if polling its future panics. + struct Guard(RawTask) + where + F: Future + Send + 'static, + R: Send + 'static, + S: Fn(Task) + Send + Sync + 'static, + T: Send + 'static; + + impl Drop for Guard + where + F: Future + Send + 'static, + R: Send + 'static, + S: Fn(Task) + Send + Sync + 'static, + T: Send + 'static, + { + fn drop(&mut self) { + let raw = self.0; + let ptr = raw.header as *const (); + + unsafe { + let mut state = (*raw.header).state.load(Ordering::Acquire); + + loop { + // If the task was closed while running, then unschedule it, drop its + // future, and drop the task reference. + if state & CLOSED != 0 { + // We still need to unschedule the task because it is possible it was + // woken while running. + (*raw.header).state.fetch_and(!SCHEDULED, Ordering::AcqRel); + + // The thread that closed the task didn't drop the future because it + // was running so now it's our responsibility to do so. + RawTask::::drop_future(ptr); + + // Drop the task reference. + RawTask::::decrement(ptr); + break; + } + + // Mark the task as not running, not scheduled, and closed. + match (*raw.header).state.compare_exchange_weak( + state, + (state & !RUNNING & !SCHEDULED) | CLOSED, + Ordering::AcqRel, + Ordering::Acquire, + ) { + Ok(state) => { + // Drop the future because the task is now closed. + RawTask::::drop_future(ptr); + + // Notify the awaiter that the task has been closed. + if state & AWAITER != 0 { + (*raw.header).notify(); + } + + // Drop the task reference. + RawTask::::decrement(ptr); + break; + } + Err(s) => state = s, + } + } + } + } + } + } +} diff --git a/async-task/src/state.rs b/async-task/src/state.rs new file mode 100644 index 00000000..d6ce34fd --- /dev/null +++ b/async-task/src/state.rs @@ -0,0 +1,65 @@ +/// Set if the task is scheduled for running. +/// +/// A task is considered to be scheduled whenever its `Task` reference exists. It is in scheduled +/// state at the moment of creation and when it gets unapused either by its `JoinHandle` or woken +/// by a `Waker`. +/// +/// This flag can't be set when the task is completed. However, it can be set while the task is +/// running, in which case it will be rescheduled as soon as polling finishes. +pub(crate) const SCHEDULED: usize = 1 << 0; + +/// Set if the task is running. +/// +/// A task is running state while its future is being polled. +/// +/// This flag can't be set when the task is completed. However, it can be in scheduled state while +/// it is running, in which case it will be rescheduled when it stops being polled. +pub(crate) const RUNNING: usize = 1 << 1; + +/// Set if the task has been completed. +/// +/// This flag is set when polling returns `Poll::Ready`. The output of the future is then stored +/// inside the task until it becomes stopped. In fact, `JoinHandle` picks the output up by marking +/// the task as stopped. +/// +/// This flag can't be set when the task is scheduled or completed. +pub(crate) const COMPLETED: usize = 1 << 2; + +/// Set if the task is closed. +/// +/// If a task is closed, that means its either cancelled or its output has been consumed by the +/// `JoinHandle`. A task becomes closed when: +/// +/// 1. It gets cancelled by `Task::cancel()` or `JoinHandle::cancel()`. +/// 2. Its output is awaited by the `JoinHandle`. +/// 3. It panics while polling the future. +/// 4. It is completed and the `JoinHandle` is dropped. +pub(crate) const CLOSED: usize = 1 << 3; + +/// Set if the `JoinHandle` still exists. +/// +/// The `JoinHandle` is a special case in that it is only tracked by this flag, while all other +/// task references (`Task` and `Waker`s) are tracked by the reference count. +pub(crate) const HANDLE: usize = 1 << 4; + +/// Set if the `JoinHandle` is awaiting the output. +/// +/// This flag is set while there is a registered awaiter of type `Waker` inside the task. When the +/// task gets closed or completed, we need to wake the awaiter. This flag can be used as a fast +/// check that tells us if we need to wake anyone without acquiring the lock inside the task. +pub(crate) const AWAITER: usize = 1 << 5; + +/// Set if the awaiter is locked. +/// +/// This lock is acquired before a new awaiter is registered or the existing one is woken. +pub(crate) const LOCKED: usize = 1 << 6; + +/// A single reference. +/// +/// The lower bits in the state contain various flags representing the task state, while the upper +/// bits contain the reference count. The value of `REFERENCE` represents a single reference in the +/// total reference count. +/// +/// Note that the reference counter only tracks the `Task` and `Waker`s. The `JoinHandle` is +/// tracked separately by the `HANDLE` flag. +pub(crate) const REFERENCE: usize = 1 << 7; diff --git a/async-task/src/task.rs b/async-task/src/task.rs new file mode 100644 index 00000000..8bfc1643 --- /dev/null +++ b/async-task/src/task.rs @@ -0,0 +1,390 @@ +use std::fmt; +use std::future::Future; +use std::marker::PhantomData; +use std::mem; +use std::ptr::NonNull; + +use crate::header::Header; +use crate::raw::RawTask; +use crate::JoinHandle; + +/// Creates a new task. +/// +/// This constructor returns a `Task` reference that runs the future and a [`JoinHandle`] that +/// awaits its result. +/// +/// The `tag` is stored inside the allocated task. +/// +/// When run, the task polls `future`. When woken, it gets scheduled for running by the +/// `schedule` function. +/// +/// # Examples +/// +/// ``` +/// # #![feature(async_await)] +/// use crossbeam::channel; +/// +/// // The future inside the task. +/// let future = async { +/// println!("Hello, world!"); +/// }; +/// +/// // If the task gets woken, it will be sent into this channel. +/// let (s, r) = channel::unbounded(); +/// let schedule = move |task| s.send(task).unwrap(); +/// +/// // Create a task with the future and the schedule function. +/// let (task, handle) = async_task::spawn(future, schedule, ()); +/// ``` +/// +/// [`JoinHandle`]: struct.JoinHandle.html +pub fn spawn(future: F, schedule: S, tag: T) -> (Task, JoinHandle) +where + F: Future + Send + 'static, + R: Send + 'static, + S: Fn(Task) + Send + Sync + 'static, + T: Send + Sync + 'static, +{ + let raw_task = RawTask::::allocate(tag, future, schedule); + let task = Task { + raw_task, + _marker: PhantomData, + }; + let handle = JoinHandle { + raw_task, + _marker: PhantomData, + }; + (task, handle) +} + +/// A task that runs a future. +/// +/// # Construction +/// +/// A task is a heap-allocated structure containing: +/// +/// * A reference counter. +/// * The state of the task. +/// * Arbitrary piece of data called a *tag*. +/// * A function that schedules the task when woken. +/// * A future or its result if polling has completed. +/// +/// Constructor [`Task::create()`] returns a [`Task`] and a [`JoinHandle`]. Those two references +/// are like two sides of the task: one runs the future and the other awaits its result. +/// +/// # Behavior +/// +/// The [`Task`] reference "owns" the task itself and is used to [run] it. Running consumes the +/// [`Task`] reference and polls its internal future. If the future is still pending after being +/// polled, the [`Task`] reference will be recreated when woken by a [`Waker`]. If the future +/// completes, its result becomes available to the [`JoinHandle`]. +/// +/// The [`JoinHandle`] is a [`Future`] that awaits the result of the task. +/// +/// When the task is woken, its [`Task`] reference is recreated and passed to the schedule function +/// provided during construction. In most executors, scheduling simply pushes the [`Task`] into a +/// queue of runnable tasks. +/// +/// If the [`Task`] reference is dropped without being run, the task is cancelled. +/// +/// Both [`Task`] and [`JoinHandle`] have methods that cancel the task. When cancelled, the task +/// won't be scheduled again even if a [`Waker`] wakes it or the [`JoinHandle`] is polled. An +/// attempt to run a cancelled task won't do anything. And if the cancelled task has already +/// completed, awaiting its result through [`JoinHandle`] will return `None`. +/// +/// If polling the task's future panics, it gets cancelled automatically. +/// +/// # Task states +/// +/// A task can be in the following states: +/// +/// * Sleeping: The [`Task`] reference doesn't exist and is waiting to be scheduled by a [`Waker`]. +/// * Scheduled: The [`Task`] reference exists and is waiting to be [run]. +/// * Completed: The [`Task`] reference doesn't exist anymore and can't be rescheduled, but its +/// result is available to the [`JoinHandle`]. +/// * Cancelled: The [`Task`] reference may or may not exist, but running it does nothing and +/// awaiting the [`JoinHandle`] returns `None`. +/// +/// When constructed, the task is initially in the scheduled state. +/// +/// # Destruction +/// +/// The future inside the task gets dropped in the following cases: +/// +/// * When [`Task`] is dropped. +/// * When [`Task`] is run to completion. +/// +/// If the future hasn't been dropped and the last [`Waker`] or [`JoinHandle`] is dropped, or if +/// a [`JoinHandle`] cancels the task, then the task will be scheduled one last time so that its +/// future gets dropped by the executor. In other words, the task's future can be dropped only by +/// [`Task`]. +/// +/// When the task completes, the result of its future is stored inside the allocation. This result +/// is taken out when the [`JoinHandle`] awaits it. When the task is cancelled or the +/// [`JoinHandle`] is dropped without being awaited, the result gets dropped too. +/// +/// The task gets deallocated when all references to it are dropped, which includes the [`Task`], +/// the [`JoinHandle`], and any associated [`Waker`]s. +/// +/// The tag inside the task and the schedule function get dropped at the time of deallocation. +/// +/// # Panics +/// +/// If polling the inner future inside [`run()`] panics, the panic will be propagated into +/// the caller. Likewise, a panic inside the task result's destructor will be propagated. All other +/// panics result in the process being aborted. +/// +/// More precisely, the process is aborted if a panic occurs: +/// +/// * Inside the schedule function. +/// * While dropping the tag. +/// * While dropping the future. +/// * While dropping the schedule function. +/// * While waking the task awaiting the [`JoinHandle`]. +/// +/// [`run()`]: struct.Task.html#method.run +/// [run]: struct.Task.html#method.run +/// [`JoinHandle`]: struct.JoinHandle.html +/// [`Task`]: struct.Task.html +/// [`Task::create()`]: struct.Task.html#method.create +/// [`Future`]: https://doc.rust-lang.org/std/future/trait.Future.html +/// [`Waker`]: https://doc.rust-lang.org/std/task/struct.Waker.html +/// +/// # Examples +/// +/// ``` +/// # #![feature(async_await)] +/// use async_task::Task; +/// use crossbeam::channel; +/// use futures::executor; +/// +/// // The future inside the task. +/// let future = async { +/// println!("Hello, world!"); +/// }; +/// +/// // If the task gets woken, it will be sent into this channel. +/// let (s, r) = channel::unbounded(); +/// let schedule = move |task| s.send(task).unwrap(); +/// +/// // Create a task with the future and the schedule function. +/// let (task, handle) = async_task::spawn(future, schedule, ()); +/// +/// // Run the task. In this example, it will complete after a single run. +/// task.run(); +/// assert!(r.is_empty()); +/// +/// // Await its result. +/// executor::block_on(handle); +/// ``` +pub struct Task { + /// A pointer to the heap-allocated task. + pub(crate) raw_task: NonNull<()>, + + /// A marker capturing the generic type `T`. + pub(crate) _marker: PhantomData, +} + +unsafe impl Send for Task {} +unsafe impl Sync for Task {} + +impl Task { + /// Schedules the task. + /// + /// This is a convenience method that simply reschedules the task by passing it to its schedule + /// function. + /// + /// If the task is cancelled, this method won't do anything. + /// + /// # Examples + /// + /// ``` + /// # #![feature(async_await)] + /// use crossbeam::channel; + /// + /// // The future inside the task. + /// let future = async { + /// println!("Hello, world!"); + /// }; + /// + /// // If the task gets woken, it will be sent into this channel. + /// let (s, r) = channel::unbounded(); + /// let schedule = move |task| s.send(task).unwrap(); + /// + /// // Create a task with the future and the schedule function. + /// let (task, handle) = async_task::spawn(future, schedule, ()); + /// + /// // Send the task into the channel. + /// task.schedule(); + /// + /// // Retrieve the task back from the channel. + /// let task = r.recv().unwrap(); + /// ``` + pub fn schedule(self) { + let ptr = self.raw_task.as_ptr(); + let header = ptr as *const Header; + mem::forget(self); + + unsafe { + ((*header).vtable.schedule)(ptr); + } + } + + /// Runs the task. + /// + /// This method polls the task's future. If the future completes, its result will become + /// available to the [`JoinHandle`]. And if the future is still pending, the task will have to + /// be woken in order to be rescheduled and then run again. + /// + /// If the task is cancelled, running it won't do anything. + /// + /// # Panics + /// + /// It is possible that polling the future panics, in which case the panic will be propagated + /// into the caller. It is advised that invocations of this method are wrapped inside + /// [`catch_unwind`]. + /// + /// If a panic occurs, the task is automatically cancelled. + /// + /// [`catch_unwind`]: https://doc.rust-lang.org/std/panic/fn.catch_unwind.html + /// + /// # Examples + /// + /// ``` + /// # #![feature(async_await)] + /// use crossbeam::channel; + /// use futures::executor; + /// + /// // The future inside the task. + /// let future = async { 1 + 2 }; + /// + /// // If the task gets woken, it will be sent into this channel. + /// let (s, r) = channel::unbounded(); + /// let schedule = move |task| s.send(task).unwrap(); + /// + /// // Create a task with the future and the schedule function. + /// let (task, handle) = async_task::spawn(future, schedule, ()); + /// + /// // Run the task. In this example, it will complete after a single run. + /// task.run(); + /// assert!(r.is_empty()); + /// + /// // Await the result of the task. + /// let result = executor::block_on(handle); + /// assert_eq!(result, Some(3)); + /// ``` + pub fn run(self) { + let ptr = self.raw_task.as_ptr(); + let header = ptr as *const Header; + mem::forget(self); + + unsafe { + ((*header).vtable.run)(ptr); + } + } + + /// Cancels the task. + /// + /// When cancelled, the task won't be scheduled again even if a [`Waker`] wakes it. An attempt + /// to run it won't do anything. And if it's completed, awaiting its result evaluates to + /// `None`. + /// + /// [`Waker`]: https://doc.rust-lang.org/std/task/struct.Waker.html + /// + /// # Examples + /// + /// ``` + /// # #![feature(async_await)] + /// use crossbeam::channel; + /// use futures::executor; + /// + /// // The future inside the task. + /// let future = async { 1 + 2 }; + /// + /// // If the task gets woken, it will be sent into this channel. + /// let (s, r) = channel::unbounded(); + /// let schedule = move |task| s.send(task).unwrap(); + /// + /// // Create a task with the future and the schedule function. + /// let (task, handle) = async_task::spawn(future, schedule, ()); + /// + /// // Cancel the task. + /// task.cancel(); + /// + /// // Running a cancelled task does nothing. + /// task.run(); + /// + /// // Await the result of the task. + /// let result = executor::block_on(handle); + /// assert_eq!(result, None); + /// ``` + pub fn cancel(&self) { + let ptr = self.raw_task.as_ptr(); + let header = ptr as *const Header; + + unsafe { + (*header).cancel(); + } + } + + /// Returns a reference to the tag stored inside the task. + /// + /// # Examples + /// + /// ``` + /// # #![feature(async_await)] + /// use crossbeam::channel; + /// + /// // The future inside the task. + /// let future = async { 1 + 2 }; + /// + /// // If the task gets woken, it will be sent into this channel. + /// let (s, r) = channel::unbounded(); + /// let schedule = move |task| s.send(task).unwrap(); + /// + /// // Create a task with the future and the schedule function. + /// let (task, handle) = async_task::spawn(future, schedule, "a simple task"); + /// + /// // Access the tag. + /// assert_eq!(*task.tag(), "a simple task"); + /// ``` + pub fn tag(&self) -> &T { + let offset = Header::offset_tag::(); + let ptr = self.raw_task.as_ptr(); + + unsafe { + let raw = (ptr as *mut u8).add(offset) as *const T; + &*raw + } + } +} + +impl Drop for Task { + fn drop(&mut self) { + let ptr = self.raw_task.as_ptr(); + let header = ptr as *const Header; + + unsafe { + // Cancel the task. + (*header).cancel(); + + // Drop the future. + ((*header).vtable.drop_future)(ptr); + + // Drop the task reference. + ((*header).vtable.decrement)(ptr); + } + } +} + +impl fmt::Debug for Task { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let ptr = self.raw_task.as_ptr(); + let header = ptr as *const Header; + + f.debug_struct("Task") + .field("header", unsafe { &(*header) }) + .field("tag", self.tag()) + .finish() + } +} diff --git a/async-task/src/utils.rs b/async-task/src/utils.rs new file mode 100644 index 00000000..441ead1e --- /dev/null +++ b/async-task/src/utils.rs @@ -0,0 +1,48 @@ +use std::alloc::Layout; +use std::mem; + +/// Calls a function and aborts if it panics. +/// +/// This is useful in unsafe code where we can't recover from panics. +#[inline] +pub(crate) fn abort_on_panic(f: impl FnOnce() -> T) -> T { + struct Bomb; + + impl Drop for Bomb { + fn drop(&mut self) { + std::process::abort(); + } + } + + let bomb = Bomb; + let t = f(); + mem::forget(bomb); + t +} + +/// Returns the layout for `a` followed by `b` and the offset of `b`. +/// +/// This function was adapted from the currently unstable `Layout::extend()`: +/// https://doc.rust-lang.org/nightly/std/alloc/struct.Layout.html#method.extend +#[inline] +pub(crate) fn extend(a: Layout, b: Layout) -> (Layout, usize) { + let new_align = a.align().max(b.align()); + let pad = padding_needed_for(a, b.align()); + + let offset = a.size().checked_add(pad).unwrap(); + let new_size = offset.checked_add(b.size()).unwrap(); + + let layout = Layout::from_size_align(new_size, new_align).unwrap(); + (layout, offset) +} + +/// Returns the padding after `layout` that aligns the following address to `align`. +/// +/// This function was adapted from the currently unstable `Layout::padding_needed_for()`: +/// https://doc.rust-lang.org/nightly/std/alloc/struct.Layout.html#method.padding_needed_for +#[inline] +pub(crate) fn padding_needed_for(layout: Layout, align: usize) -> usize { + let len = layout.size(); + let len_rounded_up = len.wrapping_add(align).wrapping_sub(1) & !align.wrapping_sub(1); + len_rounded_up.wrapping_sub(len) +} diff --git a/async-task/tests/basic.rs b/async-task/tests/basic.rs new file mode 100644 index 00000000..b9e181b1 --- /dev/null +++ b/async-task/tests/basic.rs @@ -0,0 +1,314 @@ +#![feature(async_await)] + +use std::future::Future; +use std::pin::Pin; +use std::sync::atomic::{AtomicUsize, Ordering}; +use std::task::{Context, Poll}; + +use async_task::Task; +use crossbeam::atomic::AtomicCell; +use crossbeam::channel; +use futures::future; +use lazy_static::lazy_static; + +// Creates a future with event counters. +// +// Usage: `future!(f, POLL, DROP)` +// +// The future `f` always returns `Poll::Ready`. +// When it gets polled, `POLL` is incremented. +// When it gets dropped, `DROP` is incremented. +macro_rules! future { + ($name:pat, $poll:ident, $drop:ident) => { + lazy_static! { + static ref $poll: AtomicCell = AtomicCell::new(0); + static ref $drop: AtomicCell = AtomicCell::new(0); + } + + let $name = { + struct Fut(Box); + + impl Future for Fut { + type Output = Box; + + fn poll(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll { + $poll.fetch_add(1); + Poll::Ready(Box::new(0)) + } + } + + impl Drop for Fut { + fn drop(&mut self) { + $drop.fetch_add(1); + } + } + + Fut(Box::new(0)) + }; + }; +} + +// Creates a schedule function with event counters. +// +// Usage: `schedule!(s, SCHED, DROP)` +// +// The schedule function `s` does nothing. +// When it gets invoked, `SCHED` is incremented. +// When it gets dropped, `DROP` is incremented. +macro_rules! schedule { + ($name:pat, $sched:ident, $drop:ident) => { + lazy_static! { + static ref $sched: AtomicCell = AtomicCell::new(0); + static ref $drop: AtomicCell = AtomicCell::new(0); + } + + let $name = { + struct Guard(Box); + + impl Drop for Guard { + fn drop(&mut self) { + $drop.fetch_add(1); + } + } + + let guard = Guard(Box::new(0)); + move |_task| { + &guard; + $sched.fetch_add(1); + } + }; + }; +} + +// Creates a task with event counters. +// +// Usage: `task!(task, handle f, s, DROP)` +// +// A task with future `f` and schedule function `s` is created. +// The `Task` and `JoinHandle` are bound to `task` and `handle`, respectively. +// When the tag inside the task gets dropped, `DROP` is incremented. +macro_rules! task { + ($task:pat, $handle: pat, $future:expr, $schedule:expr, $drop:ident) => { + lazy_static! { + static ref $drop: AtomicCell = AtomicCell::new(0); + } + + let ($task, $handle) = { + struct Tag(Box); + + impl Drop for Tag { + fn drop(&mut self) { + $drop.fetch_add(1); + } + } + + async_task::spawn($future, $schedule, Tag(Box::new(0))) + }; + }; +} + +#[test] +fn cancel_and_drop_handle() { + future!(f, POLL, DROP_F); + schedule!(s, SCHEDULE, DROP_S); + task!(task, handle, f, s, DROP_D); + + assert_eq!(POLL.load(), 0); + assert_eq!(SCHEDULE.load(), 0); + assert_eq!(DROP_F.load(), 0); + assert_eq!(DROP_S.load(), 0); + assert_eq!(DROP_D.load(), 0); + + task.cancel(); + assert_eq!(POLL.load(), 0); + assert_eq!(SCHEDULE.load(), 0); + assert_eq!(DROP_F.load(), 0); + assert_eq!(DROP_S.load(), 0); + assert_eq!(DROP_D.load(), 0); + + drop(handle); + assert_eq!(POLL.load(), 0); + assert_eq!(SCHEDULE.load(), 0); + assert_eq!(DROP_F.load(), 0); + assert_eq!(DROP_S.load(), 0); + assert_eq!(DROP_D.load(), 0); + + drop(task); + assert_eq!(POLL.load(), 0); + assert_eq!(SCHEDULE.load(), 0); + assert_eq!(DROP_F.load(), 1); + assert_eq!(DROP_S.load(), 1); + assert_eq!(DROP_D.load(), 1); +} + +#[test] +fn run_and_drop_handle() { + future!(f, POLL, DROP_F); + schedule!(s, SCHEDULE, DROP_S); + task!(task, handle, f, s, DROP_D); + + drop(handle); + assert_eq!(POLL.load(), 0); + assert_eq!(SCHEDULE.load(), 0); + assert_eq!(DROP_F.load(), 0); + assert_eq!(DROP_S.load(), 0); + assert_eq!(DROP_D.load(), 0); + + task.run(); + assert_eq!(POLL.load(), 1); + assert_eq!(SCHEDULE.load(), 0); + assert_eq!(DROP_F.load(), 1); + assert_eq!(DROP_S.load(), 1); + assert_eq!(DROP_D.load(), 1); +} + +#[test] +fn drop_handle_and_run() { + future!(f, POLL, DROP_F); + schedule!(s, SCHEDULE, DROP_S); + task!(task, handle, f, s, DROP_D); + + drop(handle); + assert_eq!(POLL.load(), 0); + assert_eq!(SCHEDULE.load(), 0); + assert_eq!(DROP_F.load(), 0); + assert_eq!(DROP_S.load(), 0); + assert_eq!(DROP_D.load(), 0); + + task.run(); + assert_eq!(POLL.load(), 1); + assert_eq!(SCHEDULE.load(), 0); + assert_eq!(DROP_F.load(), 1); + assert_eq!(DROP_S.load(), 1); + assert_eq!(DROP_D.load(), 1); +} + +#[test] +fn cancel_and_run() { + future!(f, POLL, DROP_F); + schedule!(s, SCHEDULE, DROP_S); + task!(task, handle, f, s, DROP_D); + + handle.cancel(); + assert_eq!(POLL.load(), 0); + assert_eq!(SCHEDULE.load(), 0); + assert_eq!(DROP_F.load(), 0); + assert_eq!(DROP_S.load(), 0); + assert_eq!(DROP_D.load(), 0); + + drop(handle); + assert_eq!(POLL.load(), 0); + assert_eq!(SCHEDULE.load(), 0); + assert_eq!(DROP_F.load(), 0); + assert_eq!(DROP_S.load(), 0); + assert_eq!(DROP_D.load(), 0); + + task.run(); + assert_eq!(POLL.load(), 0); + assert_eq!(SCHEDULE.load(), 0); + assert_eq!(DROP_F.load(), 1); + assert_eq!(DROP_S.load(), 1); + assert_eq!(DROP_D.load(), 1); +} + +#[test] +fn run_and_cancel() { + future!(f, POLL, DROP_F); + schedule!(s, SCHEDULE, DROP_S); + task!(task, handle, f, s, DROP_D); + + task.run(); + assert_eq!(POLL.load(), 1); + assert_eq!(SCHEDULE.load(), 0); + assert_eq!(DROP_F.load(), 1); + assert_eq!(DROP_S.load(), 0); + assert_eq!(DROP_D.load(), 0); + + handle.cancel(); + assert_eq!(POLL.load(), 1); + assert_eq!(SCHEDULE.load(), 0); + assert_eq!(DROP_F.load(), 1); + assert_eq!(DROP_S.load(), 0); + assert_eq!(DROP_D.load(), 0); + + drop(handle); + assert_eq!(POLL.load(), 1); + assert_eq!(SCHEDULE.load(), 0); + assert_eq!(DROP_F.load(), 1); + assert_eq!(DROP_S.load(), 1); + assert_eq!(DROP_D.load(), 1); +} + +#[test] +fn schedule() { + let (s, r) = channel::unbounded(); + let schedule = move |t| s.send(t).unwrap(); + let (task, _handle) = async_task::spawn( + future::poll_fn(|_| Poll::<()>::Pending), + schedule, + Box::new(0), + ); + + assert!(r.is_empty()); + task.schedule(); + + let task = r.recv().unwrap(); + assert!(r.is_empty()); + task.schedule(); + + let task = r.recv().unwrap(); + assert!(r.is_empty()); + task.schedule(); + + r.recv().unwrap(); +} + +#[test] +fn tag() { + let (s, r) = channel::unbounded(); + let schedule = move |t| s.send(t).unwrap(); + let (task, handle) = async_task::spawn( + future::poll_fn(|_| Poll::<()>::Pending), + schedule, + AtomicUsize::new(7), + ); + + assert!(r.is_empty()); + task.schedule(); + + let task = r.recv().unwrap(); + assert!(r.is_empty()); + handle.tag().fetch_add(1, Ordering::SeqCst); + task.schedule(); + + let task = r.recv().unwrap(); + assert_eq!(task.tag().load(Ordering::SeqCst), 8); + assert!(r.is_empty()); + task.schedule(); + + r.recv().unwrap(); +} + +#[test] +fn schedule_counter() { + let (s, r) = channel::unbounded(); + let schedule = move |t: Task| { + t.tag().fetch_add(1, Ordering::SeqCst); + s.send(t).unwrap(); + }; + let (task, handle) = async_task::spawn( + future::poll_fn(|_| Poll::<()>::Pending), + schedule, + AtomicUsize::new(0), + ); + task.schedule(); + + assert_eq!(handle.tag().load(Ordering::SeqCst), 1); + r.recv().unwrap().schedule(); + + assert_eq!(handle.tag().load(Ordering::SeqCst), 2); + r.recv().unwrap().schedule(); + + assert_eq!(handle.tag().load(Ordering::SeqCst), 3); + r.recv().unwrap(); +} diff --git a/async-task/tests/join.rs b/async-task/tests/join.rs new file mode 100644 index 00000000..e0829394 --- /dev/null +++ b/async-task/tests/join.rs @@ -0,0 +1,454 @@ +#![feature(async_await)] + +use std::cell::Cell; +use std::future::Future; +use std::pin::Pin; +use std::task::{Context, Poll}; +use std::thread; +use std::time::Duration; + +use async_task::Task; +use crossbeam::atomic::AtomicCell; +use futures::executor::block_on; +use futures::future; +use lazy_static::lazy_static; + +// Creates a future with event counters. +// +// Usage: `future!(f, POLL, DROP_F, DROP_O)` +// +// The future `f` outputs `Poll::Ready`. +// When it gets polled, `POLL` is incremented. +// When it gets dropped, `DROP_F` is incremented. +// When the output gets dropped, `DROP_O` is incremented. +macro_rules! future { + ($name:pat, $poll:ident, $drop_f:ident, $drop_o:ident) => { + lazy_static! { + static ref $poll: AtomicCell = AtomicCell::new(0); + static ref $drop_f: AtomicCell = AtomicCell::new(0); + static ref $drop_o: AtomicCell = AtomicCell::new(0); + } + + let $name = { + struct Fut(Box); + + impl Future for Fut { + type Output = Out; + + fn poll(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll { + $poll.fetch_add(1); + Poll::Ready(Out(Box::new(0))) + } + } + + impl Drop for Fut { + fn drop(&mut self) { + $drop_f.fetch_add(1); + } + } + + struct Out(Box); + + impl Drop for Out { + fn drop(&mut self) { + $drop_o.fetch_add(1); + } + } + + Fut(Box::new(0)) + }; + }; +} + +// Creates a schedule function with event counters. +// +// Usage: `schedule!(s, SCHED, DROP)` +// +// The schedule function `s` does nothing. +// When it gets invoked, `SCHED` is incremented. +// When it gets dropped, `DROP` is incremented. +macro_rules! schedule { + ($name:pat, $sched:ident, $drop:ident) => { + lazy_static! { + static ref $sched: AtomicCell = AtomicCell::new(0); + static ref $drop: AtomicCell = AtomicCell::new(0); + } + + let $name = { + struct Guard(Box); + + impl Drop for Guard { + fn drop(&mut self) { + $drop.fetch_add(1); + } + } + + let guard = Guard(Box::new(0)); + move |task: Task<_>| { + &guard; + task.schedule(); + $sched.fetch_add(1); + } + }; + }; +} + +// Creates a task with event counters. +// +// Usage: `task!(task, handle f, s, DROP)` +// +// A task with future `f` and schedule function `s` is created. +// The `Task` and `JoinHandle` are bound to `task` and `handle`, respectively. +// When the tag inside the task gets dropped, `DROP` is incremented. +macro_rules! task { + ($task:pat, $handle: pat, $future:expr, $schedule:expr, $drop:ident) => { + lazy_static! { + static ref $drop: AtomicCell = AtomicCell::new(0); + } + + let ($task, $handle) = { + struct Tag(Box); + + impl Drop for Tag { + fn drop(&mut self) { + $drop.fetch_add(1); + } + } + + async_task::spawn($future, $schedule, Tag(Box::new(0))) + }; + }; +} + +fn ms(ms: u64) -> Duration { + Duration::from_millis(ms) +} + +#[test] +fn cancel_and_join() { + future!(f, POLL, DROP_F, DROP_O); + schedule!(s, SCHEDULE, DROP_S); + task!(task, handle, f, s, DROP_D); + + assert_eq!(DROP_O.load(), 0); + + task.cancel(); + drop(task); + assert_eq!(DROP_O.load(), 0); + + assert!(block_on(handle).is_none()); + assert_eq!(POLL.load(), 0); + assert_eq!(SCHEDULE.load(), 0); + assert_eq!(DROP_F.load(), 1); + assert_eq!(DROP_S.load(), 1); + assert_eq!(DROP_D.load(), 1); + assert_eq!(DROP_O.load(), 0); +} + +#[test] +fn run_and_join() { + future!(f, POLL, DROP_F, DROP_O); + schedule!(s, SCHEDULE, DROP_S); + task!(task, handle, f, s, DROP_D); + + assert_eq!(DROP_O.load(), 0); + + task.run(); + assert_eq!(DROP_O.load(), 0); + + assert!(block_on(handle).is_some()); + assert_eq!(POLL.load(), 1); + assert_eq!(SCHEDULE.load(), 0); + assert_eq!(DROP_F.load(), 1); + assert_eq!(DROP_S.load(), 1); + assert_eq!(DROP_D.load(), 1); + assert_eq!(DROP_O.load(), 1); +} + +#[test] +fn drop_handle_and_run() { + future!(f, POLL, DROP_F, DROP_O); + schedule!(s, SCHEDULE, DROP_S); + task!(task, handle, f, s, DROP_D); + + assert_eq!(DROP_O.load(), 0); + + drop(handle); + assert_eq!(DROP_O.load(), 0); + + task.run(); + assert_eq!(POLL.load(), 1); + assert_eq!(SCHEDULE.load(), 0); + assert_eq!(DROP_F.load(), 1); + assert_eq!(DROP_S.load(), 1); + assert_eq!(DROP_D.load(), 1); + assert_eq!(DROP_O.load(), 1); +} + +#[test] +fn join_twice() { + future!(f, POLL, DROP_F, DROP_O); + schedule!(s, SCHEDULE, DROP_S); + task!(task, mut handle, f, s, DROP_D); + + assert_eq!(DROP_O.load(), 0); + + task.run(); + assert_eq!(DROP_O.load(), 0); + + assert!(block_on(&mut handle).is_some()); + assert_eq!(POLL.load(), 1); + assert_eq!(SCHEDULE.load(), 0); + assert_eq!(DROP_F.load(), 1); + assert_eq!(DROP_S.load(), 0); + assert_eq!(DROP_D.load(), 0); + assert_eq!(DROP_O.load(), 1); + + assert!(block_on(&mut handle).is_none()); + assert_eq!(POLL.load(), 1); + assert_eq!(SCHEDULE.load(), 0); + assert_eq!(DROP_F.load(), 1); + assert_eq!(DROP_S.load(), 0); + assert_eq!(DROP_D.load(), 0); + assert_eq!(DROP_O.load(), 1); + + drop(handle); + assert_eq!(DROP_S.load(), 1); + assert_eq!(DROP_D.load(), 1); +} + +#[test] +fn join_and_cancel() { + future!(f, POLL, DROP_F, DROP_O); + schedule!(s, SCHEDULE, DROP_S); + task!(task, handle, f, s, DROP_D); + + crossbeam::scope(|scope| { + scope.spawn(|_| { + thread::sleep(ms(100)); + + task.cancel(); + drop(task); + + thread::sleep(ms(200)); + assert_eq!(POLL.load(), 0); + assert_eq!(SCHEDULE.load(), 0); + assert_eq!(DROP_F.load(), 1); + assert_eq!(DROP_O.load(), 0); + assert_eq!(DROP_S.load(), 1); + assert_eq!(DROP_D.load(), 1); + }); + + assert!(block_on(handle).is_none()); + assert_eq!(POLL.load(), 0); + assert_eq!(SCHEDULE.load(), 0); + + thread::sleep(ms(100)); + assert_eq!(DROP_F.load(), 1); + assert_eq!(DROP_O.load(), 0); + assert_eq!(DROP_S.load(), 1); + assert_eq!(DROP_D.load(), 1); + }) + .unwrap(); +} + +#[test] +fn join_and_run() { + future!(f, POLL, DROP_F, DROP_O); + schedule!(s, SCHEDULE, DROP_S); + task!(task, handle, f, s, DROP_D); + + crossbeam::scope(|scope| { + scope.spawn(|_| { + thread::sleep(ms(200)); + + task.run(); + assert_eq!(POLL.load(), 1); + assert_eq!(SCHEDULE.load(), 0); + assert_eq!(DROP_F.load(), 1); + + thread::sleep(ms(100)); + assert_eq!(DROP_S.load(), 1); + assert_eq!(DROP_D.load(), 1); + }); + + assert!(block_on(handle).is_some()); + assert_eq!(POLL.load(), 1); + assert_eq!(SCHEDULE.load(), 0); + assert_eq!(DROP_F.load(), 1); + assert_eq!(DROP_O.load(), 1); + + thread::sleep(ms(100)); + assert_eq!(DROP_S.load(), 1); + assert_eq!(DROP_D.load(), 1); + }) + .unwrap(); +} + +#[test] +fn try_join_and_run_and_join() { + future!(f, POLL, DROP_F, DROP_O); + schedule!(s, SCHEDULE, DROP_S); + task!(task, mut handle, f, s, DROP_D); + + crossbeam::scope(|scope| { + scope.spawn(|_| { + thread::sleep(ms(200)); + + task.run(); + assert_eq!(POLL.load(), 1); + assert_eq!(SCHEDULE.load(), 0); + assert_eq!(DROP_F.load(), 1); + + thread::sleep(ms(100)); + assert_eq!(DROP_S.load(), 1); + assert_eq!(DROP_D.load(), 1); + }); + + block_on(future::select(&mut handle, future::ready(()))); + assert_eq!(POLL.load(), 0); + assert_eq!(SCHEDULE.load(), 0); + assert_eq!(DROP_F.load(), 0); + assert_eq!(DROP_S.load(), 0); + assert_eq!(DROP_D.load(), 0); + assert_eq!(DROP_O.load(), 0); + + assert!(block_on(handle).is_some()); + assert_eq!(POLL.load(), 1); + assert_eq!(SCHEDULE.load(), 0); + assert_eq!(DROP_F.load(), 1); + assert_eq!(DROP_O.load(), 1); + + thread::sleep(ms(100)); + assert_eq!(DROP_S.load(), 1); + assert_eq!(DROP_D.load(), 1); + }) + .unwrap(); +} + +#[test] +fn try_join_and_cancel_and_run() { + future!(f, POLL, DROP_F, DROP_O); + schedule!(s, SCHEDULE, DROP_S); + task!(task, mut handle, f, s, DROP_D); + + crossbeam::scope(|scope| { + scope.spawn(|_| { + thread::sleep(ms(200)); + + task.run(); + assert_eq!(POLL.load(), 0); + assert_eq!(SCHEDULE.load(), 0); + assert_eq!(DROP_F.load(), 1); + assert_eq!(DROP_S.load(), 1); + assert_eq!(DROP_D.load(), 1); + }); + + block_on(future::select(&mut handle, future::ready(()))); + assert_eq!(POLL.load(), 0); + assert_eq!(SCHEDULE.load(), 0); + assert_eq!(DROP_F.load(), 0); + assert_eq!(DROP_S.load(), 0); + assert_eq!(DROP_D.load(), 0); + assert_eq!(DROP_O.load(), 0); + + handle.cancel(); + assert_eq!(POLL.load(), 0); + assert_eq!(SCHEDULE.load(), 0); + assert_eq!(DROP_F.load(), 0); + assert_eq!(DROP_S.load(), 0); + assert_eq!(DROP_D.load(), 0); + assert_eq!(DROP_O.load(), 0); + + drop(handle); + assert_eq!(POLL.load(), 0); + assert_eq!(SCHEDULE.load(), 0); + assert_eq!(DROP_F.load(), 0); + assert_eq!(DROP_S.load(), 0); + assert_eq!(DROP_D.load(), 0); + assert_eq!(DROP_O.load(), 0); + }) + .unwrap(); +} + +#[test] +fn try_join_and_run_and_cancel() { + future!(f, POLL, DROP_F, DROP_O); + schedule!(s, SCHEDULE, DROP_S); + task!(task, mut handle, f, s, DROP_D); + + crossbeam::scope(|scope| { + scope.spawn(|_| { + thread::sleep(ms(200)); + + task.run(); + assert_eq!(POLL.load(), 1); + assert_eq!(SCHEDULE.load(), 0); + assert_eq!(DROP_F.load(), 1); + assert_eq!(DROP_S.load(), 0); + assert_eq!(DROP_D.load(), 0); + }); + + block_on(future::select(&mut handle, future::ready(()))); + assert_eq!(POLL.load(), 0); + assert_eq!(SCHEDULE.load(), 0); + assert_eq!(DROP_F.load(), 0); + assert_eq!(DROP_S.load(), 0); + assert_eq!(DROP_D.load(), 0); + assert_eq!(DROP_O.load(), 0); + + thread::sleep(ms(400)); + + handle.cancel(); + assert_eq!(POLL.load(), 1); + assert_eq!(SCHEDULE.load(), 0); + assert_eq!(DROP_F.load(), 1); + assert_eq!(DROP_S.load(), 0); + assert_eq!(DROP_D.load(), 0); + assert_eq!(DROP_O.load(), 0); + + drop(handle); + assert_eq!(POLL.load(), 1); + assert_eq!(SCHEDULE.load(), 0); + assert_eq!(DROP_F.load(), 1); + assert_eq!(DROP_S.load(), 1); + assert_eq!(DROP_D.load(), 1); + assert_eq!(DROP_O.load(), 1); + }) + .unwrap(); +} + +#[test] +fn await_output() { + struct Fut(Cell>); + + impl Fut { + fn new(t: T) -> Fut { + Fut(Cell::new(Some(t))) + } + } + + impl Future for Fut { + type Output = T; + + fn poll(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll { + Poll::Ready(self.0.take().unwrap()) + } + } + + for i in 0..10 { + let (task, handle) = async_task::spawn(Fut::new(i), drop, Box::new(0)); + task.run(); + assert_eq!(block_on(handle), Some(i)); + } + + for i in 0..10 { + let (task, handle) = async_task::spawn(Fut::new(vec![7; i]), drop, Box::new(0)); + task.run(); + assert_eq!(block_on(handle), Some(vec![7; i])); + } + + let (task, handle) = async_task::spawn(Fut::new("foo".to_string()), drop, Box::new(0)); + task.run(); + assert_eq!(block_on(handle), Some("foo".to_string())); +} diff --git a/async-task/tests/panic.rs b/async-task/tests/panic.rs new file mode 100644 index 00000000..68058a22 --- /dev/null +++ b/async-task/tests/panic.rs @@ -0,0 +1,288 @@ +#![feature(async_await)] + +use std::future::Future; +use std::panic::catch_unwind; +use std::pin::Pin; +use std::task::{Context, Poll}; +use std::thread; +use std::time::Duration; + +use async_task::Task; +use crossbeam::atomic::AtomicCell; +use futures::executor::block_on; +use futures::future; +use lazy_static::lazy_static; + +// Creates a future with event counters. +// +// Usage: `future!(f, POLL, DROP)` +// +// The future `f` sleeps for 200 ms and then panics. +// When it gets polled, `POLL` is incremented. +// When it gets dropped, `DROP` is incremented. +macro_rules! future { + ($name:pat, $poll:ident, $drop:ident) => { + lazy_static! { + static ref $poll: AtomicCell = AtomicCell::new(0); + static ref $drop: AtomicCell = AtomicCell::new(0); + } + + let $name = { + struct Fut(Box); + + impl Future for Fut { + type Output = (); + + fn poll(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll { + $poll.fetch_add(1); + thread::sleep(ms(200)); + panic!() + } + } + + impl Drop for Fut { + fn drop(&mut self) { + $drop.fetch_add(1); + } + } + + Fut(Box::new(0)) + }; + }; +} + +// Creates a schedule function with event counters. +// +// Usage: `schedule!(s, SCHED, DROP)` +// +// The schedule function `s` does nothing. +// When it gets invoked, `SCHED` is incremented. +// When it gets dropped, `DROP` is incremented. +macro_rules! schedule { + ($name:pat, $sched:ident, $drop:ident) => { + lazy_static! { + static ref $sched: AtomicCell = AtomicCell::new(0); + static ref $drop: AtomicCell = AtomicCell::new(0); + } + + let $name = { + struct Guard(Box); + + impl Drop for Guard { + fn drop(&mut self) { + $drop.fetch_add(1); + } + } + + let guard = Guard(Box::new(0)); + move |_task: Task<_>| { + &guard; + $sched.fetch_add(1); + } + }; + }; +} + +// Creates a task with event counters. +// +// Usage: `task!(task, handle f, s, DROP)` +// +// A task with future `f` and schedule function `s` is created. +// The `Task` and `JoinHandle` are bound to `task` and `handle`, respectively. +// When the tag inside the task gets dropped, `DROP` is incremented. +macro_rules! task { + ($task:pat, $handle: pat, $future:expr, $schedule:expr, $drop:ident) => { + lazy_static! { + static ref $drop: AtomicCell = AtomicCell::new(0); + } + + let ($task, $handle) = { + struct Tag(Box); + + impl Drop for Tag { + fn drop(&mut self) { + $drop.fetch_add(1); + } + } + + async_task::spawn($future, $schedule, Tag(Box::new(0))) + }; + }; +} + +fn ms(ms: u64) -> Duration { + Duration::from_millis(ms) +} + +#[test] +fn cancel_during_run() { + future!(f, POLL, DROP_F); + schedule!(s, SCHEDULE, DROP_S); + task!(task, handle, f, s, DROP_D); + + crossbeam::scope(|scope| { + scope.spawn(|_| { + assert!(catch_unwind(|| task.run()).is_err()); + assert_eq!(POLL.load(), 1); + assert_eq!(SCHEDULE.load(), 0); + assert_eq!(DROP_F.load(), 1); + assert_eq!(DROP_S.load(), 1); + assert_eq!(DROP_D.load(), 1); + }); + + thread::sleep(ms(100)); + + handle.cancel(); + assert_eq!(POLL.load(), 1); + assert_eq!(SCHEDULE.load(), 0); + assert_eq!(DROP_F.load(), 0); + assert_eq!(DROP_S.load(), 0); + assert_eq!(DROP_D.load(), 0); + + drop(handle); + assert_eq!(POLL.load(), 1); + assert_eq!(SCHEDULE.load(), 0); + assert_eq!(DROP_F.load(), 0); + assert_eq!(DROP_S.load(), 0); + assert_eq!(DROP_D.load(), 0); + }) + .unwrap(); +} + +#[test] +fn run_and_join() { + future!(f, POLL, DROP_F); + schedule!(s, SCHEDULE, DROP_S); + task!(task, handle, f, s, DROP_D); + + assert!(catch_unwind(|| task.run()).is_err()); + assert_eq!(POLL.load(), 1); + assert_eq!(SCHEDULE.load(), 0); + assert_eq!(DROP_F.load(), 1); + assert_eq!(DROP_S.load(), 0); + assert_eq!(DROP_D.load(), 0); + + assert!(block_on(handle).is_none()); + assert_eq!(POLL.load(), 1); + assert_eq!(SCHEDULE.load(), 0); + assert_eq!(DROP_F.load(), 1); + assert_eq!(DROP_S.load(), 1); + assert_eq!(DROP_D.load(), 1); +} + +#[test] +fn try_join_and_run_and_join() { + future!(f, POLL, DROP_F); + schedule!(s, SCHEDULE, DROP_S); + task!(task, mut handle, f, s, DROP_D); + + block_on(future::select(&mut handle, future::ready(()))); + assert_eq!(POLL.load(), 0); + assert_eq!(SCHEDULE.load(), 0); + assert_eq!(DROP_F.load(), 0); + assert_eq!(DROP_S.load(), 0); + assert_eq!(DROP_D.load(), 0); + + assert!(catch_unwind(|| task.run()).is_err()); + assert_eq!(POLL.load(), 1); + assert_eq!(SCHEDULE.load(), 0); + assert_eq!(DROP_F.load(), 1); + assert_eq!(DROP_S.load(), 0); + assert_eq!(DROP_D.load(), 0); + + assert!(block_on(handle).is_none()); + assert_eq!(POLL.load(), 1); + assert_eq!(SCHEDULE.load(), 0); + assert_eq!(DROP_F.load(), 1); + assert_eq!(DROP_S.load(), 1); + assert_eq!(DROP_D.load(), 1); +} + +#[test] +fn join_during_run() { + future!(f, POLL, DROP_F); + schedule!(s, SCHEDULE, DROP_S); + task!(task, handle, f, s, DROP_D); + + crossbeam::scope(|scope| { + scope.spawn(|_| { + assert!(catch_unwind(|| task.run()).is_err()); + assert_eq!(POLL.load(), 1); + assert_eq!(SCHEDULE.load(), 0); + assert_eq!(DROP_F.load(), 1); + + thread::sleep(ms(100)); + assert_eq!(DROP_S.load(), 1); + assert_eq!(DROP_D.load(), 1); + }); + + thread::sleep(ms(100)); + + assert!(block_on(handle).is_none()); + assert_eq!(POLL.load(), 1); + assert_eq!(SCHEDULE.load(), 0); + assert_eq!(DROP_F.load(), 1); + + thread::sleep(ms(100)); + assert_eq!(DROP_S.load(), 1); + assert_eq!(DROP_D.load(), 1); + }) + .unwrap(); +} + +#[test] +fn try_join_during_run() { + future!(f, POLL, DROP_F); + schedule!(s, SCHEDULE, DROP_S); + task!(task, mut handle, f, s, DROP_D); + + crossbeam::scope(|scope| { + scope.spawn(|_| { + assert!(catch_unwind(|| task.run()).is_err()); + assert_eq!(POLL.load(), 1); + assert_eq!(SCHEDULE.load(), 0); + assert_eq!(DROP_F.load(), 1); + assert_eq!(DROP_S.load(), 1); + assert_eq!(DROP_D.load(), 1); + }); + + thread::sleep(ms(100)); + + block_on(future::select(&mut handle, future::ready(()))); + assert_eq!(POLL.load(), 1); + assert_eq!(SCHEDULE.load(), 0); + assert_eq!(DROP_F.load(), 0); + assert_eq!(DROP_S.load(), 0); + assert_eq!(DROP_D.load(), 0); + drop(handle); + }) + .unwrap(); +} + +#[test] +fn drop_handle_during_run() { + future!(f, POLL, DROP_F); + schedule!(s, SCHEDULE, DROP_S); + task!(task, handle, f, s, DROP_D); + + crossbeam::scope(|scope| { + scope.spawn(|_| { + assert!(catch_unwind(|| task.run()).is_err()); + assert_eq!(POLL.load(), 1); + assert_eq!(SCHEDULE.load(), 0); + assert_eq!(DROP_F.load(), 1); + assert_eq!(DROP_S.load(), 1); + assert_eq!(DROP_D.load(), 1); + }); + + thread::sleep(ms(100)); + + drop(handle); + assert_eq!(POLL.load(), 1); + assert_eq!(SCHEDULE.load(), 0); + assert_eq!(DROP_F.load(), 0); + assert_eq!(DROP_S.load(), 0); + assert_eq!(DROP_D.load(), 0); + }) + .unwrap(); +} diff --git a/async-task/tests/ready.rs b/async-task/tests/ready.rs new file mode 100644 index 00000000..ecca328b --- /dev/null +++ b/async-task/tests/ready.rs @@ -0,0 +1,265 @@ +#![feature(async_await)] + +use std::future::Future; +use std::pin::Pin; +use std::task::{Context, Poll}; +use std::thread; +use std::time::Duration; + +use async_task::Task; +use crossbeam::atomic::AtomicCell; +use futures::executor::block_on; +use futures::future; +use lazy_static::lazy_static; + +// Creates a future with event counters. +// +// Usage: `future!(f, POLL, DROP_F, DROP_O)` +// +// The future `f` sleeps for 200 ms and outputs `Poll::Ready`. +// When it gets polled, `POLL` is incremented. +// When it gets dropped, `DROP_F` is incremented. +// When the output gets dropped, `DROP_O` is incremented. +macro_rules! future { + ($name:pat, $poll:ident, $drop_f:ident, $drop_o:ident) => { + lazy_static! { + static ref $poll: AtomicCell = AtomicCell::new(0); + static ref $drop_f: AtomicCell = AtomicCell::new(0); + static ref $drop_o: AtomicCell = AtomicCell::new(0); + } + + let $name = { + struct Fut(Box); + + impl Future for Fut { + type Output = Out; + + fn poll(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll { + $poll.fetch_add(1); + thread::sleep(ms(200)); + Poll::Ready(Out(Box::new(0))) + } + } + + impl Drop for Fut { + fn drop(&mut self) { + $drop_f.fetch_add(1); + } + } + + struct Out(Box); + + impl Drop for Out { + fn drop(&mut self) { + $drop_o.fetch_add(1); + } + } + + Fut(Box::new(0)) + }; + }; +} + +// Creates a schedule function with event counters. +// +// Usage: `schedule!(s, SCHED, DROP)` +// +// The schedule function `s` does nothing. +// When it gets invoked, `SCHED` is incremented. +// When it gets dropped, `DROP` is incremented. +macro_rules! schedule { + ($name:pat, $sched:ident, $drop:ident) => { + lazy_static! { + static ref $sched: AtomicCell = AtomicCell::new(0); + static ref $drop: AtomicCell = AtomicCell::new(0); + } + + let $name = { + struct Guard(Box); + + impl Drop for Guard { + fn drop(&mut self) { + $drop.fetch_add(1); + } + } + + let guard = Guard(Box::new(0)); + move |_task: Task<_>| { + &guard; + $sched.fetch_add(1); + } + }; + }; +} + +// Creates a task with event counters. +// +// Usage: `task!(task, handle f, s, DROP)` +// +// A task with future `f` and schedule function `s` is created. +// The `Task` and `JoinHandle` are bound to `task` and `handle`, respectively. +// When the tag inside the task gets dropped, `DROP` is incremented. +macro_rules! task { + ($task:pat, $handle: pat, $future:expr, $schedule:expr, $drop:ident) => { + lazy_static! { + static ref $drop: AtomicCell = AtomicCell::new(0); + } + + let ($task, $handle) = { + struct Tag(Box); + + impl Drop for Tag { + fn drop(&mut self) { + $drop.fetch_add(1); + } + } + + async_task::spawn($future, $schedule, Tag(Box::new(0))) + }; + }; +} + +fn ms(ms: u64) -> Duration { + Duration::from_millis(ms) +} + +#[test] +fn cancel_during_run() { + future!(f, POLL, DROP_F, DROP_O); + schedule!(s, SCHEDULE, DROP_S); + task!(task, handle, f, s, DROP_D); + + crossbeam::scope(|scope| { + scope.spawn(|_| { + task.run(); + assert_eq!(POLL.load(), 1); + assert_eq!(SCHEDULE.load(), 0); + assert_eq!(DROP_F.load(), 1); + assert_eq!(DROP_S.load(), 0); + assert_eq!(DROP_D.load(), 0); + assert_eq!(DROP_O.load(), 1); + }); + + thread::sleep(ms(100)); + + handle.cancel(); + assert_eq!(POLL.load(), 1); + assert_eq!(SCHEDULE.load(), 0); + assert_eq!(DROP_F.load(), 0); + assert_eq!(DROP_S.load(), 0); + assert_eq!(DROP_D.load(), 0); + assert_eq!(DROP_O.load(), 0); + + thread::sleep(ms(200)); + + assert_eq!(POLL.load(), 1); + assert_eq!(SCHEDULE.load(), 0); + assert_eq!(DROP_F.load(), 1); + assert_eq!(DROP_S.load(), 0); + assert_eq!(DROP_D.load(), 0); + assert_eq!(DROP_O.load(), 1); + + drop(handle); + assert_eq!(POLL.load(), 1); + assert_eq!(SCHEDULE.load(), 0); + assert_eq!(DROP_F.load(), 1); + assert_eq!(DROP_S.load(), 1); + assert_eq!(DROP_D.load(), 1); + assert_eq!(DROP_O.load(), 1); + }) + .unwrap(); +} + +#[test] +fn join_during_run() { + future!(f, POLL, DROP_F, DROP_O); + schedule!(s, SCHEDULE, DROP_S); + task!(task, handle, f, s, DROP_D); + + crossbeam::scope(|scope| { + scope.spawn(|_| { + task.run(); + assert_eq!(POLL.load(), 1); + assert_eq!(SCHEDULE.load(), 0); + assert_eq!(DROP_F.load(), 1); + + thread::sleep(ms(100)); + assert_eq!(DROP_S.load(), 1); + assert_eq!(DROP_D.load(), 1); + }); + + thread::sleep(ms(100)); + + assert!(block_on(handle).is_some()); + assert_eq!(POLL.load(), 1); + assert_eq!(SCHEDULE.load(), 0); + assert_eq!(DROP_F.load(), 1); + assert_eq!(DROP_O.load(), 1); + + thread::sleep(ms(100)); + assert_eq!(DROP_S.load(), 1); + assert_eq!(DROP_D.load(), 1); + }) + .unwrap(); +} + +#[test] +fn try_join_during_run() { + future!(f, POLL, DROP_F, DROP_O); + schedule!(s, SCHEDULE, DROP_S); + task!(task, mut handle, f, s, DROP_D); + + crossbeam::scope(|scope| { + scope.spawn(|_| { + task.run(); + assert_eq!(POLL.load(), 1); + assert_eq!(SCHEDULE.load(), 0); + assert_eq!(DROP_F.load(), 1); + assert_eq!(DROP_S.load(), 1); + assert_eq!(DROP_D.load(), 1); + assert_eq!(DROP_O.load(), 1); + }); + + thread::sleep(ms(100)); + + block_on(future::select(&mut handle, future::ready(()))); + assert_eq!(POLL.load(), 1); + assert_eq!(SCHEDULE.load(), 0); + assert_eq!(DROP_F.load(), 0); + assert_eq!(DROP_S.load(), 0); + assert_eq!(DROP_D.load(), 0); + assert_eq!(DROP_O.load(), 0); + drop(handle); + }) + .unwrap(); +} + +#[test] +fn drop_handle_during_run() { + future!(f, POLL, DROP_F, DROP_O); + schedule!(s, SCHEDULE, DROP_S); + task!(task, handle, f, s, DROP_D); + + crossbeam::scope(|scope| { + scope.spawn(|_| { + task.run(); + assert_eq!(POLL.load(), 1); + assert_eq!(SCHEDULE.load(), 0); + assert_eq!(DROP_F.load(), 1); + assert_eq!(DROP_S.load(), 1); + assert_eq!(DROP_D.load(), 1); + assert_eq!(DROP_O.load(), 1); + }); + + thread::sleep(ms(100)); + + drop(handle); + assert_eq!(POLL.load(), 1); + assert_eq!(SCHEDULE.load(), 0); + assert_eq!(DROP_F.load(), 0); + assert_eq!(DROP_S.load(), 0); + assert_eq!(DROP_D.load(), 0); + assert_eq!(DROP_O.load(), 0); + }) + .unwrap(); +} diff --git a/async-task/tests/waker_panic.rs b/async-task/tests/waker_panic.rs new file mode 100644 index 00000000..a683f26f --- /dev/null +++ b/async-task/tests/waker_panic.rs @@ -0,0 +1,357 @@ +#![feature(async_await)] + +use std::cell::Cell; +use std::future::Future; +use std::panic::catch_unwind; +use std::pin::Pin; +use std::task::Waker; +use std::task::{Context, Poll}; +use std::thread; +use std::time::Duration; + +use async_task::Task; +use crossbeam::atomic::AtomicCell; +use crossbeam::channel; +use lazy_static::lazy_static; + +// Creates a future with event counters. +// +// Usage: `future!(f, waker, POLL, DROP)` +// +// The future `f` always sleeps for 200 ms, and panics the second time it is polled. +// When it gets polled, `POLL` is incremented. +// When it gets dropped, `DROP` is incremented. +// +// Every time the future is run, it stores the waker into a global variable. +// This waker can be extracted using the `waker` function. +macro_rules! future { + ($name:pat, $waker:pat, $poll:ident, $drop:ident) => { + lazy_static! { + static ref $poll: AtomicCell = AtomicCell::new(0); + static ref $drop: AtomicCell = AtomicCell::new(0); + static ref WAKER: AtomicCell> = AtomicCell::new(None); + } + + let ($name, $waker) = { + struct Fut(Cell, Box); + + impl Future for Fut { + type Output = (); + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + WAKER.store(Some(cx.waker().clone())); + $poll.fetch_add(1); + thread::sleep(ms(200)); + + if self.0.get() { + panic!() + } else { + self.0.set(true); + Poll::Pending + } + } + } + + impl Drop for Fut { + fn drop(&mut self) { + $drop.fetch_add(1); + } + } + + (Fut(Cell::new(false), Box::new(0)), || { + WAKER.swap(None).unwrap() + }) + }; + }; +} + +// Creates a schedule function with event counters. +// +// Usage: `schedule!(s, chan, SCHED, DROP)` +// +// The schedule function `s` pushes the task into `chan`. +// When it gets invoked, `SCHED` is incremented. +// When it gets dropped, `DROP` is incremented. +// +// Receiver `chan` extracts the task when it is scheduled. +macro_rules! schedule { + ($name:pat, $chan:pat, $sched:ident, $drop:ident) => { + lazy_static! { + static ref $sched: AtomicCell = AtomicCell::new(0); + static ref $drop: AtomicCell = AtomicCell::new(0); + } + + let ($name, $chan) = { + let (s, r) = channel::unbounded(); + + struct Guard(Box); + + impl Drop for Guard { + fn drop(&mut self) { + $drop.fetch_add(1); + } + } + + let guard = Guard(Box::new(0)); + let sched = move |task: Task<_>| { + &guard; + $sched.fetch_add(1); + s.send(task).unwrap(); + }; + + (sched, r) + }; + }; +} + +// Creates a task with event counters. +// +// Usage: `task!(task, handle f, s, DROP)` +// +// A task with future `f` and schedule function `s` is created. +// The `Task` and `JoinHandle` are bound to `task` and `handle`, respectively. +// When the tag inside the task gets dropped, `DROP` is incremented. +macro_rules! task { + ($task:pat, $handle: pat, $future:expr, $schedule:expr, $drop:ident) => { + lazy_static! { + static ref $drop: AtomicCell = AtomicCell::new(0); + } + + let ($task, $handle) = { + struct Tag(Box); + + impl Drop for Tag { + fn drop(&mut self) { + $drop.fetch_add(1); + } + } + + async_task::spawn($future, $schedule, Tag(Box::new(0))) + }; + }; +} + +fn ms(ms: u64) -> Duration { + Duration::from_millis(ms) +} + +#[test] +fn wake_during_run() { + future!(f, waker, POLL, DROP_F); + schedule!(s, chan, SCHEDULE, DROP_S); + task!(task, handle, f, s, DROP_D); + + task.run(); + let w = waker(); + w.wake_by_ref(); + let task = chan.recv().unwrap(); + + crossbeam::scope(|scope| { + scope.spawn(|_| { + assert!(catch_unwind(|| task.run()).is_err()); + drop(waker()); + assert_eq!(POLL.load(), 2); + assert_eq!(SCHEDULE.load(), 1); + assert_eq!(DROP_F.load(), 1); + assert_eq!(DROP_S.load(), 1); + assert_eq!(DROP_D.load(), 1); + assert_eq!(chan.len(), 0); + }); + + thread::sleep(ms(100)); + + w.wake(); + drop(handle); + assert_eq!(POLL.load(), 2); + assert_eq!(SCHEDULE.load(), 1); + assert_eq!(DROP_F.load(), 0); + assert_eq!(DROP_S.load(), 0); + assert_eq!(DROP_D.load(), 0); + assert_eq!(chan.len(), 0); + + thread::sleep(ms(200)); + + assert_eq!(POLL.load(), 2); + assert_eq!(SCHEDULE.load(), 1); + assert_eq!(DROP_F.load(), 1); + assert_eq!(DROP_S.load(), 1); + assert_eq!(DROP_D.load(), 1); + assert_eq!(chan.len(), 0); + }) + .unwrap(); +} + +#[test] +fn cancel_during_run() { + future!(f, waker, POLL, DROP_F); + schedule!(s, chan, SCHEDULE, DROP_S); + task!(task, handle, f, s, DROP_D); + + task.run(); + let w = waker(); + w.wake(); + let task = chan.recv().unwrap(); + + crossbeam::scope(|scope| { + scope.spawn(|_| { + assert!(catch_unwind(|| task.run()).is_err()); + drop(waker()); + assert_eq!(POLL.load(), 2); + assert_eq!(SCHEDULE.load(), 1); + assert_eq!(DROP_F.load(), 1); + assert_eq!(DROP_S.load(), 1); + assert_eq!(DROP_D.load(), 1); + assert_eq!(chan.len(), 0); + }); + + thread::sleep(ms(100)); + + handle.cancel(); + assert_eq!(POLL.load(), 2); + assert_eq!(SCHEDULE.load(), 1); + assert_eq!(DROP_F.load(), 0); + assert_eq!(DROP_S.load(), 0); + assert_eq!(DROP_D.load(), 0); + assert_eq!(chan.len(), 0); + + drop(handle); + assert_eq!(POLL.load(), 2); + assert_eq!(SCHEDULE.load(), 1); + assert_eq!(DROP_F.load(), 0); + assert_eq!(DROP_S.load(), 0); + assert_eq!(DROP_D.load(), 0); + assert_eq!(chan.len(), 0); + + thread::sleep(ms(200)); + + assert_eq!(POLL.load(), 2); + assert_eq!(SCHEDULE.load(), 1); + assert_eq!(DROP_F.load(), 1); + assert_eq!(DROP_S.load(), 1); + assert_eq!(DROP_D.load(), 1); + assert_eq!(chan.len(), 0); + }) + .unwrap(); +} + +#[test] +fn wake_and_cancel_during_run() { + future!(f, waker, POLL, DROP_F); + schedule!(s, chan, SCHEDULE, DROP_S); + task!(task, handle, f, s, DROP_D); + + task.run(); + let w = waker(); + w.wake_by_ref(); + let task = chan.recv().unwrap(); + + crossbeam::scope(|scope| { + scope.spawn(|_| { + assert!(catch_unwind(|| task.run()).is_err()); + drop(waker()); + assert_eq!(POLL.load(), 2); + assert_eq!(SCHEDULE.load(), 1); + assert_eq!(DROP_F.load(), 1); + assert_eq!(DROP_S.load(), 1); + assert_eq!(DROP_D.load(), 1); + assert_eq!(chan.len(), 0); + }); + + thread::sleep(ms(100)); + + w.wake(); + assert_eq!(POLL.load(), 2); + assert_eq!(SCHEDULE.load(), 1); + assert_eq!(DROP_F.load(), 0); + assert_eq!(DROP_S.load(), 0); + assert_eq!(DROP_D.load(), 0); + assert_eq!(chan.len(), 0); + + handle.cancel(); + assert_eq!(POLL.load(), 2); + assert_eq!(SCHEDULE.load(), 1); + assert_eq!(DROP_F.load(), 0); + assert_eq!(DROP_S.load(), 0); + assert_eq!(DROP_D.load(), 0); + assert_eq!(chan.len(), 0); + + drop(handle); + assert_eq!(POLL.load(), 2); + assert_eq!(SCHEDULE.load(), 1); + assert_eq!(DROP_F.load(), 0); + assert_eq!(DROP_S.load(), 0); + assert_eq!(DROP_D.load(), 0); + assert_eq!(chan.len(), 0); + + thread::sleep(ms(200)); + + assert_eq!(POLL.load(), 2); + assert_eq!(SCHEDULE.load(), 1); + assert_eq!(DROP_F.load(), 1); + assert_eq!(DROP_S.load(), 1); + assert_eq!(DROP_D.load(), 1); + assert_eq!(chan.len(), 0); + }) + .unwrap(); +} + +#[test] +fn cancel_and_wake_during_run() { + future!(f, waker, POLL, DROP_F); + schedule!(s, chan, SCHEDULE, DROP_S); + task!(task, handle, f, s, DROP_D); + + task.run(); + let w = waker(); + w.wake_by_ref(); + let task = chan.recv().unwrap(); + + crossbeam::scope(|scope| { + scope.spawn(|_| { + assert!(catch_unwind(|| task.run()).is_err()); + drop(waker()); + assert_eq!(POLL.load(), 2); + assert_eq!(SCHEDULE.load(), 1); + assert_eq!(DROP_F.load(), 1); + assert_eq!(DROP_S.load(), 1); + assert_eq!(DROP_D.load(), 1); + assert_eq!(chan.len(), 0); + }); + + thread::sleep(ms(100)); + + handle.cancel(); + assert_eq!(POLL.load(), 2); + assert_eq!(SCHEDULE.load(), 1); + assert_eq!(DROP_F.load(), 0); + assert_eq!(DROP_S.load(), 0); + assert_eq!(DROP_D.load(), 0); + assert_eq!(chan.len(), 0); + + drop(handle); + assert_eq!(POLL.load(), 2); + assert_eq!(SCHEDULE.load(), 1); + assert_eq!(DROP_F.load(), 0); + assert_eq!(DROP_S.load(), 0); + assert_eq!(DROP_D.load(), 0); + assert_eq!(chan.len(), 0); + + w.wake(); + assert_eq!(POLL.load(), 2); + assert_eq!(SCHEDULE.load(), 1); + assert_eq!(DROP_F.load(), 0); + assert_eq!(DROP_S.load(), 0); + assert_eq!(DROP_D.load(), 0); + assert_eq!(chan.len(), 0); + + thread::sleep(ms(200)); + + assert_eq!(POLL.load(), 2); + assert_eq!(SCHEDULE.load(), 1); + assert_eq!(DROP_F.load(), 1); + assert_eq!(DROP_S.load(), 1); + assert_eq!(DROP_D.load(), 1); + assert_eq!(chan.len(), 0); + }) + .unwrap(); +} diff --git a/async-task/tests/waker_pending.rs b/async-task/tests/waker_pending.rs new file mode 100644 index 00000000..547ff7a3 --- /dev/null +++ b/async-task/tests/waker_pending.rs @@ -0,0 +1,348 @@ +#![feature(async_await)] + +use std::future::Future; +use std::pin::Pin; +use std::task::Waker; +use std::task::{Context, Poll}; +use std::thread; +use std::time::Duration; + +use async_task::Task; +use crossbeam::atomic::AtomicCell; +use crossbeam::channel; +use lazy_static::lazy_static; + +// Creates a future with event counters. +// +// Usage: `future!(f, waker, POLL, DROP)` +// +// The future `f` always sleeps for 200 ms and returns `Poll::Pending`. +// When it gets polled, `POLL` is incremented. +// When it gets dropped, `DROP` is incremented. +// +// Every time the future is run, it stores the waker into a global variable. +// This waker can be extracted using the `waker` function. +macro_rules! future { + ($name:pat, $waker:pat, $poll:ident, $drop:ident) => { + lazy_static! { + static ref $poll: AtomicCell = AtomicCell::new(0); + static ref $drop: AtomicCell = AtomicCell::new(0); + static ref WAKER: AtomicCell> = AtomicCell::new(None); + } + + let ($name, $waker) = { + struct Fut(Box); + + impl Future for Fut { + type Output = (); + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + WAKER.store(Some(cx.waker().clone())); + $poll.fetch_add(1); + thread::sleep(ms(200)); + Poll::Pending + } + } + + impl Drop for Fut { + fn drop(&mut self) { + $drop.fetch_add(1); + } + } + + (Fut(Box::new(0)), || WAKER.swap(None).unwrap()) + }; + }; +} + +// Creates a schedule function with event counters. +// +// Usage: `schedule!(s, chan, SCHED, DROP)` +// +// The schedule function `s` pushes the task into `chan`. +// When it gets invoked, `SCHED` is incremented. +// When it gets dropped, `DROP` is incremented. +// +// Receiver `chan` extracts the task when it is scheduled. +macro_rules! schedule { + ($name:pat, $chan:pat, $sched:ident, $drop:ident) => { + lazy_static! { + static ref $sched: AtomicCell = AtomicCell::new(0); + static ref $drop: AtomicCell = AtomicCell::new(0); + } + + let ($name, $chan) = { + let (s, r) = channel::unbounded(); + + struct Guard(Box); + + impl Drop for Guard { + fn drop(&mut self) { + $drop.fetch_add(1); + } + } + + let guard = Guard(Box::new(0)); + let sched = move |task: Task<_>| { + &guard; + $sched.fetch_add(1); + s.send(task).unwrap(); + }; + + (sched, r) + }; + }; +} + +// Creates a task with event counters. +// +// Usage: `task!(task, handle f, s, DROP)` +// +// A task with future `f` and schedule function `s` is created. +// The `Task` and `JoinHandle` are bound to `task` and `handle`, respectively. +// When the tag inside the task gets dropped, `DROP` is incremented. +macro_rules! task { + ($task:pat, $handle: pat, $future:expr, $schedule:expr, $drop:ident) => { + lazy_static! { + static ref $drop: AtomicCell = AtomicCell::new(0); + } + + let ($task, $handle) = { + struct Tag(Box); + + impl Drop for Tag { + fn drop(&mut self) { + $drop.fetch_add(1); + } + } + + async_task::spawn($future, $schedule, Tag(Box::new(0))) + }; + }; +} + +fn ms(ms: u64) -> Duration { + Duration::from_millis(ms) +} + +#[test] +fn wake_during_run() { + future!(f, waker, POLL, DROP_F); + schedule!(s, chan, SCHEDULE, DROP_S); + task!(task, _handle, f, s, DROP_D); + + task.run(); + let w = waker(); + w.wake_by_ref(); + let task = chan.recv().unwrap(); + + crossbeam::scope(|scope| { + scope.spawn(|_| { + task.run(); + assert_eq!(POLL.load(), 2); + assert_eq!(SCHEDULE.load(), 2); + assert_eq!(DROP_F.load(), 0); + assert_eq!(DROP_S.load(), 0); + assert_eq!(DROP_D.load(), 0); + assert_eq!(chan.len(), 1); + }); + + thread::sleep(ms(100)); + + w.wake_by_ref(); + assert_eq!(POLL.load(), 2); + assert_eq!(SCHEDULE.load(), 1); + assert_eq!(DROP_F.load(), 0); + assert_eq!(DROP_S.load(), 0); + assert_eq!(DROP_D.load(), 0); + assert_eq!(chan.len(), 0); + + thread::sleep(ms(200)); + + assert_eq!(POLL.load(), 2); + assert_eq!(SCHEDULE.load(), 2); + assert_eq!(DROP_F.load(), 0); + assert_eq!(DROP_S.load(), 0); + assert_eq!(DROP_D.load(), 0); + assert_eq!(chan.len(), 1); + }) + .unwrap(); + + chan.recv().unwrap(); + drop(waker()); +} + +#[test] +fn cancel_during_run() { + future!(f, waker, POLL, DROP_F); + schedule!(s, chan, SCHEDULE, DROP_S); + task!(task, handle, f, s, DROP_D); + + task.run(); + let w = waker(); + w.wake(); + let task = chan.recv().unwrap(); + + crossbeam::scope(|scope| { + scope.spawn(|_| { + task.run(); + drop(waker()); + assert_eq!(POLL.load(), 2); + assert_eq!(SCHEDULE.load(), 1); + assert_eq!(DROP_F.load(), 1); + assert_eq!(DROP_S.load(), 1); + assert_eq!(DROP_D.load(), 1); + assert_eq!(chan.len(), 0); + }); + + thread::sleep(ms(100)); + + handle.cancel(); + assert_eq!(POLL.load(), 2); + assert_eq!(SCHEDULE.load(), 1); + assert_eq!(DROP_F.load(), 0); + assert_eq!(DROP_S.load(), 0); + assert_eq!(DROP_D.load(), 0); + assert_eq!(chan.len(), 0); + + drop(handle); + assert_eq!(POLL.load(), 2); + assert_eq!(SCHEDULE.load(), 1); + assert_eq!(DROP_F.load(), 0); + assert_eq!(DROP_S.load(), 0); + assert_eq!(DROP_D.load(), 0); + assert_eq!(chan.len(), 0); + + thread::sleep(ms(200)); + + assert_eq!(POLL.load(), 2); + assert_eq!(SCHEDULE.load(), 1); + assert_eq!(DROP_F.load(), 1); + assert_eq!(DROP_S.load(), 1); + assert_eq!(DROP_D.load(), 1); + assert_eq!(chan.len(), 0); + }) + .unwrap(); +} + +#[test] +fn wake_and_cancel_during_run() { + future!(f, waker, POLL, DROP_F); + schedule!(s, chan, SCHEDULE, DROP_S); + task!(task, handle, f, s, DROP_D); + + task.run(); + let w = waker(); + w.wake_by_ref(); + let task = chan.recv().unwrap(); + + crossbeam::scope(|scope| { + scope.spawn(|_| { + task.run(); + drop(waker()); + assert_eq!(POLL.load(), 2); + assert_eq!(SCHEDULE.load(), 1); + assert_eq!(DROP_F.load(), 1); + assert_eq!(DROP_S.load(), 1); + assert_eq!(DROP_D.load(), 1); + assert_eq!(chan.len(), 0); + }); + + thread::sleep(ms(100)); + + w.wake(); + assert_eq!(POLL.load(), 2); + assert_eq!(SCHEDULE.load(), 1); + assert_eq!(DROP_F.load(), 0); + assert_eq!(DROP_S.load(), 0); + assert_eq!(DROP_D.load(), 0); + assert_eq!(chan.len(), 0); + + handle.cancel(); + assert_eq!(POLL.load(), 2); + assert_eq!(SCHEDULE.load(), 1); + assert_eq!(DROP_F.load(), 0); + assert_eq!(DROP_S.load(), 0); + assert_eq!(DROP_D.load(), 0); + assert_eq!(chan.len(), 0); + + drop(handle); + assert_eq!(POLL.load(), 2); + assert_eq!(SCHEDULE.load(), 1); + assert_eq!(DROP_F.load(), 0); + assert_eq!(DROP_S.load(), 0); + assert_eq!(DROP_D.load(), 0); + assert_eq!(chan.len(), 0); + + thread::sleep(ms(200)); + + assert_eq!(POLL.load(), 2); + assert_eq!(SCHEDULE.load(), 1); + assert_eq!(DROP_F.load(), 1); + assert_eq!(DROP_S.load(), 1); + assert_eq!(DROP_D.load(), 1); + assert_eq!(chan.len(), 0); + }) + .unwrap(); +} + +#[test] +fn cancel_and_wake_during_run() { + future!(f, waker, POLL, DROP_F); + schedule!(s, chan, SCHEDULE, DROP_S); + task!(task, handle, f, s, DROP_D); + + task.run(); + let w = waker(); + w.wake_by_ref(); + let task = chan.recv().unwrap(); + + crossbeam::scope(|scope| { + scope.spawn(|_| { + task.run(); + drop(waker()); + assert_eq!(POLL.load(), 2); + assert_eq!(SCHEDULE.load(), 1); + assert_eq!(DROP_F.load(), 1); + assert_eq!(DROP_S.load(), 1); + assert_eq!(DROP_D.load(), 1); + assert_eq!(chan.len(), 0); + }); + + thread::sleep(ms(100)); + + handle.cancel(); + assert_eq!(POLL.load(), 2); + assert_eq!(SCHEDULE.load(), 1); + assert_eq!(DROP_F.load(), 0); + assert_eq!(DROP_S.load(), 0); + assert_eq!(DROP_D.load(), 0); + assert_eq!(chan.len(), 0); + + drop(handle); + assert_eq!(POLL.load(), 2); + assert_eq!(SCHEDULE.load(), 1); + assert_eq!(DROP_F.load(), 0); + assert_eq!(DROP_S.load(), 0); + assert_eq!(DROP_D.load(), 0); + assert_eq!(chan.len(), 0); + + w.wake(); + assert_eq!(POLL.load(), 2); + assert_eq!(SCHEDULE.load(), 1); + assert_eq!(DROP_F.load(), 0); + assert_eq!(DROP_S.load(), 0); + assert_eq!(DROP_D.load(), 0); + assert_eq!(chan.len(), 0); + + thread::sleep(ms(200)); + + assert_eq!(POLL.load(), 2); + assert_eq!(SCHEDULE.load(), 1); + assert_eq!(DROP_F.load(), 1); + assert_eq!(DROP_S.load(), 1); + assert_eq!(DROP_D.load(), 1); + assert_eq!(chan.len(), 0); + }) + .unwrap(); +} diff --git a/async-task/tests/waker_ready.rs b/async-task/tests/waker_ready.rs new file mode 100644 index 00000000..e64cc554 --- /dev/null +++ b/async-task/tests/waker_ready.rs @@ -0,0 +1,328 @@ +#![feature(async_await)] + +use std::cell::Cell; +use std::future::Future; +use std::pin::Pin; +use std::task::Waker; +use std::task::{Context, Poll}; +use std::thread; +use std::time::Duration; + +use async_task::Task; +use crossbeam::atomic::AtomicCell; +use crossbeam::channel; +use lazy_static::lazy_static; + +// Creates a future with event counters. +// +// Usage: `future!(f, waker, POLL, DROP)` +// +// The future `f` always sleeps for 200 ms, and returns `Poll::Ready` the second time it is polled. +// When it gets polled, `POLL` is incremented. +// When it gets dropped, `DROP` is incremented. +// +// Every time the future is run, it stores the waker into a global variable. +// This waker can be extracted using the `waker` function. +macro_rules! future { + ($name:pat, $waker:pat, $poll:ident, $drop:ident) => { + lazy_static! { + static ref $poll: AtomicCell = AtomicCell::new(0); + static ref $drop: AtomicCell = AtomicCell::new(0); + static ref WAKER: AtomicCell> = AtomicCell::new(None); + } + + let ($name, $waker) = { + struct Fut(Cell, Box); + + impl Future for Fut { + type Output = Box; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + WAKER.store(Some(cx.waker().clone())); + $poll.fetch_add(1); + thread::sleep(ms(200)); + + if self.0.get() { + Poll::Ready(Box::new(0)) + } else { + self.0.set(true); + Poll::Pending + } + } + } + + impl Drop for Fut { + fn drop(&mut self) { + $drop.fetch_add(1); + } + } + + (Fut(Cell::new(false), Box::new(0)), || { + WAKER.swap(None).unwrap() + }) + }; + }; +} + +// Creates a schedule function with event counters. +// +// Usage: `schedule!(s, chan, SCHED, DROP)` +// +// The schedule function `s` pushes the task into `chan`. +// When it gets invoked, `SCHED` is incremented. +// When it gets dropped, `DROP` is incremented. +// +// Receiver `chan` extracts the task when it is scheduled. +macro_rules! schedule { + ($name:pat, $chan:pat, $sched:ident, $drop:ident) => { + lazy_static! { + static ref $sched: AtomicCell = AtomicCell::new(0); + static ref $drop: AtomicCell = AtomicCell::new(0); + } + + let ($name, $chan) = { + let (s, r) = channel::unbounded(); + + struct Guard(Box); + + impl Drop for Guard { + fn drop(&mut self) { + $drop.fetch_add(1); + } + } + + let guard = Guard(Box::new(0)); + let sched = move |task: Task<_>| { + &guard; + $sched.fetch_add(1); + s.send(task).unwrap(); + }; + + (sched, r) + }; + }; +} + +// Creates a task with event counters. +// +// Usage: `task!(task, handle f, s, DROP)` +// +// A task with future `f` and schedule function `s` is created. +// The `Task` and `JoinHandle` are bound to `task` and `handle`, respectively. +// When the tag inside the task gets dropped, `DROP` is incremented. +macro_rules! task { + ($task:pat, $handle: pat, $future:expr, $schedule:expr, $drop:ident) => { + lazy_static! { + static ref $drop: AtomicCell = AtomicCell::new(0); + } + + let ($task, $handle) = { + struct Tag(Box); + + impl Drop for Tag { + fn drop(&mut self) { + $drop.fetch_add(1); + } + } + + async_task::spawn($future, $schedule, Tag(Box::new(0))) + }; + }; +} + +fn ms(ms: u64) -> Duration { + Duration::from_millis(ms) +} + +#[test] +fn wake() { + future!(f, waker, POLL, DROP_F); + schedule!(s, chan, SCHEDULE, DROP_S); + task!(mut task, _, f, s, DROP_D); + + assert!(chan.is_empty()); + + task.run(); + assert_eq!(POLL.load(), 1); + assert_eq!(SCHEDULE.load(), 0); + assert_eq!(DROP_F.load(), 0); + assert_eq!(DROP_S.load(), 0); + assert_eq!(DROP_D.load(), 0); + assert_eq!(chan.len(), 0); + + waker().wake(); + task = chan.recv().unwrap(); + assert_eq!(POLL.load(), 1); + assert_eq!(SCHEDULE.load(), 1); + assert_eq!(DROP_F.load(), 0); + assert_eq!(DROP_S.load(), 0); + assert_eq!(DROP_D.load(), 0); + assert_eq!(chan.len(), 0); + + task.run(); + assert_eq!(POLL.load(), 2); + assert_eq!(SCHEDULE.load(), 1); + assert_eq!(DROP_F.load(), 1); + assert_eq!(DROP_S.load(), 0); + assert_eq!(DROP_D.load(), 0); + assert_eq!(chan.len(), 0); + + waker().wake(); + assert_eq!(POLL.load(), 2); + assert_eq!(SCHEDULE.load(), 1); + assert_eq!(DROP_F.load(), 1); + assert_eq!(DROP_S.load(), 1); + assert_eq!(DROP_D.load(), 1); + assert_eq!(chan.len(), 0); +} + +#[test] +fn wake_by_ref() { + future!(f, waker, POLL, DROP_F); + schedule!(s, chan, SCHEDULE, DROP_S); + task!(mut task, _, f, s, DROP_D); + + assert!(chan.is_empty()); + + task.run(); + assert_eq!(POLL.load(), 1); + assert_eq!(SCHEDULE.load(), 0); + assert_eq!(DROP_F.load(), 0); + assert_eq!(DROP_S.load(), 0); + assert_eq!(DROP_D.load(), 0); + assert_eq!(chan.len(), 0); + + waker().wake_by_ref(); + task = chan.recv().unwrap(); + assert_eq!(POLL.load(), 1); + assert_eq!(SCHEDULE.load(), 1); + assert_eq!(DROP_F.load(), 0); + assert_eq!(DROP_S.load(), 0); + assert_eq!(DROP_D.load(), 0); + assert_eq!(chan.len(), 0); + + task.run(); + assert_eq!(POLL.load(), 2); + assert_eq!(SCHEDULE.load(), 1); + assert_eq!(DROP_F.load(), 1); + assert_eq!(DROP_S.load(), 0); + assert_eq!(DROP_D.load(), 0); + assert_eq!(chan.len(), 0); + + waker().wake_by_ref(); + assert_eq!(POLL.load(), 2); + assert_eq!(SCHEDULE.load(), 1); + assert_eq!(DROP_F.load(), 1); + assert_eq!(DROP_S.load(), 1); + assert_eq!(DROP_D.load(), 1); + assert_eq!(chan.len(), 0); +} + +#[test] +fn clone() { + future!(f, waker, POLL, DROP_F); + schedule!(s, chan, SCHEDULE, DROP_S); + task!(mut task, _, f, s, DROP_D); + + task.run(); + assert_eq!(POLL.load(), 1); + assert_eq!(SCHEDULE.load(), 0); + assert_eq!(DROP_F.load(), 0); + assert_eq!(DROP_S.load(), 0); + assert_eq!(DROP_D.load(), 0); + assert_eq!(chan.len(), 0); + + let w2 = waker().clone(); + let w3 = w2.clone(); + let w4 = w3.clone(); + w4.wake(); + + task = chan.recv().unwrap(); + task.run(); + assert_eq!(POLL.load(), 2); + assert_eq!(SCHEDULE.load(), 1); + assert_eq!(DROP_F.load(), 1); + assert_eq!(DROP_S.load(), 0); + assert_eq!(DROP_D.load(), 0); + assert_eq!(chan.len(), 0); + + w3.wake(); + assert_eq!(POLL.load(), 2); + assert_eq!(SCHEDULE.load(), 1); + assert_eq!(DROP_F.load(), 1); + assert_eq!(DROP_S.load(), 0); + assert_eq!(DROP_D.load(), 0); + assert_eq!(chan.len(), 0); + + drop(w2); + drop(waker()); + assert_eq!(DROP_S.load(), 1); + assert_eq!(DROP_D.load(), 1); +} + +#[test] +fn wake_cancelled() { + future!(f, waker, POLL, DROP_F); + schedule!(s, chan, SCHEDULE, DROP_S); + task!(task, _, f, s, DROP_D); + + task.run(); + assert_eq!(POLL.load(), 1); + assert_eq!(SCHEDULE.load(), 0); + assert_eq!(DROP_F.load(), 0); + assert_eq!(DROP_S.load(), 0); + assert_eq!(DROP_D.load(), 0); + assert_eq!(chan.len(), 0); + + let w = waker(); + + w.wake_by_ref(); + chan.recv().unwrap().cancel(); + assert_eq!(POLL.load(), 1); + assert_eq!(SCHEDULE.load(), 1); + assert_eq!(DROP_F.load(), 1); + assert_eq!(DROP_S.load(), 0); + assert_eq!(DROP_D.load(), 0); + assert_eq!(chan.len(), 0); + + w.wake(); + assert_eq!(POLL.load(), 1); + assert_eq!(SCHEDULE.load(), 1); + assert_eq!(DROP_F.load(), 1); + assert_eq!(DROP_S.load(), 1); + assert_eq!(DROP_D.load(), 1); + assert_eq!(chan.len(), 0); +} + +#[test] +fn wake_completed() { + future!(f, waker, POLL, DROP_F); + schedule!(s, chan, SCHEDULE, DROP_S); + task!(task, _, f, s, DROP_D); + + task.run(); + let w = waker(); + assert_eq!(POLL.load(), 1); + assert_eq!(SCHEDULE.load(), 0); + assert_eq!(DROP_F.load(), 0); + assert_eq!(DROP_S.load(), 0); + assert_eq!(DROP_D.load(), 0); + assert_eq!(chan.len(), 0); + + w.wake(); + chan.recv().unwrap().run(); + assert_eq!(POLL.load(), 2); + assert_eq!(SCHEDULE.load(), 1); + assert_eq!(DROP_F.load(), 1); + assert_eq!(DROP_S.load(), 0); + assert_eq!(DROP_D.load(), 0); + assert_eq!(chan.len(), 0); + + waker().wake(); + assert_eq!(POLL.load(), 2); + assert_eq!(SCHEDULE.load(), 1); + assert_eq!(DROP_F.load(), 1); + assert_eq!(DROP_S.load(), 1); + assert_eq!(DROP_D.load(), 1); + assert_eq!(chan.len(), 0); +} diff --git a/benches/task_local.rs b/benches/task_local.rs new file mode 100644 index 00000000..28003651 --- /dev/null +++ b/benches/task_local.rs @@ -0,0 +1,20 @@ +#![feature(async_await, test)] + +extern crate test; + +use async_std::task; +use async_std::task_local; +use test::{black_box, Bencher}; + +#[bench] +fn get(b: &mut Bencher) { + task_local! { + static VAL: u64 = 1; + } + + let mut sum = 0; + task::block_on(async { + b.iter(|| VAL.with(|v| sum += v)); + }); + black_box(sum); +} diff --git a/examples/fetch-html.rs b/examples/fetch-html.rs new file mode 100644 index 00000000..e1482214 --- /dev/null +++ b/examples/fetch-html.rs @@ -0,0 +1,15 @@ +//! Fetches the HTML contents of the Rust website. + +#![feature(async_await)] + +use std::error::Error; + +use async_std::task; + +fn main() -> Result<(), Box> { + task::block_on(async { + // let contents = surf::get("https://www.rust-lang.org").recv_string().await?; + // println!("{}", contents); + Ok(()) + }) +} diff --git a/examples/hello-world.rs b/examples/hello-world.rs new file mode 100644 index 00000000..a3f064cf --- /dev/null +++ b/examples/hello-world.rs @@ -0,0 +1,13 @@ +//! Spawns a task that says hello. + +#![feature(async_await)] + +use async_std::task; + +async fn say_hi() { + println!("Hello, world!"); +} + +fn main() { + task::block_on(say_hi()) +} diff --git a/examples/list-dir.rs b/examples/list-dir.rs new file mode 100644 index 00000000..545d68da --- /dev/null +++ b/examples/list-dir.rs @@ -0,0 +1,21 @@ +//! Lists files in a directory given as an argument. + +#![feature(async_await)] + +use std::env::args; + +use async_std::{fs, io, prelude::*, task}; + +fn main() -> io::Result<()> { + let path = args().nth(1).expect("missing path argument"); + + task::block_on(async { + let mut dir = fs::read_dir(&path).await?; + + while let Some(entry) = dir.next().await { + println!("{}", entry?.file_name().to_string_lossy()); + } + + Ok(()) + }) +} diff --git a/examples/logging.rs b/examples/logging.rs new file mode 100644 index 00000000..8ada5c7d --- /dev/null +++ b/examples/logging.rs @@ -0,0 +1,17 @@ +//! Prints the runtime's execution log on the standard output. + +#![feature(async_await)] + +use async_std::task; + +fn main() { + femme::start(log::LevelFilter::Trace).unwrap(); + + task::block_on(async { + let handle = task::spawn(async { + log::info!("Hello world!"); + }); + + handle.await; + }) +} diff --git a/examples/print-file.rs b/examples/print-file.rs new file mode 100644 index 00000000..b204a469 --- /dev/null +++ b/examples/print-file.rs @@ -0,0 +1,34 @@ +//! Prints a file given as an argument to stdout. + +#![feature(async_await)] + +use std::env::args; + +use async_std::{fs, io, prelude::*, task}; + +const LEN: usize = 4 * 1024 * 1024; // 4 Mb + +fn main() -> io::Result<()> { + let path = args().nth(1).expect("missing path argument"); + + task::block_on(async { + let mut file = fs::File::open(&path).await?; + let mut stdout = io::stdout(); + let mut buf = vec![0u8; LEN]; + + loop { + // Read a buffer from the file. + let n = file.read(&mut buf).await?; + + // If this is the end of file, clean up and return. + if n == 0 { + stdout.flush().await?; + file.close().await?; + return Ok(()); + } + + // Write the buffer into stdout. + stdout.write_all(&buf[..n]).await?; + } + }) +} diff --git a/examples/stdin-echo.rs b/examples/stdin-echo.rs new file mode 100644 index 00000000..e85bcbd0 --- /dev/null +++ b/examples/stdin-echo.rs @@ -0,0 +1,28 @@ +//! Echoes lines read on stdin to stdout. + +#![feature(async_await)] + +use async_std::{io, prelude::*, task}; + +fn main() -> io::Result<()> { + task::block_on(async { + let stdin = io::stdin(); + let mut stdout = io::stdout(); + let mut line = String::new(); + + loop { + // Read a line from stdin. + let n = stdin.read_line(&mut line).await?; + + // If this is the end of stdin, return. + if n == 0 { + return Ok(()); + } + + // Write the line to stdout. + stdout.write_all(line.as_bytes()).await?; + stdout.flush().await?; + line.clear(); + } + }) +} diff --git a/examples/stdin-timeout.rs b/examples/stdin-timeout.rs new file mode 100644 index 00000000..188112dc --- /dev/null +++ b/examples/stdin-timeout.rs @@ -0,0 +1,28 @@ +//! Reads a line from stdin, or exits with an error if nothing is read in 5 seconds. + +#![feature(async_await)] + +use std::time::Duration; + +use async_std::{io, prelude::*, task}; + +fn main() -> io::Result<()> { + task::block_on(async { + let stdin = io::stdin(); + let mut line = String::new(); + + match stdin + .read_line(&mut line) + .timeout(Duration::from_secs(5)) + .await + { + Ok(res) => { + res?; + print!("Got line: {}", line); + } + Err(_) => println!("You have only 5 seconds to enter a line. Try again :)"), + } + + Ok(()) + }) +} diff --git a/examples/task-local.rs b/examples/task-local.rs new file mode 100644 index 00000000..ed541803 --- /dev/null +++ b/examples/task-local.rs @@ -0,0 +1,19 @@ +//! Creates a task-local value. + +#![feature(async_await)] + +use std::cell::Cell; + +use async_std::{task, task_local}; + +task_local! { + static VAR: Cell = Cell::new(1); +} + +fn main() { + task::block_on(async { + println!("var = {}", VAR.with(|v| v.get())); + VAR.with(|v| v.set(2)); + println!("var = {}", VAR.with(|v| v.get())); + }) +} diff --git a/examples/task-name.rs b/examples/task-name.rs new file mode 100644 index 00000000..39ad080a --- /dev/null +++ b/examples/task-name.rs @@ -0,0 +1,19 @@ +//! Spawns a named task that prints its name. + +#![feature(async_await)] + +use async_std::task; + +async fn print_name() { + println!("My name is {:?}", task::current().name()); +} + +fn main() { + task::block_on(async { + task::Builder::new() + .name("my-task".to_string()) + .spawn(print_name()) + .unwrap() + .await; + }) +} diff --git a/examples/tcp-client.rs b/examples/tcp-client.rs new file mode 100644 index 00000000..dd613db5 --- /dev/null +++ b/examples/tcp-client.rs @@ -0,0 +1,34 @@ +//! TCP client. +//! +//! First start the echo server: +//! +//! ```sh +//! $ cargo run --example tcp-echo +//! ``` +//! +//! Then run the client: +//! +//! ```sh +//! $ cargo run --example tcp-client +//! ``` + +#![feature(async_await)] + +use async_std::{io, net, prelude::*, task}; + +fn main() -> io::Result<()> { + task::block_on(async { + let mut stream = net::TcpStream::connect("127.0.0.1:8080").await?; + println!("Connected to {}", &stream.peer_addr()?); + + let msg = "hello world"; + println!("<- {}", msg); + stream.write_all(msg.as_bytes()).await?; + + let mut buf = vec![0u8; 1024]; + let n = stream.read(&mut buf).await?; + println!("-> {}\n", String::from_utf8_lossy(&buf[..n])); + + Ok(()) + }) +} diff --git a/examples/tcp-echo.rs b/examples/tcp-echo.rs new file mode 100644 index 00000000..498bb03c --- /dev/null +++ b/examples/tcp-echo.rs @@ -0,0 +1,37 @@ +//! TCP echo server. +//! +//! To send messages, do: +//! +//! ```sh +//! $ nc localhost 8080 +//! ``` + +#![feature(async_await)] + +use async_std::{io, net, prelude::*, task}; + +async fn process(stream: net::TcpStream) -> io::Result<()> { + println!("Accepted from: {}", stream.peer_addr()?); + + let (reader, writer) = &mut (&stream, &stream); + io::copy(reader, writer).await?; + + Ok(()) +} + +fn main() -> io::Result<()> { + task::block_on(async { + let listener = net::TcpListener::bind("127.0.0.1:8080").await?; + println!("Listening on {}", listener.local_addr()?); + + let mut incoming = listener.incoming(); + + while let Some(stream) = incoming.next().await { + let stream = stream?; + task::spawn(async { + process(stream).await.unwrap(); + }); + } + Ok(()) + }) +} diff --git a/examples/udp-client.rs b/examples/udp-client.rs new file mode 100644 index 00000000..5e5bbcad --- /dev/null +++ b/examples/udp-client.rs @@ -0,0 +1,34 @@ +//! UDP client. +//! +//! First start the echo server: +//! +//! ```sh +//! $ cargo run --example udp-echo +//! ``` +//! +//! Then run the client: +//! +//! ```sh +//! $ cargo run --example udp-client +//! ``` + +#![feature(async_await)] + +use async_std::{io, net, task}; + +fn main() -> io::Result<()> { + task::block_on(async { + let socket = net::UdpSocket::bind("127.0.0.1:8081").await?; + println!("Listening on {}", socket.local_addr()?); + + let msg = "hello world"; + println!("<- {}", msg); + socket.send_to(msg.as_bytes(), "127.0.0.1:8080").await?; + + let mut buf = vec![0u8; 1024]; + let (n, _) = socket.recv_from(&mut buf).await?; + println!("-> {}\n", String::from_utf8_lossy(&buf[..n])); + + Ok(()) + }) +} diff --git a/examples/udp-echo.rs b/examples/udp-echo.rs new file mode 100644 index 00000000..3ee6b284 --- /dev/null +++ b/examples/udp-echo.rs @@ -0,0 +1,26 @@ +//! UDP echo server. +//! +//! To send messages, do: +//! +//! ```sh +//! $ nc -u localhost 8080 +//! ``` + +#![feature(async_await)] + +use async_std::{io, net, task}; + +fn main() -> io::Result<()> { + task::block_on(async { + let socket = net::UdpSocket::bind("127.0.0.1:8080").await?; + let mut buf = vec![0u8; 1024]; + + println!("Listening on {}", socket.local_addr()?); + + loop { + let (n, peer) = socket.recv_from(&mut buf).await?; + let sent = socket.send_to(&buf[..n], &peer).await?; + println!("Sent {} out of {} bytes to {}", sent, n, peer); + } + }) +} diff --git a/src/fs/dir_builder.rs b/src/fs/dir_builder.rs new file mode 100644 index 00000000..cb2d86cb --- /dev/null +++ b/src/fs/dir_builder.rs @@ -0,0 +1,125 @@ +use std::fs; +use std::future::Future; +use std::io; +use std::path::Path; + +use cfg_if::cfg_if; + +use crate::task::blocking; + +/// A builder for creating directories in various manners. +/// +/// This type is an async version of [`std::fs::DirBuilder`]. +/// +/// [`std::fs::DirBuilder`]: https://doc.rust-lang.org/std/fs/struct.DirBuilder.html +#[derive(Debug)] +pub struct DirBuilder { + recursive: bool, + + #[cfg(unix)] + mode: Option, +} + +impl DirBuilder { + /// Creates a new builder with [`recursive`] set to `false`. + /// + /// [`recursive`]: #method.recursive + /// + /// # Examples + /// + /// ``` + /// use async_std::fs::DirBuilder; + /// + /// let builder = DirBuilder::new(); + /// ``` + pub fn new() -> DirBuilder { + #[cfg(unix)] + let builder = DirBuilder { + recursive: false, + mode: None, + }; + + #[cfg(windows)] + let builder = DirBuilder { recursive: false }; + + builder + } + + /// Sets the option for recursive mode. + /// + /// This option, when `true`, means that all parent directories should be created recursively + /// if they don't exist. Parents are created with the same security settings and permissions as + /// the final directory. + /// + /// This option defaults to `false`. + /// + /// # Examples + /// + /// ``` + /// use async_std::fs::DirBuilder; + /// + /// let mut builder = DirBuilder::new(); + /// builder.recursive(true); + /// ``` + pub fn recursive(&mut self, recursive: bool) -> &mut Self { + self.recursive = recursive; + self + } + + /// Creates a directory with the configured options. + /// + /// It is considered an error if the directory already exists unless recursive mode is enabled. + /// + /// # Examples + /// + /// ```no_run + /// # #![feature(async_await)] + /// use async_std::fs::{metadata, DirBuilder}; + /// + /// # futures::executor::block_on(async { + /// let path = "/tmp/foo/bar/baz"; + /// + /// DirBuilder::new() + /// .recursive(true) + /// .create(path) + /// .await?; + /// + /// assert!(metadata(path).await?.is_dir()); + /// # std::io::Result::Ok(()) + /// # }).unwrap(); + /// ``` + pub fn create>(&self, path: P) -> impl Future> { + let mut builder = fs::DirBuilder::new(); + builder.recursive(self.recursive); + + #[cfg(unix)] + { + if let Some(mode) = self.mode { + std::os::unix::fs::DirBuilderExt::mode(&mut builder, mode); + } + } + + let path = path.as_ref().to_owned(); + async move { blocking::spawn(async move { builder.create(path) }).await } + } +} + +cfg_if! { + if #[cfg(feature = "docs.rs")] { + use crate::os::unix::fs::DirBuilderExt; + } else if #[cfg(unix)] { + use std::os::unix::fs::DirBuilderExt; + } +} + +#[cfg_attr(feature = "docs.rs", doc(cfg(unix)))] +cfg_if! { + if #[cfg(any(unix, feature = "docs.rs"))] { + impl DirBuilderExt for DirBuilder { + fn mode(&mut self, mode: u32) -> &mut Self { + self.mode = Some(mode); + self + } + } + } +} diff --git a/src/fs/dir_entry.rs b/src/fs/dir_entry.rs new file mode 100644 index 00000000..92fee6c9 --- /dev/null +++ b/src/fs/dir_entry.rs @@ -0,0 +1,248 @@ +use std::ffi::OsString; +use std::fs; +use std::future::Future; +use std::io; +use std::path::PathBuf; +use std::pin::Pin; +use std::sync::Mutex; +use std::task::Poll; + +use cfg_if::cfg_if; +use futures::prelude::*; + +use crate::task::blocking; + +/// An entry inside a directory. +/// +/// An instance of `DirEntry` represents an entry inside a directory on the filesystem. Each entry +/// carriers additional information like the full path or metadata. +/// +/// This type is an async version of [`std::fs::DirEntry`]. +/// +/// [`std::fs::DirEntry`]: https://doc.rust-lang.org/std/fs/struct.DirEntry.html +#[derive(Debug)] +pub struct DirEntry { + /// The state of the entry. + state: Mutex, + + /// The full path to the entry. + path: PathBuf, + + #[cfg(unix)] + ino: u64, + + /// The bare name of the entry without the leading path. + file_name: OsString, +} + +/// The state of an asynchronous `DirEntry`. +/// +/// The `DirEntry` can be either idle or busy performing an asynchronous operation. +#[derive(Debug)] +enum State { + Idle(Option), + Busy(blocking::JoinHandle), +} + +impl DirEntry { + /// Creates an asynchronous `DirEntry` from a synchronous handle. + pub(crate) fn new(inner: fs::DirEntry) -> DirEntry { + #[cfg(unix)] + let dir_entry = DirEntry { + path: inner.path(), + file_name: inner.file_name(), + ino: inner.ino(), + state: Mutex::new(State::Idle(Some(inner))), + }; + + #[cfg(windows)] + let dir_entry = DirEntry { + path: inner.path(), + file_name: inner.file_name(), + state: Mutex::new(State::Idle(Some(inner))), + }; + + dir_entry + } + + /// Returns the full path to this entry. + /// + /// The full path is created by joining the original path passed to [`read_dir`] with the name + /// of this entry. + /// + /// [`read_dir`]: fn.read_dir.html + /// + /// # Examples + /// + /// ```no_run + /// # #![feature(async_await)] + /// use async_std::fs::read_dir; + /// use async_std::prelude::*; + /// + /// # futures::executor::block_on(async { + /// let mut dir = read_dir(".").await?; + /// + /// while let Some(entry) = dir.next().await { + /// let entry = entry?; + /// println!("{:?}", entry.path()); + /// } + /// # std::io::Result::Ok(()) + /// # }).unwrap(); + /// ``` + pub fn path(&self) -> PathBuf { + self.path.clone() + } + + /// Returns the metadata for this entry. + /// + /// This function will not traverse symlinks if this entry points at a symlink. + /// + /// # Examples + /// + /// ```no_run + /// # #![feature(async_await)] + /// use async_std::fs::read_dir; + /// use async_std::prelude::*; + /// + /// # futures::executor::block_on(async { + /// let mut dir = read_dir(".").await?; + /// + /// while let Some(entry) = dir.next().await { + /// let entry = entry?; + /// println!("{:?}", entry.metadata().await?); + /// } + /// # std::io::Result::Ok(()) + /// # }).unwrap(); + /// ``` + pub async fn metadata(&self) -> io::Result { + future::poll_fn(|cx| { + let state = &mut *self.state.lock().unwrap(); + + loop { + match state { + State::Idle(opt) => match opt.take() { + None => return Poll::Ready(None), + Some(inner) => { + let (s, r) = futures::channel::oneshot::channel(); + + // Start the operation asynchronously. + *state = State::Busy(blocking::spawn(async move { + let res = inner.metadata(); + let _ = s.send(res); + State::Idle(Some(inner)) + })); + + return Poll::Ready(Some(r)); + } + }, + // Poll the asynchronous operation the file is currently blocked on. + State::Busy(task) => *state = futures::ready!(Pin::new(task).poll(cx)), + } + } + }) + .map(|opt| opt.ok_or_else(|| io_error("invalid state"))) + .await? + .map_err(|_| io_error("blocking task failed")) + .await? + } + + /// Returns the file type for this entry. + /// + /// This function will not traverse symlinks if this entry points at a symlink. + /// + /// # Examples + /// + /// ```no_run + /// # #![feature(async_await)] + /// use async_std::fs::read_dir; + /// use async_std::prelude::*; + /// + /// # futures::executor::block_on(async { + /// let mut dir = read_dir(".").await?; + /// + /// while let Some(entry) = dir.next().await { + /// let entry = entry?; + /// println!("{:?}", entry.file_type().await?); + /// } + /// # std::io::Result::Ok(()) + /// # }).unwrap(); + /// ``` + pub async fn file_type(&self) -> io::Result { + future::poll_fn(|cx| { + let state = &mut *self.state.lock().unwrap(); + + loop { + match state { + State::Idle(opt) => match opt.take() { + None => return Poll::Ready(None), + Some(inner) => { + let (s, r) = futures::channel::oneshot::channel(); + + // Start the operation asynchronously. + *state = State::Busy(blocking::spawn(async move { + let res = inner.file_type(); + let _ = s.send(res); + State::Idle(Some(inner)) + })); + + return Poll::Ready(Some(r)); + } + }, + // Poll the asynchronous operation the file is currently blocked on. + State::Busy(task) => *state = futures::ready!(Pin::new(task).poll(cx)), + } + } + }) + .map(|opt| opt.ok_or_else(|| io_error("invalid state"))) + .await? + .map_err(|_| io_error("blocking task failed")) + .await? + } + + /// Returns the bare name of this entry without the leading path. + /// + /// # Examples + /// + /// ```no_run + /// # #![feature(async_await)] + /// use async_std::fs::read_dir; + /// use async_std::prelude::*; + /// + /// # futures::executor::block_on(async { + /// let mut dir = read_dir(".").await?; + /// + /// while let Some(entry) = dir.next().await { + /// let entry = entry?; + /// println!("{:?}", entry.file_name()); + /// } + /// # std::io::Result::Ok(()) + /// # }).unwrap(); + /// ``` + pub fn file_name(&self) -> OsString { + self.file_name.clone() + } +} + +/// Creates a custom `io::Error` with an arbitrary error type. +fn io_error(err: impl Into>) -> io::Error { + io::Error::new(io::ErrorKind::Other, err) +} + +cfg_if! { + if #[cfg(feature = "docs.rs")] { + use crate::os::unix::fs::DirEntryExt; + } else if #[cfg(unix)] { + use std::os::unix::fs::DirEntryExt; + } +} + +#[cfg_attr(feature = "docs.rs", doc(cfg(unix)))] +cfg_if! { + if #[cfg(any(unix, feature = "docs.rs"))] { + impl DirEntryExt for DirEntry { + fn ino(&self) -> u64 { + self.ino + } + } + } +} diff --git a/src/fs/file.rs b/src/fs/file.rs new file mode 100644 index 00000000..4e3303cb --- /dev/null +++ b/src/fs/file.rs @@ -0,0 +1,826 @@ +//! Types for working with files. + +use std::fs; +use std::future::Future; +use std::io::{self, SeekFrom}; +use std::path::Path; +use std::pin::Pin; +use std::sync::Mutex; +use std::task::{Context, Poll}; + +use cfg_if::cfg_if; +use futures::io::Initializer; +use futures::prelude::*; + +use crate::task::blocking; + +/// A reference to a file on the filesystem. +/// +/// An instance of a `File` can be read and/or written depending on what options it was opened +/// with. +/// +/// Files are automatically closed when they go out of scope. Errors detected on closing are +/// ignored by the implementation of `Drop`. Use the method [`sync_all`] if these errors must be +/// manually handled. +/// +/// This type is an async version of [`std::fs::File`]. +/// +/// [`sync_all`]: struct.File.html#method.sync_all +/// [`std::fs::File`]: https://doc.rust-lang.org/std/fs/struct.File.html +/// +/// # Examples +/// +/// Create a new file and write some bytes to it: +/// +/// ```no_run +/// # #![feature(async_await)] +/// use async_std::fs::File; +/// use async_std::prelude::*; +/// +/// # futures::executor::block_on(async { +/// let mut file = File::create("foo.txt").await?; +/// file.write_all(b"Hello, world!").await?; +/// # std::io::Result::Ok(()) +/// # }).unwrap(); +/// ``` +/// +/// Read the contents of a file into a `Vec`: +/// +/// ```no_run +/// # #![feature(async_await)] +/// use async_std::fs::File; +/// use async_std::prelude::*; +/// +/// # futures::executor::block_on(async { +/// let mut file = File::open("foo.txt").await?; +/// let mut contents = Vec::new(); +/// file.read_to_end(&mut contents).await?; +/// # std::io::Result::Ok(()) +/// # }).unwrap(); +/// ``` +#[derive(Debug)] +pub struct File { + mutex: Mutex, + + #[cfg(unix)] + raw_fd: std::os::unix::io::RawFd, + + #[cfg(windows)] + raw_handle: UnsafeShared, +} + +/// The state of an asynchronous file. +/// +/// The file can be either idle or busy performing an asynchronous operation. +#[derive(Debug)] +enum State { + /// The file is idle. + /// + /// If the inner representation is `None`, that means the file is closed. + Idle(Option), + + /// The file is blocked on an asynchronous operation. + /// + /// Awaiting this operation will result in the new state of the file. + Busy(blocking::JoinHandle), +} + +/// Inner representation of an asynchronous file. +#[derive(Debug)] +struct Inner { + /// The blocking file handle. + file: fs::File, + + /// The read/write buffer. + buf: Vec, + + /// The result of the last asynchronous operation on the file. + last_op: Option, +} + +/// Possible results of an asynchronous operation on a file. +#[derive(Debug)] +enum Operation { + Read(io::Result), + Write(io::Result), + Seek(io::Result), + Flush(io::Result<()>), +} + +impl File { + /// Opens a file in read-only mode. + /// + /// See the [`OpenOptions::open`] method for more details. + /// + /// # Errors + /// + /// This function will return an error if `path` does not already exist. + /// Other errors may also be returned according to [`OpenOptions::open`]. + /// + /// [`OpenOptions::open`]: https://doc.rust-lang.org/std/fs/struct.OpenOptions.html + /// + /// # Examples + /// + /// ```no_run + /// # #![feature(async_await)] + /// use async_std::fs::File; + /// + /// # futures::executor::block_on(async { + /// let file = File::open("foo.txt").await?; + /// # std::io::Result::Ok(()) + /// # }).unwrap(); + /// ``` + pub async fn open>(path: P) -> io::Result { + let path = path.as_ref().to_owned(); + let file = blocking::spawn(async move { fs::File::open(&path) }).await?; + + #[cfg(unix)] + let file = File { + raw_fd: file.as_raw_fd(), + mutex: Mutex::new(State::Idle(Some(Inner { + file, + buf: Vec::new(), + last_op: None, + }))), + }; + + #[cfg(windows)] + let file = File { + raw_handle: UnsafeShared(file.as_raw_handle()), + mutex: Mutex::new(State::Idle(Some(Inner { + file, + buf: Vec::new(), + last_op: None, + }))), + }; + + Ok(file) + } + + /// Opens a file in write-only mode. + /// + /// This function will create a file if it does not exist, and will truncate it if it does. + /// + /// See the [`OpenOptions::open`] function for more details. + /// + /// [`OpenOptions::open`]: https://doc.rust-lang.org/std/fs/struct.OpenOptions.html + /// + /// # Examples + /// + /// ```no_run + /// # #![feature(async_await)] + /// use async_std::fs::File; + /// + /// # futures::executor::block_on(async { + /// let file = File::create("foo.txt").await?; + /// # std::io::Result::Ok(()) + /// # }).unwrap(); + /// ``` + pub async fn create>(path: P) -> io::Result { + let path = path.as_ref().to_owned(); + let file = blocking::spawn(async move { fs::File::create(&path) }).await?; + + #[cfg(unix)] + let file = File { + raw_fd: file.as_raw_fd(), + mutex: Mutex::new(State::Idle(Some(Inner { + file, + buf: Vec::new(), + last_op: None, + }))), + }; + + #[cfg(windows)] + let file = File { + raw_handle: UnsafeShared(file.as_raw_handle()), + mutex: Mutex::new(State::Idle(Some(Inner { + file, + buf: Vec::new(), + last_op: None, + }))), + }; + + Ok(file) + } + + /// Attempts to synchronize all OS-internal metadata to disk. + /// + /// This function will attempt to ensure that all in-memory data reaches the filesystem before + /// returning. + /// + /// This can be used to handle errors that would otherwise only be caught when the `File` is + /// closed. Dropping a file will ignore errors in synchronizing this in-memory data. + /// + /// # Examples + /// + /// ```no_run + /// # #![feature(async_await)] + /// use async_std::fs::File; + /// use async_std::prelude::*; + /// + /// # futures::executor::block_on(async { + /// let mut file = File::create("foo.txt").await?; + /// file.write_all(b"Hello, world!").await?; + /// file.sync_all().await?; + /// # std::io::Result::Ok(()) + /// # }).unwrap(); + /// ``` + pub async fn sync_all(&self) -> io::Result<()> { + future::poll_fn(|cx| { + let state = &mut *self.mutex.lock().unwrap(); + + loop { + match state { + State::Idle(opt) => match opt.take() { + None => return Poll::Ready(None), + Some(inner) => { + let (s, r) = futures::channel::oneshot::channel(); + + // Start the operation asynchronously. + *state = State::Busy(blocking::spawn(async move { + let res = inner.file.sync_all(); + let _ = s.send(res); + State::Idle(Some(inner)) + })); + + return Poll::Ready(Some(r)); + } + }, + // Poll the asynchronous operation the file is currently blocked on. + State::Busy(task) => *state = futures::ready!(Pin::new(task).poll(cx)), + } + } + }) + .map(|opt| opt.ok_or_else(|| io_error("file closed"))) + .await? + .map_err(|_| io_error("blocking task failed")) + .await? + } + + /// Similar to [`sync_all`], except that it may not synchronize file metadata. + /// + /// This is intended for use cases that must synchronize content, but don't need the metadata + /// on disk. The goal of this method is to reduce disk operations. + /// + /// Note that some platforms may simply implement this in terms of [`sync_all`]. + /// + /// [`sync_all`]: struct.File.html#method.sync_all + /// + /// # Examples + /// + /// ```no_run + /// # #![feature(async_await)] + /// use async_std::fs::File; + /// use async_std::prelude::*; + /// + /// # futures::executor::block_on(async { + /// let mut file = File::create("foo.txt").await?; + /// file.write_all(b"Hello, world!").await?; + /// file.sync_data().await?; + /// # std::io::Result::Ok(()) + /// # }).unwrap(); + /// ``` + pub async fn sync_data(&self) -> io::Result<()> { + future::poll_fn(|cx| { + let state = &mut *self.mutex.lock().unwrap(); + + loop { + match state { + State::Idle(opt) => match opt.take() { + None => return Poll::Ready(None), + Some(inner) => { + let (s, r) = futures::channel::oneshot::channel(); + + // Start the operation asynchronously. + *state = State::Busy(blocking::spawn(async move { + let res = inner.file.sync_data(); + let _ = s.send(res); + State::Idle(Some(inner)) + })); + + return Poll::Ready(Some(r)); + } + }, + // Poll the asynchronous operation the file is currently blocked on. + State::Busy(task) => *state = futures::ready!(Pin::new(task).poll(cx)), + } + } + }) + .map(|opt| opt.ok_or_else(|| io_error("file closed"))) + .await? + .map_err(|_| io_error("blocking task failed")) + .await? + } + + /// Truncates or extends the underlying file. + /// + /// If the `size` is less than the current file's size, then the file will be truncated. If it + /// is greater than the current file's size, then the file will be extended to `size` and have + /// all of the intermediate data filled in with zeros. + /// + /// The file's cursor isn't changed. In particular, if the cursor was at the end and the file + /// is truncated using this operation, the cursor will now be past the end. + /// + /// # Errors + /// + /// This function will return an error if the file is not opened for writing. + /// + /// # Examples + /// + /// ```no_run + /// # #![feature(async_await)] + /// use async_std::fs::File; + /// use async_std::prelude::*; + /// + /// # futures::executor::block_on(async { + /// let mut file = File::create("foo.txt").await?; + /// file.set_len(10).await?; + /// # std::io::Result::Ok(()) + /// # }).unwrap(); + /// ``` + pub async fn set_len(&self, size: u64) -> io::Result<()> { + future::poll_fn(|cx| { + let state = &mut *self.mutex.lock().unwrap(); + + loop { + match state { + State::Idle(opt) => match opt.take() { + None => return Poll::Ready(None), + Some(inner) => { + let (s, r) = futures::channel::oneshot::channel(); + + // Start the operation asynchronously. + *state = State::Busy(blocking::spawn(async move { + let res = inner.file.set_len(size); + let _ = s.send(res); + State::Idle(Some(inner)) + })); + + return Poll::Ready(Some(r)); + } + }, + // Poll the asynchronous operation the file is currently blocked on. + State::Busy(task) => *state = futures::ready!(Pin::new(task).poll(cx)), + } + } + }) + .map(|opt| opt.ok_or_else(|| io_error("file closed"))) + .await? + .map_err(|_| io_error("blocking task failed")) + .await? + } + + /// Queries metadata about the file. + /// + /// # Examples + /// + /// ```no_run + /// # #![feature(async_await)] + /// use async_std::fs::File; + /// + /// # futures::executor::block_on(async { + /// let file = File::open("foo.txt").await?; + /// let metadata = file.metadata().await?; + /// # std::io::Result::Ok(()) + /// # }).unwrap(); + /// ``` + pub async fn metadata(&self) -> io::Result { + future::poll_fn(|cx| { + let state = &mut *self.mutex.lock().unwrap(); + + loop { + match state { + State::Idle(opt) => match opt.take() { + None => return Poll::Ready(None), + Some(inner) => { + let (s, r) = futures::channel::oneshot::channel(); + + // Start the operation asynchronously. + *state = State::Busy(blocking::spawn(async move { + let res = inner.file.metadata(); + let _ = s.send(res); + State::Idle(Some(inner)) + })); + + return Poll::Ready(Some(r)); + } + }, + // Poll the asynchronous operation the file is currently blocked on. + State::Busy(task) => *state = futures::ready!(Pin::new(task).poll(cx)), + } + } + }) + .map(|opt| opt.ok_or_else(|| io_error("file closed"))) + .await? + .map_err(|_| io_error("blocking task failed")) + .await? + } + + /// Changes the permissions on the underlying file. + /// + /// # Errors + /// + /// This function will return an error if the user lacks permission to change attributes on the + /// underlying file, but may also return an error in other OS-specific cases. + /// + /// # Examples + /// + /// ```no_run + /// # #![feature(async_await)] + /// use async_std::fs::File; + /// use async_std::prelude::*; + /// + /// # futures::executor::block_on(async { + /// let mut file = File::create("foo.txt").await?; + /// let mut perms = file.metadata().await?.permissions(); + /// perms.set_readonly(true); + /// file.set_permissions(perms).await?; + /// # std::io::Result::Ok(()) + /// # }).unwrap(); + /// ``` + pub async fn set_permissions(&self, perm: fs::Permissions) -> io::Result<()> { + let mut perm = Some(perm); + + future::poll_fn(|cx| { + let state = &mut *self.mutex.lock().unwrap(); + + loop { + match state { + State::Idle(opt) => match opt.take() { + None => return Poll::Ready(None), + Some(inner) => { + let (s, r) = futures::channel::oneshot::channel(); + let perm = perm.take().unwrap(); + + // Start the operation asynchronously. + *state = State::Busy(blocking::spawn(async move { + let res = inner.file.set_permissions(perm); + let _ = s.send(res); + State::Idle(Some(inner)) + })); + + return Poll::Ready(Some(r)); + } + }, + // Poll the asynchronous operation the file is currently blocked on. + State::Busy(task) => *state = futures::ready!(Pin::new(task).poll(cx)), + } + } + }) + .map(|opt| opt.ok_or_else(|| io_error("file closed"))) + .await? + .map_err(|_| io_error("blocking task failed")) + .await? + } +} + +impl AsyncRead for File { + fn poll_read( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut [u8], + ) -> Poll> { + Pin::new(&mut &*self).poll_read(cx, buf) + } + + #[inline] + unsafe fn initializer(&self) -> Initializer { + Initializer::nop() + } +} + +impl AsyncRead for &File { + fn poll_read( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut [u8], + ) -> Poll> { + let state = &mut *self.mutex.lock().unwrap(); + + loop { + match state { + State::Idle(opt) => { + // Grab a reference to the inner representation of the file or return an error + // if the file is closed. + let inner = opt.as_mut().ok_or_else(|| io_error("file closed"))?; + let mut offset = 0; + + // Check if the operation has completed. + if let Some(Operation::Read(res)) = inner.last_op.take() { + let n = res?; + + if n <= buf.len() { + // Copy the read data into the buffer and return. + buf[..n].copy_from_slice(&inner.buf[..n]); + return Poll::Ready(Ok(n)); + } + + // If more data was read than fits into the buffer, let's retry the read + // operation, but first move the cursor where it was before the previous + // read. + offset = n; + } + + let mut inner = opt.take().unwrap(); + + // Set the length of the inner buffer to the length of the provided buffer. + if inner.buf.len() < buf.len() { + inner.buf.reserve(buf.len() - inner.buf.len()); + } + unsafe { + inner.buf.set_len(buf.len()); + } + + // Start the operation asynchronously. + *state = State::Busy(blocking::spawn(async move { + if offset > 0 { + let pos = SeekFrom::Current(-(offset as i64)); + let _ = io::Seek::seek(&mut inner.file, pos); + } + + let res = io::Read::read(&mut inner.file, &mut inner.buf); + inner.last_op = Some(Operation::Read(res)); + State::Idle(Some(inner)) + })); + } + // Poll the asynchronous operation the file is currently blocked on. + State::Busy(task) => *state = futures::ready!(Pin::new(task).poll(cx)), + } + } + } + + #[inline] + unsafe fn initializer(&self) -> Initializer { + Initializer::nop() + } +} + +impl AsyncWrite for File { + fn poll_write( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + Pin::new(&mut &*self).poll_write(cx, buf) + } + + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + Pin::new(&mut &*self).poll_flush(cx) + } + + fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + Pin::new(&mut &*self).poll_close(cx) + } +} + +impl AsyncWrite for &File { + fn poll_write( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + let state = &mut *self.mutex.lock().unwrap(); + + loop { + match state { + State::Idle(opt) => { + // Grab a reference to the inner representation of the file or return an error + // if the file is closed. + let inner = opt.as_mut().ok_or_else(|| io_error("file closed"))?; + + // Check if the operation has completed. + if let Some(Operation::Write(res)) = inner.last_op.take() { + let n = res?; + + // If more data was written than is available in the buffer, let's retry + // the write operation. + if n <= buf.len() { + return Poll::Ready(Ok(n)); + } + } else { + let mut inner = opt.take().unwrap(); + + // Set the length of the inner buffer to the length of the provided buffer. + if inner.buf.len() < buf.len() { + inner.buf.reserve(buf.len() - inner.buf.len()); + } + unsafe { + inner.buf.set_len(buf.len()); + } + + // Copy the data to write into the inner buffer. + inner.buf[..buf.len()].copy_from_slice(buf); + + // Start the operation asynchronously. + *state = State::Busy(blocking::spawn(async move { + let res = io::Write::write(&mut inner.file, &mut inner.buf); + inner.last_op = Some(Operation::Write(res)); + State::Idle(Some(inner)) + })); + } + } + // Poll the asynchronous operation the file is currently blocked on. + State::Busy(task) => *state = futures::ready!(Pin::new(task).poll(cx)), + } + } + } + + fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let state = &mut *self.mutex.lock().unwrap(); + + loop { + match state { + State::Idle(opt) => { + // Grab a reference to the inner representation of the file or return if the + // file is closed. + let inner = match opt.as_mut() { + None => return Poll::Ready(Ok(())), + Some(s) => s, + }; + + // Check if the operation has completed. + if let Some(Operation::Flush(res)) = inner.last_op.take() { + return Poll::Ready(res); + } else { + let mut inner = opt.take().unwrap(); + + // Start the operation asynchronously. + *state = State::Busy(blocking::spawn(async move { + let res = io::Write::flush(&mut inner.file); + inner.last_op = Some(Operation::Flush(res)); + State::Idle(Some(inner)) + })); + } + } + // Poll the asynchronous operation the file is currently blocked on. + State::Busy(task) => *state = futures::ready!(Pin::new(task).poll(cx)), + } + } + } + + fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let state = &mut *self.mutex.lock().unwrap(); + + loop { + match state { + State::Idle(opt) => { + // Grab a reference to the inner representation of the file or return if the + // file is closed. + let inner = match opt.take() { + None => return Poll::Ready(Ok(())), + Some(s) => s, + }; + + // Start the operation asynchronously. + *state = State::Busy(blocking::spawn(async move { + drop(inner); + State::Idle(None) + })); + } + // Poll the asynchronous operation the file is currently blocked on. + State::Busy(task) => *state = futures::ready!(Pin::new(task).poll(cx)), + } + } + } +} + +impl AsyncSeek for File { + fn poll_seek( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + pos: SeekFrom, + ) -> Poll> { + Pin::new(&mut &*self).poll_seek(cx, pos) + } +} + +impl AsyncSeek for &File { + fn poll_seek( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + pos: SeekFrom, + ) -> Poll> { + let state = &mut *self.mutex.lock().unwrap(); + + loop { + match state { + State::Idle(opt) => { + // Grab a reference to the inner representation of the file or return an error + // if the file is closed. + let inner = opt.as_mut().ok_or_else(|| io_error("file closed"))?; + + // Check if the operation has completed. + if let Some(Operation::Seek(res)) = inner.last_op.take() { + return Poll::Ready(res); + } else { + let mut inner = opt.take().unwrap(); + + // Start the operation asynchronously. + *state = State::Busy(blocking::spawn(async move { + let res = io::Seek::seek(&mut inner.file, pos); + inner.last_op = Some(Operation::Seek(res)); + State::Idle(Some(inner)) + })); + } + } + // Poll the asynchronous operation the file is currently blocked on. + State::Busy(task) => *state = futures::ready!(Pin::new(task).poll(cx)), + } + } + } +} + +/// Creates a custom `io::Error` with an arbitrary error type. +fn io_error(err: impl Into>) -> io::Error { + io::Error::new(io::ErrorKind::Other, err) +} + +impl From for File { + /// Converts a `std::fs::File` into its asynchronous equivalent. + fn from(file: fs::File) -> File { + #[cfg(unix)] + let file = File { + raw_fd: file.as_raw_fd(), + mutex: Mutex::new(State::Idle(Some(Inner { + file, + buf: Vec::new(), + last_op: None, + }))), + }; + + #[cfg(windows)] + let file = File { + raw_handle: UnsafeShared(file.as_raw_handle()), + mutex: Mutex::new(State::Idle(Some(Inner { + file, + buf: Vec::new(), + last_op: None, + }))), + }; + + file + } +} + +cfg_if! { + if #[cfg(feature = "docs.rs")] { + use crate::os::unix::io::{AsRawFd, FromRawFd, IntoRawFd, RawFd}; + use crate::os::windows::io::{AsRawHandle, FromRawHandle, IntoRawHandle, RawHandle}; + } else if #[cfg(unix)] { + use std::os::unix::io::{AsRawFd, FromRawFd, IntoRawFd, RawFd}; + } else if #[cfg(windows)] { + use std::os::windows::io::{AsRawHandle, FromRawHandle, IntoRawHandle, RawHandle}; + } +} + +#[cfg_attr(feature = "docs.rs", doc(cfg(unix)))] +cfg_if! { + if #[cfg(any(unix, feature = "docs.rs"))] { + impl AsRawFd for File { + fn as_raw_fd(&self) -> RawFd { + self.raw_fd + } + } + + impl FromRawFd for File { + unsafe fn from_raw_fd(fd: RawFd) -> File { + fs::File::from_raw_fd(fd).into() + } + } + + impl IntoRawFd for File { + fn into_raw_fd(self) -> RawFd { + self.raw_fd + } + } + } +} + +#[cfg_attr(feature = "docs.rs", doc(cfg(windows)))] +cfg_if! { + if #[cfg(any(windows, feature = "docs.rs"))] { + impl AsRawHandle for File { + fn as_raw_handle(&self) -> RawHandle { + self.raw_handle.0 + } + } + + impl FromRawHandle for File { + unsafe fn from_raw_handle(handle: RawHandle) -> File { + fs::File::from_raw_handle(handle).into() + } + } + + impl IntoRawHandle for File { + fn into_raw_handle(self) -> RawHandle { + self.raw_handle.0 + } + } + + #[derive(Debug)] + struct UnsafeShared(T); + + unsafe impl Send for UnsafeShared {} + unsafe impl Sync for UnsafeShared {} + } +} diff --git a/src/fs/mod.rs b/src/fs/mod.rs new file mode 100644 index 00000000..f944f436 --- /dev/null +++ b/src/fs/mod.rs @@ -0,0 +1,578 @@ +//! Filesystem manipulation operations. +//! +//! This module is an async version of [`std::fs`]. +//! +//! [`std::fs`]: https://doc.rust-lang.org/std/fs/index.html +//! +//! # Examples +//! +//! Create a new file and write some bytes to it: +//! +//! ```no_run +//! # #![feature(async_await)] +//! use async_std::fs::File; +//! use async_std::prelude::*; +//! +//! # futures::executor::block_on(async { +//! let mut file = File::create("foo.txt").await?; +//! file.write_all(b"Hello, world!").await?; +//! # std::io::Result::Ok(()) +//! # }).unwrap(); +//! ``` + +use std::fs; +use std::io; +use std::path::{Path, PathBuf}; + +use crate::task::blocking; + +pub use dir_builder::DirBuilder; +pub use dir_entry::DirEntry; +pub use file::File; +pub use open_options::OpenOptions; +pub use read_dir::ReadDir; + +mod dir_builder; +mod dir_entry; +mod file; +mod open_options; +mod read_dir; + +#[doc(inline)] +pub use std::fs::{FileType, Metadata, Permissions}; + +/// Returns the canonical form of a path. +/// +/// The returned path is in absolute form with all intermediate components normalized and symbolic +/// links resolved. +/// +/// This function is an async version of [`std::fs::canonicalize`]. +/// +/// [`std::fs::canonicalize`]: https://doc.rust-lang.org/std/fs/fn.canonicalize.html +/// +/// # Errors +/// +/// An error will be returned in the following situations (not an exhaustive list): +/// +/// * `path` does not exist. +/// * A non-final component in path is not a directory. +/// +/// # Examples +/// +/// ```no_run +/// # #![feature(async_await)] +/// use async_std::fs::canonicalize; +/// +/// # futures::executor::block_on(async { +/// let path = canonicalize(".").await?; +/// # std::io::Result::Ok(()) +/// # }).unwrap(); +/// ``` +pub async fn canonicalize>(path: P) -> io::Result { + let path = path.as_ref().to_owned(); + blocking::spawn(async move { fs::canonicalize(path) }).await +} + +/// Creates a new, empty directory. +/// +/// This function is an async version of [`std::fs::create_dir`]. +/// +/// [`std::fs::create_dir`]: https://doc.rust-lang.org/std/fs/fn.create_dir.html +/// +/// # Errors +/// +/// An error will be returned in the following situations (not an exhaustive list): +/// +/// * `path` already exists. +/// * A parent of the given path does not exist. +/// * The current process lacks permissions to create directory at `path`. +/// +/// # Examples +/// +/// ```no_run +/// # #![feature(async_await)] +/// use async_std::fs::create_dir; +/// +/// # futures::executor::block_on(async { +/// create_dir("./some/dir").await?; +/// # std::io::Result::Ok(()) +/// # }).unwrap(); +/// ``` +pub async fn create_dir>(path: P) -> io::Result<()> { + let path = path.as_ref().to_owned(); + blocking::spawn(async move { fs::create_dir(path) }).await +} + +/// Creates a new, empty directory and all of its parents if they are missing. +/// +/// This function is an async version of [`std::fs::create_dir_all`]. +/// +/// [`std::fs::create_dir_all`]: https://doc.rust-lang.org/std/fs/fn.create_dir_all.html +/// +/// # Errors +/// +/// An error will be returned in the following situations (not an exhaustive list): +/// +/// * The parent directories do not exists and couldn't be created. +/// * The current process lacks permissions to create directory at `path`. +/// +/// # Examples +/// +/// ```no_run +/// # #![feature(async_await)] +/// use async_std::fs::create_dir_all; +/// +/// # futures::executor::block_on(async { +/// create_dir_all("./some/dir").await?; +/// # std::io::Result::Ok(()) +/// # }).unwrap(); +/// ``` +pub async fn create_dir_all>(path: P) -> io::Result<()> { + let path = path.as_ref().to_owned(); + blocking::spawn(async move { fs::create_dir_all(path) }).await +} + +/// Creates a new hard link on the filesystem. +/// +/// The `dst` path will be a link pointing to the `src` path. Note that systems often require these +/// two paths to both be located on the same filesystem. +/// +/// This function is an async version of [`std::fs::hard_link`]. +/// +/// [`std::fs::hard_link`]: https://doc.rust-lang.org/std/fs/fn.hard_link.html +/// +/// # Errors +/// +/// An error will be returned in the following situations (not an exhaustive list): +/// +/// * The `src` path is not a file or doesn't exist. +/// +/// # Examples +/// +/// ```no_run +/// # #![feature(async_await)] +/// use async_std::fs::hard_link; +/// +/// # futures::executor::block_on(async { +/// hard_link("a.txt", "b.txt").await?; +/// # std::io::Result::Ok(()) +/// # }).unwrap(); +/// ``` +pub async fn hard_link, Q: AsRef>(from: P, to: Q) -> io::Result<()> { + let from = from.as_ref().to_owned(); + let to = to.as_ref().to_owned(); + blocking::spawn(async move { fs::hard_link(&from, &to) }).await +} + +/// Copies the contents and permissions of one file to another. +/// +/// On success, the total number of bytes copied is returned and equals the length of the `from` +/// file. +/// +/// The old contents of `to` will be overwritten. If `from` and `to` both point to the same file, +/// then the file will likely get truncated by this operation. +/// +/// This function is an async version of [`std::fs::copy`]. +/// +/// [`std::fs::copy`]: https://doc.rust-lang.org/std/fs/fn.copy.html +/// +/// # Errors +/// +/// An error will be returned in the following situations (not an exhaustive list): +/// +/// * The `from` path is not a file. +/// * The `from` file does not exist. +/// * The current process lacks permissions to access `from` or write `to`. +/// +/// # Examples +/// +/// ```no_run +/// # #![feature(async_await)] +/// use async_std::fs::copy; +/// +/// # futures::executor::block_on(async { +/// let bytes_copied = copy("foo.txt", "bar.txt").await?; +/// # std::io::Result::Ok(()) +/// # }).unwrap(); +/// ``` +pub async fn copy, Q: AsRef>(from: P, to: Q) -> io::Result { + let from = from.as_ref().to_owned(); + let to = to.as_ref().to_owned(); + blocking::spawn(async move { fs::copy(&from, &to) }).await +} + +/// Queries the metadata for a path. +/// +/// This function will traverse symbolic links to query information about the file or directory. +/// +/// This function is an async version of [`std::fs::metadata`]. +/// +/// [`std::fs::metadata`]: https://doc.rust-lang.org/std/fs/fn.metadata.html +/// +/// # Errors +/// +/// An error will be returned in the following situations (not an exhaustive list): +/// +/// * `path` does not exist. +/// * The current process lacks permissions to query metadata for `path`. +/// +/// # Examples +/// +/// ```no_run +/// # #![feature(async_await)] +/// use async_std::fs::metadata; +/// +/// # futures::executor::block_on(async { +/// let perm = metadata("foo.txt").await?.permissions(); +/// # std::io::Result::Ok(()) +/// # }).unwrap(); +/// ``` +pub async fn metadata>(path: P) -> io::Result { + let path = path.as_ref().to_owned(); + blocking::spawn(async move { fs::metadata(path) }).await +} + +/// Read the entire contents of a file into a bytes vector. +/// +/// This is a convenience function for reading entire files. It pre-allocates a buffer based on the +/// file size when available, so it is generally faster than manually opening a file and reading +/// into a `Vec`. +/// +/// This function is an async version of [`std::fs::read`]. +/// +/// [`std::fs::read`]: https://doc.rust-lang.org/std/fs/fn.read.html +/// +/// # Errors +/// +/// An error will be returned in the following situations (not an exhaustive list): +/// +/// * `path` does not exist. +/// * The current process lacks permissions to read `path`. +/// +/// # Examples +/// +/// ```no_run +/// # #![feature(async_await)] +/// use async_std::fs::read; +/// +/// # futures::executor::block_on(async { +/// let contents = read("foo.txt").await?; +/// # std::io::Result::Ok(()) +/// # }).unwrap(); +/// ``` +pub async fn read>(path: P) -> io::Result> { + let path = path.as_ref().to_owned(); + blocking::spawn(async move { fs::read(path) }).await +} + +/// Returns a stream over the entries within a directory. +/// +/// The stream yields items of type [`io::Result`]`<`[`DirEntry`]`>`. New errors may be encountered +/// after a stream is initially constructed. +/// +/// This function is an async version of [`std::fs::read_dir`]. +/// +/// [`io::Result`]: https://doc.rust-lang.org/std/io/type.Result.html +/// [`DirEntry`]: struct.DirEntry.html +/// [`std::fs::read_dir`]: https://doc.rust-lang.org/std/fs/fn.read_dir.html +/// +/// # Errors +/// +/// An error will be returned in the following situations (not an exhaustive list): +/// +/// * `path` does not exist. +/// * `path` does not point at a directory. +/// * The current process lacks permissions to view the contents of `path`. +/// +/// # Examples +/// +/// ```no_run +/// # #![feature(async_await)] +/// use async_std::fs::read_dir; +/// use async_std::prelude::*; +/// +/// # futures::executor::block_on(async { +/// let mut dir = read_dir(".").await?; +/// +/// while let Some(entry) = dir.next().await { +/// let entry = entry?; +/// println!("{:?}", entry.file_name()); +/// } +/// # std::io::Result::Ok(()) +/// # }).unwrap(); +/// ``` +pub async fn read_dir>(path: P) -> io::Result { + let path = path.as_ref().to_owned(); + blocking::spawn(async move { fs::read_dir(path) }) + .await + .map(ReadDir::new) +} + +/// Reads a symbolic link, returning the path it points to. +/// +/// This function is an async version of [`std::fs::read_link`]. +/// +/// [`std::fs::read_link`]: https://doc.rust-lang.org/std/fs/fn.read_link.html +/// +/// # Errors +/// +/// An error will be returned in the following situations (not an exhaustive list): +/// +/// * `path` is not a symbolic link. +/// * `path` does not exist. +/// +/// # Examples +/// +/// ```no_run +/// # #![feature(async_await)] +/// use async_std::fs::read_link; +/// +/// # futures::executor::block_on(async { +/// let path = read_link("foo.txt").await?; +/// # std::io::Result::Ok(()) +/// # }).unwrap(); +/// ``` +pub async fn read_link>(path: P) -> io::Result { + let path = path.as_ref().to_owned(); + blocking::spawn(async move { fs::read_link(path) }).await +} + +/// Read the entire contents of a file into a string. +/// +/// This function is an async version of [`std::fs::read_to_string`]. +/// +/// [`std::fs::read_to_string`]: https://doc.rust-lang.org/std/fs/fn.read_to_string.html +/// +/// # Errors +/// +/// An error will be returned in the following situations (not an exhaustive list): +/// +/// * `path` is not a file. +/// * The current process lacks permissions to read `path`. +/// +/// # Examples +/// +/// ```no_run +/// # #![feature(async_await)] +/// use async_std::fs::read_to_string; +/// +/// # futures::executor::block_on(async { +/// let contents = read_to_string("foo.txt").await?; +/// # std::io::Result::Ok(()) +/// # }).unwrap(); +/// ``` +pub async fn read_to_string>(path: P) -> io::Result { + let path = path.as_ref().to_owned(); + blocking::spawn(async move { fs::read_to_string(path) }).await +} + +/// Removes an existing, empty directory. +/// +/// This function is an async version of [`std::fs::remove_dir`]. +/// +/// [`std::fs::remove_dir`]: https://doc.rust-lang.org/std/fs/fn.remove_dir.html +/// +/// # Errors +/// +/// An error will be returned in the following situations (not an exhaustive list): +/// +/// * `path` is not an empty directory. +/// * The current process lacks permissions to remove directory at `path`. +/// +/// # Examples +/// +/// ```no_run +/// # #![feature(async_await)] +/// use async_std::fs::remove_dir; +/// +/// # futures::executor::block_on(async { +/// remove_dir("./some/dir").await?; +/// # std::io::Result::Ok(()) +/// # }).unwrap(); +/// ``` +pub async fn remove_dir>(path: P) -> io::Result<()> { + let path = path.as_ref().to_owned(); + blocking::spawn(async move { fs::remove_dir(path) }).await +} + +/// Removes an directory and all of its contents. +/// +/// This function is an async version of [`std::fs::remove_dir_all`]. +/// +/// [`std::fs::remove_dir_all`]: https://doc.rust-lang.org/std/fs/fn.remove_dir_all.html +/// +/// # Errors +/// +/// An error will be returned in the following situations (not an exhaustive list): +/// +/// * `path` is not a directory. +/// * The current process lacks permissions to remove directory at `path`. +/// +/// # Examples +/// +/// ```no_run +/// # #![feature(async_await)] +/// use async_std::fs::remove_dir_all; +/// +/// # futures::executor::block_on(async { +/// remove_dir_all("./some/dir").await?; +/// # std::io::Result::Ok(()) +/// # }).unwrap(); +/// ``` +pub async fn remove_dir_all>(path: P) -> io::Result<()> { + let path = path.as_ref().to_owned(); + blocking::spawn(async move { fs::remove_dir_all(path) }).await +} + +/// Removes a file from the filesystem. +/// +/// This function is an async version of [`std::fs::remove_file`]. +/// +/// [`std::fs::remove_file`]: https://doc.rust-lang.org/std/fs/fn.remove_file.html +/// +/// # Errors +/// +/// An error will be returned in the following situations (not an exhaustive list): +/// +/// * `path` is not a file. +/// * The current process lacks permissions to remove file at `path`. +/// +/// # Examples +/// +/// ```no_run +/// # #![feature(async_await)] +/// use async_std::fs::remove_file; +/// +/// # futures::executor::block_on(async { +/// remove_file("foo.txt").await?; +/// # std::io::Result::Ok(()) +/// # }).unwrap(); +/// ``` +pub async fn remove_file>(path: P) -> io::Result<()> { + let path = path.as_ref().to_owned(); + blocking::spawn(async move { fs::remove_file(path) }).await +} + +/// Renames a file or directory to a new name, replacing the original if it already exists. +/// +/// This function is an async version of [`std::fs::rename`]. +/// +/// [`std::fs::rename`]: https://doc.rust-lang.org/std/fs/fn.rename.html +/// +/// # Errors +/// +/// An error will be returned in the following situations (not an exhaustive list): +/// +/// * `from` does not exist. +/// * `from` and `to` are on different filesystems. +/// * The current process lacks permissions to rename `from` to `to`. +/// +/// # Examples +/// +/// ```no_run +/// # #![feature(async_await)] +/// use async_std::fs::rename; +/// +/// # futures::executor::block_on(async { +/// rename("a.txt", "b.txt").await?; +/// # std::io::Result::Ok(()) +/// # }).unwrap(); +/// ``` +pub async fn rename, Q: AsRef>(from: P, to: Q) -> io::Result<()> { + let from = from.as_ref().to_owned(); + let to = to.as_ref().to_owned(); + blocking::spawn(async move { fs::rename(&from, &to) }).await +} + +/// Changes the permissions on a file or directory. +/// +/// This function is an async version of [`std::fs::set_permissions`]. +/// +/// [`std::fs::set_permissions`]: https://doc.rust-lang.org/std/fs/fn.set_permissions.html +/// +/// # Errors +/// +/// An error will be returned in the following situations (not an exhaustive list): +/// +/// * `path` does not exist. +/// * The current process lacks permissions to change attributes of `path`. +/// +/// # Examples +/// +/// ```no_run +/// # #![feature(async_await)] +/// use async_std::fs::{metadata, set_permissions}; +/// +/// # futures::executor::block_on(async { +/// let mut perm = metadata("foo.txt").await?.permissions(); +/// perm.set_readonly(true); +/// +/// set_permissions("foo.txt", perm).await?; +/// # std::io::Result::Ok(()) +/// # }).unwrap(); +/// ``` +pub async fn set_permissions>(path: P, perm: fs::Permissions) -> io::Result<()> { + let path = path.as_ref().to_owned(); + blocking::spawn(async move { fs::set_permissions(path, perm) }).await +} + +/// Queries the metadata for a path without following symlinks. +/// +/// This function is an async version of [`std::fs::symlink_metadata`]. +/// +/// [`std::fs::symlink_metadata`]: https://doc.rust-lang.org/std/fs/fn.symlink_metadata.html +/// +/// # Errors +/// +/// An error will be returned in the following situations (not an exhaustive list): +/// +/// * `path` does not exist. +/// * The current process lacks permissions to query metadata for `path`. +/// +/// # Examples +/// +/// ```no_run +/// # #![feature(async_await)] +/// use async_std::fs::symlink_metadata; +/// +/// # futures::executor::block_on(async { +/// let perm = symlink_metadata("foo.txt").await?.permissions(); +/// # std::io::Result::Ok(()) +/// # }).unwrap(); +/// ``` +pub async fn symlink_metadata>(path: P) -> io::Result { + let path = path.as_ref().to_owned(); + blocking::spawn(async move { fs::symlink_metadata(path) }).await +} + +/// Writes a slice of bytes as the entire contents of a file. +/// +/// This function will create a file if it does not exist, and will entirely replace its contents +/// if it does. +/// +/// This function is an async version of [`std::fs::write`]. +/// +/// [`std::fs::write`]: https://doc.rust-lang.org/std/fs/fn.write.html +/// +/// # Errors +/// +/// An error will be returned in the following situations (not an exhaustive list): +/// +/// * The current process lacks permissions to write into `path`. +/// +/// # Examples +/// +/// ```no_run +/// # #![feature(async_await)] +/// use async_std::fs::write; +/// +/// # futures::executor::block_on(async { +/// write("foo.txt", b"Lorem ipsum").await?; +/// # std::io::Result::Ok(()) +/// # }).unwrap(); +/// ``` +pub async fn write, C: AsRef<[u8]>>(path: P, contents: C) -> io::Result<()> { + let path = path.as_ref().to_owned(); + let contents = contents.as_ref().to_owned(); + blocking::spawn(async move { fs::write(path, contents) }).await +} diff --git a/src/fs/open_options.rs b/src/fs/open_options.rs new file mode 100644 index 00000000..1606fe60 --- /dev/null +++ b/src/fs/open_options.rs @@ -0,0 +1,348 @@ +use std::fs; +use std::future::Future; +use std::io; +use std::path::Path; + +use cfg_if::cfg_if; + +use super::File; +use crate::task::blocking; + +/// Options and flags which for configuring how a file is opened. +/// +/// This builder exposes the ability to configure how a [`File`] is opened and what operations are +/// permitted on the open file. The [`File::open`] and [`File::create`] methods are aliases for +/// commonly used options with this builder. +/// +/// Generally speaking, when using `OpenOptions`, you'll first call [`new`], then chain calls to +/// methods to set each option, then call [`open`], passing the path of the file you're trying to +/// open. This will give you a [`File`] inside that you can further operate on. +/// +/// This type is an async version of [`std::fs::OpenOptions`]. +/// +/// [`new`]: struct.OpenOptions.html#method.new +/// [`open`]: struct.OpenOptions.html#method.open +/// [`File`]: struct.File.html +/// [`File::open`]: struct.File.html#method.open +/// [`File::create`]: struct.File.html#method.create +/// [`std::fs::OpenOptions`]: https://doc.rust-lang.org/std/fs/struct.OpenOptions.html +/// +/// # Examples +/// +/// Opening a file for reading: +/// +/// ```no_run +/// # #![feature(async_await)] +/// use async_std::fs::OpenOptions; +/// +/// # futures::executor::block_on(async { +/// let file = OpenOptions::new() +/// .read(true) +/// .open("foo.txt") +/// .await?; +/// # std::io::Result::Ok(()) +/// # }).unwrap(); +/// ``` +/// +/// Opening a file for both reading and writing, creating it if it doesn't exist: +/// +/// ```no_run +/// # #![feature(async_await)] +/// use async_std::fs::OpenOptions; +/// +/// # futures::executor::block_on(async { +/// let file = OpenOptions::new() +/// .read(true) +/// .write(true) +/// .create(true) +/// .open("foo.txt") +/// .await?; +/// # std::io::Result::Ok(()) +/// # }).unwrap(); +/// ``` +#[derive(Clone, Debug)] +pub struct OpenOptions(fs::OpenOptions); + +impl OpenOptions { + /// Creates a blank new set of options. + /// + /// All options are initially set to `false`. + /// + /// # Examples + /// + /// ```no_run + /// # #![feature(async_await)] + /// use async_std::fs::OpenOptions; + /// + /// # futures::executor::block_on(async { + /// let file = OpenOptions::new() + /// .read(true) + /// .open("foo.txt") + /// .await?; + /// # std::io::Result::Ok(()) + /// # }).unwrap(); + /// ``` + pub fn new() -> OpenOptions { + OpenOptions(fs::OpenOptions::new()) + } + + /// Sets the option for read access. + /// + /// This option, when `true`, will indicate that the file should be readable if opened. + /// + /// # Examples + /// + /// ```no_run + /// # #![feature(async_await)] + /// use async_std::fs::OpenOptions; + /// + /// # futures::executor::block_on(async { + /// let file = OpenOptions::new() + /// .read(true) + /// .open("foo.txt") + /// .await?; + /// # std::io::Result::Ok(()) + /// # }).unwrap(); + /// ``` + pub fn read(&mut self, read: bool) -> &mut OpenOptions { + self.0.read(read); + self + } + + /// Sets the option for write access. + /// + /// This option, when `true`, will indicate that the file should be writable if opened. + /// + /// If the file already exists, any write calls on it will overwrite its contents, without + /// truncating it. + /// + /// # Examples + /// + /// ```no_run + /// # #![feature(async_await)] + /// use async_std::fs::OpenOptions; + /// + /// # futures::executor::block_on(async { + /// let file = OpenOptions::new() + /// .write(true) + /// .open("foo.txt") + /// .await?; + /// # std::io::Result::Ok(()) + /// # }).unwrap(); + /// ``` + pub fn write(&mut self, write: bool) -> &mut OpenOptions { + self.0.write(write); + self + } + + /// Sets the option for append mode. + /// + /// This option, when `true`, means that writes will append to a file instead of overwriting + /// previous contents. Note that setting `.write(true).append(true)` has the same effect as + /// setting only `.append(true)`. + /// + /// For most filesystems, the operating system guarantees that all writes are atomic: no writes + /// get mangled because another process writes at the same time. + /// + /// One maybe obvious note when using append mode: make sure that all data that belongs + /// together is written to the file in one operation. This can be done by concatenating strings + /// before writing them, or using a buffered writer (with a buffer of adequate size), and + /// flushing when the message is complete. + /// + /// If a file is opened with both read and append access, beware that after opening and after + /// every write, the position for reading may be set at the end of the file. So, before + /// writing, save the current position by seeking with a zero offset, and restore it before the + /// next read. + /// + /// ## Note + /// + /// This function doesn't create the file if it doesn't exist. Use the [`create`] method to do + /// so. + /// + /// [`create`]: #method.create + /// + /// # Examples + /// + /// ```no_run + /// # #![feature(async_await)] + /// use async_std::fs::OpenOptions; + /// + /// # futures::executor::block_on(async { + /// let file = OpenOptions::new() + /// .append(true) + /// .open("foo.txt") + /// .await?; + /// # std::io::Result::Ok(()) + /// # }).unwrap(); + /// ``` + pub fn append(&mut self, append: bool) -> &mut OpenOptions { + self.0.append(append); + self + } + + /// Sets the option for truncating a previous file. + /// + /// If a file is successfully opened with this option set, it will truncate the file to 0 + /// length if it already exists. + /// + /// The file must be opened with write access for truncation to work. + /// + /// # Examples + /// + /// ```no_run + /// # #![feature(async_await)] + /// use async_std::fs::OpenOptions; + /// + /// # futures::executor::block_on(async { + /// let file = OpenOptions::new() + /// .write(true) + /// .truncate(true) + /// .open("foo.txt") + /// .await?; + /// # std::io::Result::Ok(()) + /// # }).unwrap(); + /// ``` + pub fn truncate(&mut self, truncate: bool) -> &mut OpenOptions { + self.0.truncate(truncate); + self + } + + /// Sets the option for creating a new file. + /// + /// This option indicates whether a new file will be created if the file does not yet exist. + /// + /// In order for the file to be created, [`write`] or [`append`] access must be used. + /// + /// [`write`]: #method.write + /// [`append`]: #method.append + /// + /// # Examples + /// + /// ```no_run + /// # #![feature(async_await)] + /// use async_std::fs::OpenOptions; + /// + /// # futures::executor::block_on(async { + /// let file = OpenOptions::new() + /// .write(true) + /// .create(true) + /// .open("foo.txt") + /// .await?; + /// # std::io::Result::Ok(()) + /// # }).unwrap(); + /// ``` + pub fn create(&mut self, create: bool) -> &mut OpenOptions { + self.0.create(create); + self + } + + /// Sets the option to always create a new file. + /// + /// This option indicates whether a new file will be created. No file is allowed to exist at + /// the target location, also no (dangling) symlink. + /// + /// This option is useful because it is atomic. Otherwise, between checking whether a file + /// exists and creating a new one, the file may have been created by another process (a TOCTOU + /// race condition / attack). + /// + /// If `.create_new(true)` is set, [`.create()`] and [`.truncate()`] are ignored. + /// + /// The file must be opened with write or append access in order to create a new file. + /// + /// [`.create()`]: #method.create + /// [`.truncate()`]: #method.truncate + /// + /// # Examples + /// + /// ```no_run + /// # #![feature(async_await)] + /// use async_std::fs::OpenOptions; + /// + /// # futures::executor::block_on(async { + /// let file = OpenOptions::new() + /// .write(true) + /// .create_new(true) + /// .open("foo.txt") + /// .await?; + /// # std::io::Result::Ok(()) + /// # }).unwrap(); + /// ``` + pub fn create_new(&mut self, create_new: bool) -> &mut OpenOptions { + self.0.create_new(create_new); + self + } + + /// Opens a file at specified path with the configured options. + /// + /// # Errors + /// + /// This function will return an error under a number of different circumstances. Some of these + /// error conditions are listed here, together with their [`ErrorKind`]. The mapping to + /// [`ErrorKind`]s is not part of the compatibility contract of the function, especially the + /// `Other` kind might change to more specific kinds in the future. + /// + /// * [`NotFound`]: The specified file does not exist and neither `create` or `create_new` is + /// set. + /// * [`NotFound`]: One of the directory components of the file path does not exist. + /// * [`PermissionDenied`]: The user lacks permission to get the specified access rights for + /// the file. + /// * [`PermissionDenied`]: The user lacks permission to open one of the directory components + /// of the specified path. + /// * [`AlreadyExists`]: `create_new` was specified and the file already exists. + /// * [`InvalidInput`]: Invalid combinations of open options (truncate without write access, no + /// access mode set, etc.). + /// * [`Other`]: One of the directory components of the specified file path was not, in fact, a + /// directory. + /// * [`Other`]: Filesystem-level errors: full disk, write permission requested on a read-only + /// file system, exceeded disk quota, too many open files, too long filename, too many + /// symbolic links in the specified path (Unix-like systems only), etc. + /// + /// [`ErrorKind`]: https://doc.rust-lang.org/std/io/enum.ErrorKind.html + /// [`AlreadyExists`]: https://doc.rust-lang.org/std/io/enum.ErrorKind.html#variant.AlreadyExists + /// [`InvalidInput`]: https://doc.rust-lang.org/std/io/enum.ErrorKind.html#variant.InvalidInput + /// [`NotFound`]: https://doc.rust-lang.org/std/io/enum.ErrorKind.html#variant.NotFound + /// [`Other`]: https://doc.rust-lang.org/std/io/enum.ErrorKind.html#variant.Other + /// [`PermissionDenied`]: https://doc.rust-lang.org/std/io/enum.ErrorKind.html#variant.PermissionDenied + /// + /// # Examples + /// + /// ```no_run + /// # #![feature(async_await)] + /// use async_std::fs::OpenOptions; + /// + /// # futures::executor::block_on(async { + /// let file = OpenOptions::new().open("foo.txt").await?; + /// # std::io::Result::Ok(()) + /// # }).unwrap(); + /// ``` + pub fn open>(&self, path: P) -> impl Future> { + let path = path.as_ref().to_owned(); + let options = self.0.clone(); + async move { blocking::spawn(async move { options.open(path).map(|f| f.into()) }).await } + } +} + +cfg_if! { + if #[cfg(feature = "docs.rs")] { + use crate::os::unix::fs::OpenOptionsExt; + } else if #[cfg(unix)] { + use std::os::unix::fs::OpenOptionsExt; + } +} + +#[cfg_attr(feature = "docs.rs", doc(cfg(unix)))] +cfg_if! { + if #[cfg(any(unix, feature = "docs.rs"))] { + impl OpenOptionsExt for OpenOptions { + fn mode(&mut self, mode: u32) -> &mut Self { + self.0.mode(mode); + self + } + + fn custom_flags(&mut self, flags: i32) -> &mut Self { + self.0.custom_flags(flags); + self + } + } + } +} diff --git a/src/fs/read_dir.rs b/src/fs/read_dir.rs new file mode 100644 index 00000000..e663c14f --- /dev/null +++ b/src/fs/read_dir.rs @@ -0,0 +1,95 @@ +use std::fs; +use std::future::Future; +use std::io; + +use std::pin::Pin; +use std::sync::Mutex; +use std::task::{Context, Poll}; + +use futures::Stream; + +use super::DirEntry; +use crate::task::blocking; + +/// A stream over entries in a directory. +/// +/// This stream is returned by [`read_dir`] and yields items of type +/// [`io::Result`]`<`[`DirEntry`]`>`. Each [`DirEntry`] can then retrieve information like entry's +/// path or metadata. +/// +/// This type is an async version of [`std::fs::ReadDir`]. +/// +/// [`read_dir`]: fn.read_dir.html +/// [`io::Result`]: https://doc.rust-lang.org/std/io/type.Result.html +/// [`DirEntry`]: struct.DirEntry.html +/// [`std::fs::ReadDir`]: https://doc.rust-lang.org/std/fs/struct.ReadDir.html +#[derive(Debug)] +pub struct ReadDir(Mutex); + +/// The state of an asynchronous `ReadDir`. +/// +/// The `ReadDir` can be either idle or busy performing an asynchronous operation. +#[derive(Debug)] +enum State { + Idle(Option), + Busy(blocking::JoinHandle), +} + +/// Inner representation of an asynchronous `DirEntry`. +#[derive(Debug)] +struct Inner { + /// The blocking handle. + read_dir: fs::ReadDir, + + /// The next item in the stream. + item: Option>, +} + +impl ReadDir { + /// Creates an asynchronous `ReadDir` from a synchronous handle. + pub(crate) fn new(inner: fs::ReadDir) -> ReadDir { + ReadDir(Mutex::new(State::Idle(Some(Inner { + read_dir: inner, + item: None, + })))) + } +} + +impl Stream for ReadDir { + type Item = io::Result; + + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let state = &mut *self.0.lock().unwrap(); + + loop { + match state { + State::Idle(opt) => { + let inner = match opt.as_mut() { + None => return Poll::Ready(None), + Some(inner) => inner, + }; + + // Check if the operation has completed. + if let Some(res) = inner.item.take() { + return Poll::Ready(Some(res)); + } else { + let mut inner = opt.take().unwrap(); + + // Start the operation asynchronously. + *state = State::Busy(blocking::spawn(async move { + match inner.read_dir.next() { + None => State::Idle(None), + Some(res) => { + inner.item = Some(res.map(DirEntry::new)); + State::Idle(Some(inner)) + } + } + })); + } + } + // Poll the asynchronous operation the file is currently blocked on. + State::Busy(task) => *state = futures::ready!(Pin::new(task).poll(cx)), + } + } + } +} diff --git a/src/future/mod.rs b/src/future/mod.rs new file mode 100644 index 00000000..42531877 --- /dev/null +++ b/src/future/mod.rs @@ -0,0 +1,42 @@ +//! Asynchronous values. + +#[doc(inline)] +pub use std::future::Future; + +/// Never resolves to a value. +/// +/// # Examples +/// ``` +/// # #![feature(async_await)] +/// use async_std::future::pending; +/// use async_std::prelude::*; +/// use std::time::Duration; +/// +/// # async_std::task::block_on(async { +/// let dur = Duration::from_secs(1); +/// assert!(pending::<()>().timeout(dur).await.is_err()); +/// # }) +/// ``` +pub async fn pending() -> T { + futures::future::pending::().await +} + +/// Resolves to the provided value. +/// +/// This function is an async version of [`std::convert::identity`]. +/// +/// [`std::convert::identity`]: https://doc.rust-lang.org/std/convert/fn.identity.html +/// +/// # Examples +/// +/// ``` +/// # #![feature(async_await)] +/// use async_std::future::ready; +/// +/// # async_std::task::block_on(async { +/// assert_eq!(ready(10).await, 10); +/// # }) +/// ``` +pub async fn ready(val: T) -> T { + val +} diff --git a/src/io/copy.rs b/src/io/copy.rs new file mode 100644 index 00000000..6c966053 --- /dev/null +++ b/src/io/copy.rs @@ -0,0 +1,52 @@ +use futures::prelude::*; +use std::io; + +/// Copies the entire contents of a reader into a writer. +/// +/// This function will continuously read data from `reader` and then +/// write it into `writer` in a streaming fashion until `reader` +/// returns EOF. +/// +/// On success, the total number of bytes that were copied from +/// `reader` to `writer` is returned. +/// +/// If you’re wanting to copy the contents of one file to another and you’re +/// working with filesystem paths, see the [`fs::copy`] function. +/// +/// This function is an async version of [`std::fs::write`]. +/// +/// [`std::io::copy`]: https://doc.rust-lang.org/std/io/fn.copy.html +/// [`fs::copy`]: ../fs/fn.copy.html +/// +/// # Errors +/// +/// This function will return an error immediately if any call to `read` or +/// `write` returns an error. All instances of `ErrorKind::Interrupted` are +/// handled by this function and the underlying operation is retried. +/// +/// # Examples +/// +/// ``` +/// # #![feature(async_await)] +/// use async_std::{io, task}; +/// +/// fn main() -> std::io::Result<()> { +/// task::block_on(async { +/// let mut reader: &[u8] = b"hello"; +/// let mut writer: Vec = vec![]; +/// +/// io::copy(&mut reader, &mut writer).await?; +/// +/// assert_eq!(&b"hello"[..], &writer[..]); +/// Ok(()) +/// }) +/// } +/// ``` +pub async fn copy(reader: &mut R, writer: &mut W) -> io::Result +where + R: AsyncRead + Unpin + ?Sized, + W: AsyncWrite + Unpin + ?Sized, +{ + let bytes_read = reader.copy_into(writer).await?; + Ok(bytes_read) +} diff --git a/src/io/mod.rs b/src/io/mod.rs new file mode 100644 index 00000000..f8f2dd3f --- /dev/null +++ b/src/io/mod.rs @@ -0,0 +1,37 @@ +//! Basic input and output. +//! +//! This module is an async version of [`std::io`]. +//! +//! [`std::io`]: https://doc.rust-lang.org/std/io/index.html +//! +//! # Examples +//! +//! Read a line from the standard input: +//! +//! ```no_run +//! # #![feature(async_await)] +//! use async_std::io; +//! +//! # futures::executor::block_on(async { +//! let stdin = io::stdin(); +//! let mut line = String::new(); +//! stdin.read_line(&mut line).await?; +//! # std::io::Result::Ok(()) +//! # }).unwrap(); +//! ``` + +#[doc(inline)] +pub use futures::io::{AsyncBufRead, AsyncRead, AsyncSeek, AsyncWrite, SeekFrom}; + +pub use copy::copy; +pub use stderr::{stderr, Stderr}; +pub use stdin::{stdin, Stdin}; +pub use stdout::{stdout, Stdout}; + +mod copy; +mod stderr; +mod stdin; +mod stdout; + +#[doc(inline)] +pub use std::io::{empty, sink, Cursor, Empty, Error, ErrorKind, Result, Sink}; diff --git a/src/io/stderr.rs b/src/io/stderr.rs new file mode 100644 index 00000000..9f7660e7 --- /dev/null +++ b/src/io/stderr.rs @@ -0,0 +1,198 @@ +use std::future::Future; +use std::io; +use std::pin::Pin; +use std::sync::Mutex; +use std::task::{Context, Poll}; + +use cfg_if::cfg_if; +use futures::prelude::*; + +use crate::task::blocking; + +/// Constructs a new handle to the standard error of the current process. +/// +/// This function is an async version of [`std::io::stderr`]. +/// +/// [`std::io::stderr`]: https://doc.rust-lang.org/std/io/fn.stderr.html +/// +/// # Examples +/// +/// ```no_run +/// # #![feature(async_await)] +/// use async_std::io::stderr; +/// use async_std::prelude::*; +/// +/// # futures::executor::block_on(async { +/// let mut stderr = stderr(); +/// stderr.write_all(b"Hello, world!").await?; +/// # std::io::Result::Ok(()) +/// # }).unwrap(); +/// ``` +pub fn stderr() -> Stderr { + Stderr(Mutex::new(State::Idle(Some(Inner { + stderr: io::stderr(), + buf: Vec::new(), + last_op: None, + })))) +} + +/// A handle to the standard error of the current process. +/// +/// Created by the [`stderr`] function. +/// +/// This type is an async version of [`std::io::Stderr`]. +/// +/// [`stderr`]: fn.stderr.html +/// [`std::io::Stderr`]: https://doc.rust-lang.org/std/io/struct.Stderr.html +#[derive(Debug)] +pub struct Stderr(Mutex); + +/// The state of the asynchronous stderr. +/// +/// The stderr can be either idle or busy performing an asynchronous operation. +#[derive(Debug)] +enum State { + /// The stderr is idle. + Idle(Option), + + /// The stderr is blocked on an asynchronous operation. + /// + /// Awaiting this operation will result in the new state of the stderr. + Busy(blocking::JoinHandle), +} + +/// Inner representation of the asynchronous stderr. +#[derive(Debug)] +struct Inner { + /// The blocking stderr handle. + stderr: io::Stderr, + + /// The write buffer. + buf: Vec, + + /// The result of the last asynchronous operation on the stderr. + last_op: Option, +} + +/// Possible results of an asynchronous operation on the stderr. +#[derive(Debug)] +enum Operation { + Write(io::Result), + Flush(io::Result<()>), +} + +impl AsyncWrite for Stderr { + fn poll_write( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + let state = &mut *self.0.lock().unwrap(); + + loop { + match state { + State::Idle(opt) => { + let inner = opt.as_mut().unwrap(); + + // Check if the operation has completed. + if let Some(Operation::Write(res)) = inner.last_op.take() { + let n = res?; + + // If more data was written than is available in the buffer, let's retry + // the write operation. + if n <= buf.len() { + return Poll::Ready(Ok(n)); + } + } else { + let mut inner = opt.take().unwrap(); + + // Set the length of the inner buffer to the length of the provided buffer. + if inner.buf.len() < buf.len() { + inner.buf.reserve(buf.len() - inner.buf.len()); + } + unsafe { + inner.buf.set_len(buf.len()); + } + + // Copy the data to write into the inner buffer. + inner.buf[..buf.len()].copy_from_slice(buf); + + // Start the operation asynchronously. + *state = State::Busy(blocking::spawn(async move { + let res = io::Write::write(&mut inner.stderr, &mut inner.buf); + inner.last_op = Some(Operation::Write(res)); + State::Idle(Some(inner)) + })); + } + } + // Poll the asynchronous operation the stderr is currently blocked on. + State::Busy(task) => *state = futures::ready!(Pin::new(task).poll(cx)), + } + } + } + + fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let state = &mut *self.0.lock().unwrap(); + + loop { + match state { + State::Idle(opt) => { + let inner = opt.as_mut().unwrap(); + + // Check if the operation has completed. + if let Some(Operation::Flush(res)) = inner.last_op.take() { + return Poll::Ready(res); + } else { + let mut inner = opt.take().unwrap(); + + // Start the operation asynchronously. + *state = State::Busy(blocking::spawn(async move { + let res = io::Write::flush(&mut inner.stderr); + inner.last_op = Some(Operation::Flush(res)); + State::Idle(Some(inner)) + })); + } + } + // Poll the asynchronous operation the stderr is currently blocked on. + State::Busy(task) => *state = futures::ready!(Pin::new(task).poll(cx)), + } + } + } + + fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.poll_flush(cx) + } +} + +cfg_if! { + if #[cfg(feature = "docs.rs")] { + use crate::os::unix::io::{AsRawFd, RawFd}; + use crate::os::windows::io::{AsRawHandle, RawHandle}; + } else if #[cfg(unix)] { + use std::os::unix::io::{AsRawFd, RawFd}; + } else if #[cfg(windows)] { + use std::os::windows::io::{AsRawHandle, RawHandle}; + } +} + +#[cfg_attr(feature = "docs.rs", doc(cfg(unix)))] +cfg_if! { + if #[cfg(any(unix, feature = "docs.rs"))] { + impl AsRawFd for Stderr { + fn as_raw_fd(&self) -> RawFd { + io::stderr().as_raw_fd() + } + } + } +} + +#[cfg_attr(feature = "docs.rs", doc(cfg(unix)))] +cfg_if! { + if #[cfg(any(windows, feature = "docs.rs"))] { + impl AsRawHandle for Stderr { + fn as_raw_handle(&self) -> RawHandle { + io::stderr().as_raw_handle() + } + } + } +} diff --git a/src/io/stdin.rs b/src/io/stdin.rs new file mode 100644 index 00000000..3c088f87 --- /dev/null +++ b/src/io/stdin.rs @@ -0,0 +1,228 @@ +use std::future::Future; +use std::io; +use std::pin::Pin; +use std::sync::Mutex; +use std::task::{Context, Poll}; + +use cfg_if::cfg_if; +use futures::io::Initializer; +use futures::prelude::*; + +use crate::task::blocking; + +/// Constructs a new handle to the standard input of the current process. +/// +/// This function is an async version of [`std::io::stdin`]. +/// +/// [`std::io::stdin`]: https://doc.rust-lang.org/std/io/fn.stdin.html +/// +/// # Examples +/// +/// ```no_run +/// # #![feature(async_await)] +/// use async_std::io::stdin; +/// +/// # futures::executor::block_on(async { +/// let stdin = stdin(); +/// let mut line = String::new(); +/// stdin.read_line(&mut line).await?; +/// # std::io::Result::Ok(()) +/// # }).unwrap(); +/// ``` +pub fn stdin() -> Stdin { + Stdin(Mutex::new(State::Idle(Some(Inner { + stdin: io::stdin(), + line: String::new(), + buf: Vec::new(), + last_op: None, + })))) +} + +/// A handle to the standard input of the current process. +/// +/// Created by the [`stdin`] function. +/// +/// This type is an async version of [`std::io::Stdin`]. +/// +/// [`stdin`]: fn.stdin.html +/// [`std::io::Stdin`]: https://doc.rust-lang.org/std/io/struct.Stdin.html +#[derive(Debug)] +pub struct Stdin(Mutex); + +/// The state of the asynchronous stdin. +/// +/// The stdin can be either idle or busy performing an asynchronous operation. +#[derive(Debug)] +enum State { + /// The stdin is idle. + Idle(Option), + + /// The stdin is blocked on an asynchronous operation. + /// + /// Awaiting this operation will result in the new state of the stdin. + Busy(blocking::JoinHandle), +} + +/// Inner representation of the asynchronous stdin. +#[derive(Debug)] +struct Inner { + /// The blocking stdin handle. + stdin: io::Stdin, + + /// The line buffer. + line: String, + + /// The write buffer. + buf: Vec, + + /// The result of the last asynchronous operation on the stdin. + last_op: Option, +} + +/// Possible results of an asynchronous operation on the stdin. +#[derive(Debug)] +enum Operation { + ReadLine(io::Result), + Read(io::Result), +} + +impl Stdin { + /// Reads a line of input into the specified buffer. + /// + /// # Examples + /// + /// ```no_run + /// # #![feature(async_await)] + /// use async_std::io::stdin; + /// + /// # futures::executor::block_on(async { + /// let stdin = stdin(); + /// let mut line = String::new(); + /// stdin.read_line(&mut line).await?; + /// # std::io::Result::Ok(()) + /// # }).unwrap(); + /// ``` + pub async fn read_line(&self, buf: &mut String) -> io::Result { + future::poll_fn(|cx| { + let state = &mut *self.0.lock().unwrap(); + + loop { + match state { + State::Idle(opt) => { + let inner = opt.as_mut().unwrap(); + + // Check if the operation has completed. + if let Some(Operation::ReadLine(res)) = inner.last_op.take() { + let n = res?; + + // Copy the read data into the buffer and return. + buf.push_str(&inner.line); + return Poll::Ready(Ok(n)); + } else { + let mut inner = opt.take().unwrap(); + + // Start the operation asynchronously. + *state = State::Busy(blocking::spawn(async move { + inner.line.clear(); + let res = inner.stdin.read_line(&mut inner.line); + inner.last_op = Some(Operation::ReadLine(res)); + State::Idle(Some(inner)) + })); + } + } + // Poll the asynchronous operation the stdin is currently blocked on. + State::Busy(task) => *state = futures::ready!(Pin::new(task).poll(cx)), + } + } + }) + .await + } +} + +impl AsyncRead for Stdin { + fn poll_read( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut [u8], + ) -> Poll> { + let state = &mut *self.0.lock().unwrap(); + + loop { + match state { + State::Idle(opt) => { + let inner = opt.as_mut().unwrap(); + + // Check if the operation has completed. + if let Some(Operation::Read(res)) = inner.last_op.take() { + let n = res?; + + // If more data was read than fits into the buffer, let's retry the read + // operation. + if n <= buf.len() { + // Copy the read data into the buffer and return. + buf[..n].copy_from_slice(&inner.buf[..n]); + return Poll::Ready(Ok(n)); + } + } else { + let mut inner = opt.take().unwrap(); + + // Set the length of the inner buffer to the length of the provided buffer. + if inner.buf.len() < buf.len() { + inner.buf.reserve(buf.len() - inner.buf.len()); + } + unsafe { + inner.buf.set_len(buf.len()); + } + + // Start the operation asynchronously. + *state = State::Busy(blocking::spawn(async move { + let res = io::Read::read(&mut inner.stdin, &mut inner.buf); + inner.last_op = Some(Operation::Read(res)); + State::Idle(Some(inner)) + })); + } + } + // Poll the asynchronous operation the stdin is currently blocked on. + State::Busy(task) => *state = futures::ready!(Pin::new(task).poll(cx)), + } + } + } + + #[inline] + unsafe fn initializer(&self) -> Initializer { + Initializer::nop() + } +} + +cfg_if! { + if #[cfg(feature = "docs.rs")] { + use crate::os::unix::io::{AsRawFd, RawFd}; + use crate::os::windows::io::{AsRawHandle, RawHandle}; + } else if #[cfg(unix)] { + use std::os::unix::io::{AsRawFd, RawFd}; + } else if #[cfg(windows)] { + use std::os::windows::io::{AsRawHandle, RawHandle}; + } +} + +#[cfg_attr(feature = "docs.rs", doc(cfg(unix)))] +cfg_if! { + if #[cfg(any(unix, feature = "docs.rs"))] { + impl AsRawFd for Stdin { + fn as_raw_fd(&self) -> RawFd { + io::stdin().as_raw_fd() + } + } + } +} + +#[cfg_attr(feature = "docs.rs", doc(cfg(unix)))] +cfg_if! { + if #[cfg(any(windows, feature = "docs.rs"))] { + impl AsRawHandle for Stdin { + fn as_raw_handle(&self) -> RawHandle { + io::stdin().as_raw_handle() + } + } + } +} diff --git a/src/io/stdout.rs b/src/io/stdout.rs new file mode 100644 index 00000000..7609b764 --- /dev/null +++ b/src/io/stdout.rs @@ -0,0 +1,198 @@ +use std::future::Future; +use std::io; +use std::pin::Pin; +use std::sync::Mutex; +use std::task::{Context, Poll}; + +use cfg_if::cfg_if; +use futures::prelude::*; + +use crate::task::blocking; + +/// Constructs a new handle to the standard output of the current process. +/// +/// This function is an async version of [`std::io::stdout`]. +/// +/// [`std::io::stdout`]: https://doc.rust-lang.org/std/io/fn.stdout.html +/// +/// # Examples +/// +/// ```no_run +/// # #![feature(async_await)] +/// use async_std::io::stdout; +/// use async_std::prelude::*; +/// +/// # futures::executor::block_on(async { +/// let mut stdout = stdout(); +/// stdout.write_all(b"Hello, world!").await?; +/// # std::io::Result::Ok(()) +/// # }).unwrap(); +/// ``` +pub fn stdout() -> Stdout { + Stdout(Mutex::new(State::Idle(Some(Inner { + stdout: io::stdout(), + buf: Vec::new(), + last_op: None, + })))) +} + +/// A handle to the standard output of the current process. +/// +/// Created by the [`stdout`] function. +/// +/// This type is an async version of [`std::io::Stdout`]. +/// +/// [`stdout`]: fn.stdout.html +/// [`std::io::Stdout`]: https://doc.rust-lang.org/std/io/struct.Stdout.html +#[derive(Debug)] +pub struct Stdout(Mutex); + +/// The state of the asynchronous stdout. +/// +/// The stdout can be either idle or busy performing an asynchronous operation. +#[derive(Debug)] +enum State { + /// The stdout is idle. + Idle(Option), + + /// The stdout is blocked on an asynchronous operation. + /// + /// Awaiting this operation will result in the new state of the stdout. + Busy(blocking::JoinHandle), +} + +/// Inner representation of the asynchronous stdout. +#[derive(Debug)] +struct Inner { + /// The blocking stdout handle. + stdout: io::Stdout, + + /// The write buffer. + buf: Vec, + + /// The result of the last asynchronous operation on the stdout. + last_op: Option, +} + +/// Possible results of an asynchronous operation on the stdout. +#[derive(Debug)] +enum Operation { + Write(io::Result), + Flush(io::Result<()>), +} + +impl AsyncWrite for Stdout { + fn poll_write( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + let state = &mut *self.0.lock().unwrap(); + + loop { + match state { + State::Idle(opt) => { + let inner = opt.as_mut().unwrap(); + + // Check if the operation has completed. + if let Some(Operation::Write(res)) = inner.last_op.take() { + let n = res?; + + // If more data was written than is available in the buffer, let's retry + // the write operation. + if n <= buf.len() { + return Poll::Ready(Ok(n)); + } + } else { + let mut inner = opt.take().unwrap(); + + // Set the length of the inner buffer to the length of the provided buffer. + if inner.buf.len() < buf.len() { + inner.buf.reserve(buf.len() - inner.buf.len()); + } + unsafe { + inner.buf.set_len(buf.len()); + } + + // Copy the data to write into the inner buffer. + inner.buf[..buf.len()].copy_from_slice(buf); + + // Start the operation asynchronously. + *state = State::Busy(blocking::spawn(async move { + let res = io::Write::write(&mut inner.stdout, &mut inner.buf); + inner.last_op = Some(Operation::Write(res)); + State::Idle(Some(inner)) + })); + } + } + // Poll the asynchronous operation the stdout is currently blocked on. + State::Busy(task) => *state = futures::ready!(Pin::new(task).poll(cx)), + } + } + } + + fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let state = &mut *self.0.lock().unwrap(); + + loop { + match state { + State::Idle(opt) => { + let inner = opt.as_mut().unwrap(); + + // Check if the operation has completed. + if let Some(Operation::Flush(res)) = inner.last_op.take() { + return Poll::Ready(res); + } else { + let mut inner = opt.take().unwrap(); + + // Start the operation asynchronously. + *state = State::Busy(blocking::spawn(async move { + let res = io::Write::flush(&mut inner.stdout); + inner.last_op = Some(Operation::Flush(res)); + State::Idle(Some(inner)) + })); + } + } + // Poll the asynchronous operation the stdout is currently blocked on. + State::Busy(task) => *state = futures::ready!(Pin::new(task).poll(cx)), + } + } + } + + fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.poll_flush(cx) + } +} + +cfg_if! { + if #[cfg(feature = "docs.rs")] { + use crate::os::unix::io::{AsRawFd, RawFd}; + use crate::os::windows::io::{AsRawHandle, RawHandle}; + } else if #[cfg(unix)] { + use std::os::unix::io::{AsRawFd, RawFd}; + } else if #[cfg(windows)] { + use std::os::windows::io::{AsRawHandle, RawHandle}; + } +} + +#[cfg_attr(feature = "docs.rs", doc(cfg(unix)))] +cfg_if! { + if #[cfg(any(unix, feature = "docs.rs"))] { + impl AsRawFd for Stdout { + fn as_raw_fd(&self) -> RawFd { + io::stdout().as_raw_fd() + } + } + } +} + +#[cfg_attr(feature = "docs.rs", doc(cfg(unix)))] +cfg_if! { + if #[cfg(any(windows, feature = "docs.rs"))] { + impl AsRawHandle for Stdout { + fn as_raw_handle(&self) -> RawHandle { + io::stdout().as_raw_handle() + } + } + } +} diff --git a/src/lib.rs b/src/lib.rs new file mode 100644 index 00000000..5315cabe --- /dev/null +++ b/src/lib.rs @@ -0,0 +1,38 @@ +//! Asynchronous standard library. +//! +//! This crate is an async version of [`std`]. +//! +//! [`std`]: https://doc.rust-lang.org/std/index.html +//! +//! # Examples +//! +//! Spawn a task and block the current thread on its result: +//! +//! ``` +//! # #![feature(async_await)] +//! use async_std::task; +//! +//! fn main() { +//! task::block_on(async { +//! println!("Hello, world!"); +//! }) +//! } +//! ``` + +#![feature(async_await)] +#![cfg_attr(feature = "docs.rs", feature(doc_cfg))] +#![warn(missing_docs, missing_debug_implementations, rust_2018_idioms)] +#![doc(html_playground_url = "https://play.rust-lang.org")] + +pub mod fs; +pub mod future; +pub mod io; +pub mod net; +pub mod os; +pub mod prelude; +pub mod stream; +pub mod sync; +pub mod task; +pub mod time; + +pub(crate) mod utils; diff --git a/src/net/driver.rs b/src/net/driver.rs new file mode 100644 index 00000000..337cb4e4 --- /dev/null +++ b/src/net/driver.rs @@ -0,0 +1,431 @@ +use std::fmt; +use std::io::{self, prelude::*}; +use std::pin::Pin; +use std::sync::atomic::{AtomicUsize, Ordering}; +use std::sync::{Arc, Mutex}; +use std::task::{Context, Poll, Waker}; + +use futures::{prelude::*, ready}; +use lazy_static::lazy_static; +use mio::{self, Evented}; +use slab::Slab; + +use crate::utils::abort_on_panic; + +/// Data associated with a registered I/O handle. +#[derive(Debug)] +struct Entry { + /// A unique identifier. + token: mio::Token, + + /// Indicates whether this I/O handle is ready for reading, writing, or if it is disconnected. + readiness: AtomicUsize, + + /// Tasks that are blocked on reading from this I/O handle. + readers: Mutex>, + + /// Thasks that are blocked on writing to this I/O handle. + writers: Mutex>, +} + +/// The state of a networking driver. +struct Reactor { + /// A mio instance that polls for new events. + poller: mio::Poll, + + /// A collection of registered I/O handles. + entries: Mutex>>, + + /// Dummy I/O handle that is only used to wake up the polling thread. + notify_reg: (mio::Registration, mio::SetReadiness), + + /// An identifier for the notification handle. + notify_token: mio::Token, +} + +impl Reactor { + /// Creates a new reactor for polling I/O events. + fn new() -> io::Result { + let poller = mio::Poll::new()?; + let notify_reg = mio::Registration::new2(); + + let mut reactor = Reactor { + poller, + entries: Mutex::new(Slab::new()), + notify_reg, + notify_token: mio::Token(0), + }; + + // Register a dummy I/O handle for waking up the polling thread. + let entry = reactor.register(&reactor.notify_reg.0)?; + reactor.notify_token = entry.token; + + Ok(reactor) + } + + /// Registers an I/O event source and returns its associated entry. + fn register(&self, source: &dyn Evented) -> io::Result> { + let mut entries = self.entries.lock().unwrap(); + + // Reserve a vacant spot in the slab and use its key as the token value. + let vacant = entries.vacant_entry(); + let token = mio::Token(vacant.key()); + + // Allocate an entry and insert it into the slab. + let entry = Arc::new(Entry { + token, + readiness: AtomicUsize::new(mio::Ready::empty().as_usize()), + readers: Mutex::new(Vec::new()), + writers: Mutex::new(Vec::new()), + }); + vacant.insert(entry.clone()); + + // Register the I/O event source in the poller. + let interest = mio::Ready::all(); + let opts = mio::PollOpt::edge(); + self.poller.register(source, token, interest, opts)?; + + Ok(entry) + } + + /// Deregisters an I/O event source associated with an entry. + fn deregister(&self, source: &dyn Evented, entry: &Entry) -> io::Result<()> { + // Deregister the I/O object from the mio instance. + self.poller.deregister(source)?; + + // Remove the entry associated with the I/O object. + self.entries.lock().unwrap().remove(entry.token.0); + + Ok(()) + } + + // fn notify(&self) { + // self.notify_reg + // .1 + // .set_readiness(mio::Ready::readable()) + // .unwrap(); + // } +} + +lazy_static! { + /// The state of the global networking driver. + static ref REACTOR: Reactor = { + // Spawn a thread that waits on the poller for new events and wakes up tasks blocked on I/O + // handles. + std::thread::Builder::new() + .name("async-net-driver".to_string()) + .spawn(move || { + // If the driver thread panics, there's not much we can do. It is not a + // recoverable error and there is no place to propagate it into so we just abort. + abort_on_panic(|| { + main_loop().expect("async networking thread has panicked"); + }) + }) + .expect("cannot start a thread driving blocking tasks"); + + Reactor::new().expect("cannot initialize reactor") + }; +} + +/// Waits on the poller for new events and wakes up tasks blocked on I/O handles. +fn main_loop() -> io::Result<()> { + let reactor = &REACTOR; + let mut events = mio::Events::with_capacity(1000); + + loop { + // Block on the poller until at least one new event comes in. + reactor.poller.poll(&mut events, None)?; + + // Lock the entire entry table while we're processing new events. + let entries = reactor.entries.lock().unwrap(); + + for event in events.iter() { + let token = event.token(); + + if token == reactor.notify_token { + // If this is the notification token, we just need the notification state. + reactor.notify_reg.1.set_readiness(mio::Ready::empty())?; + } else { + // Otherwise, look for the entry associated with this token. + if let Some(entry) = entries.get(token.0) { + // Set the readiness flags from this I/O event. + let readiness = event.readiness(); + entry + .readiness + .fetch_or(readiness.as_usize(), Ordering::SeqCst); + + // Wake up reader tasks blocked on this I/O handle. + if !(readiness & reader_interests()).is_empty() { + for w in entry.readers.lock().unwrap().drain(..) { + w.wake(); + } + } + + // Wake up writer tasks blocked on this I/O handle. + if !(readiness & writer_interests()).is_empty() { + for w in entry.writers.lock().unwrap().drain(..) { + w.wake(); + } + } + } + } + } + } +} + +/// An I/O handle powered by the networking driver. +/// +/// This handle wraps an I/O event source and exposes a "futurized" interface on top of it, +/// implementing traits `AsyncRead` and `AsyncWrite`. +pub struct IoHandle { + /// Data associated with the I/O handle. + entry: Arc, + + /// The I/O event source. + source: T, +} + +impl IoHandle { + /// Creates a new I/O handle. + /// + /// The provided I/O event source will be kept registered inside the reactor's poller for the + /// lifetime of the returned I/O handle. + pub fn new(source: T) -> IoHandle { + IoHandle { + entry: REACTOR + .register(&source) + .expect("cannot register an I/O event source"), + source, + } + } + + /// Returns a reference to the inner I/O event source. + pub fn get_ref(&self) -> &T { + &self.source + } + + /// Polls the I/O handle for reading. + /// + /// If reading from the I/O handle would block, `Poll::Pending` will be returned. + pub fn poll_readable(&self, cx: &mut Context<'_>) -> Poll> { + let mask = reader_interests(); + let mut readiness = mio::Ready::from_usize(self.entry.readiness.load(Ordering::SeqCst)); + + if (readiness & mask).is_empty() { + self.entry.readers.lock().unwrap().push(cx.waker().clone()); + readiness = mio::Ready::from_usize(self.entry.readiness.fetch_or(0, Ordering::SeqCst)); + } + + if (readiness & mask).is_empty() { + Poll::Pending + } else { + Poll::Ready(Ok(())) + } + } + + /// Clears the readability status. + /// + /// This method is usually called when an attempt at reading from the OS-level I/O handle + /// returns `io::ErrorKind::WouldBlock`. + pub fn clear_readable(&self, cx: &mut Context<'_>) -> io::Result<()> { + let mask = reader_interests() - hup(); + self.entry + .readiness + .fetch_and(!mask.as_usize(), Ordering::SeqCst); + + if self.poll_readable(cx)?.is_ready() { + // Wake the current task. + cx.waker().wake_by_ref(); + } + + Ok(()) + } + + /// Polls the I/O handle for writing. + /// + /// If writing into the I/O handle would block, `Poll::Pending` will be returned. + pub fn poll_writable(&self, cx: &mut Context<'_>) -> Poll> { + let mask = writer_interests(); + let mut readiness = mio::Ready::from_usize(self.entry.readiness.load(Ordering::SeqCst)); + + if (readiness & mask).is_empty() { + self.entry.writers.lock().unwrap().push(cx.waker().clone()); + readiness = mio::Ready::from_usize(self.entry.readiness.fetch_or(0, Ordering::SeqCst)); + } + + if (readiness & mask).is_empty() { + Poll::Pending + } else { + Poll::Ready(Ok(())) + } + } + + /// Clears the writability status. + /// + /// This method is usually called when an attempt at writing from the OS-level I/O handle + /// returns `io::ErrorKind::WouldBlock`. + pub fn clear_writable(&self, cx: &mut Context<'_>) -> io::Result<()> { + let mask = writer_interests() - hup(); + self.entry + .readiness + .fetch_and(!mask.as_usize(), Ordering::SeqCst); + + if self.poll_writable(cx)?.is_ready() { + // Wake the current task. + cx.waker().wake_by_ref(); + } + + Ok(()) + } +} + +impl Drop for IoHandle { + fn drop(&mut self) { + REACTOR + .deregister(&self.source, &self.entry) + .expect("cannot deregister I/O event source"); + } +} + +impl fmt::Debug for IoHandle { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("IoHandle") + .field("entry", &self.entry) + .field("source", &self.source) + .finish() + } +} + +impl AsyncRead for IoHandle { + fn poll_read( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut [u8], + ) -> Poll> { + ready!(Pin::new(&mut *self).poll_readable(cx)?); + + match self.source.read(buf) { + Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => { + self.clear_readable(cx)?; + Poll::Pending + } + res => Poll::Ready(res), + } + } +} + +impl<'a, T: Evented + Unpin> AsyncRead for &'a IoHandle +where + &'a T: Read, +{ + fn poll_read( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut [u8], + ) -> Poll> { + ready!(Pin::new(&mut *self).poll_readable(cx)?); + + match (&self.source).read(buf) { + Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => { + self.clear_readable(cx)?; + Poll::Pending + } + res => Poll::Ready(res), + } + } +} + +impl AsyncWrite for IoHandle { + fn poll_write( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + ready!(self.poll_writable(cx)?); + + match self.source.write(buf) { + Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => { + self.clear_writable(cx)?; + Poll::Pending + } + res => Poll::Ready(res), + } + } + + fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + ready!(self.poll_writable(cx)?); + + match self.source.flush() { + Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => { + self.clear_writable(cx)?; + Poll::Pending + } + res => Poll::Ready(res), + } + } + + fn poll_close(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } +} + +impl<'a, T: Evented + Unpin> AsyncWrite for &'a IoHandle +where + &'a T: Write, +{ + fn poll_write( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + ready!(self.poll_writable(cx)?); + + match (&self.source).write(buf) { + Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => { + self.clear_writable(cx)?; + Poll::Pending + } + res => Poll::Ready(res), + } + } + + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + ready!(self.poll_writable(cx)?); + + match (&self.source).flush() { + Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => { + self.clear_writable(cx)?; + Poll::Pending + } + res => Poll::Ready(res), + } + } + + fn poll_close(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } +} + +/// Returns a mask containing flags that interest tasks reading from I/O handles. +#[inline] +fn reader_interests() -> mio::Ready { + mio::Ready::all() - mio::Ready::writable() +} + +/// Returns a mask containing flags that interest tasks writing into I/O handles. +#[inline] +fn writer_interests() -> mio::Ready { + mio::Ready::writable() | hup() +} + +/// Returns a flag containing the hangup status. +#[inline] +fn hup() -> mio::Ready { + #[cfg(unix)] + let ready = mio::unix::UnixReady::hup().into(); + + #[cfg(not(unix))] + let ready = mio::Ready::empty(); + + ready +} diff --git a/src/net/mod.rs b/src/net/mod.rs new file mode 100644 index 00000000..dcac0bbc --- /dev/null +++ b/src/net/mod.rs @@ -0,0 +1,35 @@ +//! Networking primitives for TCP/UDP communication. +//! +//! For OS-specific networking primitives like Unix domain sockets, refer to the [`async_std::os`] +//! module. +//! +//! This module is an async version of [`std::net`]. +//! +//! [`async_std::os`]: ../os/index.html +//! [`std::net`]: https://doc.rust-lang.org/std/net/index.html +//! +//! ## Examples +//! +//! A simple UDP echo server: +//! +//! ```no_run +//! # #![feature(async_await)] +//! use async_std::net::UdpSocket; +//! +//! # futures::executor::block_on(async { +//! let socket = UdpSocket::bind("127.0.0.1:8080").await?; +//! let mut buf = vec![0u8; 1024]; +//! loop { +//! let (n, peer) = socket.recv_from(&mut buf).await?; +//! socket.send_to(&buf[..n], &peer).await?; +//! } +//! # std::io::Result::Ok(()) +//! # }).unwrap(); +//! ``` + +pub use tcp::{Incoming, TcpListener, TcpStream}; +pub use udp::UdpSocket; + +pub(crate) mod driver; +mod tcp; +mod udp; diff --git a/src/net/tcp.rs b/src/net/tcp.rs new file mode 100644 index 00000000..1b2499a9 --- /dev/null +++ b/src/net/tcp.rs @@ -0,0 +1,807 @@ +use std::io::{self, IoSlice, IoSliceMut}; +use std::mem; +use std::net::{self, SocketAddr, ToSocketAddrs}; +use std::pin::Pin; +use std::task::{Context, Poll}; + +use cfg_if::cfg_if; +use futures::{prelude::*, ready}; + +use crate::net::driver::IoHandle; + +/// A TCP stream between a local and a remote socket. +/// +/// A `TcpStream` can either be created by connecting to an endpoint, via the [`connect`] method, +/// or by [accepting] a connection from a [listener]. It can be read or written to using the +/// [`AsyncRead`], [`AsyncWrite`], and related extension traits in [`futures::io`]. +/// +/// The connection will be closed when the value is dropped. The reading and writing portions of +/// the connection can also be shut down individually with the [`shutdown`] method. +/// +/// This type is an async version of [`std::net::TcpStream`]. +/// +/// [`connect`]: struct.TcpStream.html#method.connect +/// [accepting]: struct.TcpListener.html#method.accept +/// [listener]: struct.TcpListener.html +/// [`AsyncRead`]: https://docs.rs/futures-preview/0.3.0-alpha.13/futures/io/trait.AsyncRead.html +/// [`AsyncWrite`]: https://docs.rs/futures-preview/0.3.0-alpha.13/futures/io/trait.AsyncRead.html +/// [`futures::io`]: https://docs.rs/futures-preview/0.3.0-alpha.13/futures/io +/// [`shutdown`]: struct.TcpStream.html#method.shutdown +/// [`std::net::TcpStream`]: https://doc.rust-lang.org/std/net/struct.TcpStream.html +/// +/// ## Examples +/// +/// ```no_run +/// # #![feature(async_await)] +/// use async_std::net::TcpStream; +/// use async_std::prelude::*; +/// +/// # futures::executor::block_on(async { +/// let mut stream = TcpStream::connect("127.0.0.1:8080").await?; +/// println!("Connected to {}", &stream.peer_addr()?); +/// +/// let msg = "hello world"; +/// println!("<- {}", msg); +/// stream.write_all(msg.as_bytes()).await?; +/// +/// let mut buf = vec![0u8; 1024]; +/// let n = stream.read(&mut buf).await?; +/// println!("-> {}\n", std::str::from_utf8(&buf[..n])?); +/// # Ok::<_, Box>(()) +/// # }).unwrap(); +/// ``` +#[derive(Debug)] +pub struct TcpStream { + io_handle: IoHandle, + + #[cfg(unix)] + raw_fd: std::os::unix::io::RawFd, + // #[cfg(windows)] + // raw_socket: std::os::windows::io::RawSocket, +} + +impl TcpStream { + /// Creates a new TCP stream connected to the specified address. + /// + /// This method will create a new TCP socket and attempt to connect it to the `addr` + /// provided. The [returned future] will be resolved once the stream has successfully + /// connected, or it will return an error if one occurs. + /// + /// [returned future]: struct.Connect.html + /// + /// # Examples + /// + /// ```no_run + /// # #![feature(async_await)] + /// use async_std::net::TcpStream; + /// + /// # futures::executor::block_on(async { + /// let stream = TcpStream::connect("127.0.0.1:0").await?; + /// # std::io::Result::Ok(()) + /// # }).unwrap(); + /// ``` + pub async fn connect(addrs: A) -> io::Result { + enum State { + Waiting(TcpStream), + Error(io::Error), + Done, + } + + let mut last_err = None; + + for addr in addrs.to_socket_addrs()? { + let mut state = { + match mio::net::TcpStream::connect(&addr) { + Ok(mio_stream) => { + #[cfg(unix)] + let stream = TcpStream { + raw_fd: mio_stream.as_raw_fd(), + io_handle: IoHandle::new(mio_stream), + }; + + #[cfg(windows)] + let stream = TcpStream { + // raw_socket: mio_stream.as_raw_socket(), + io_handle: IoHandle::new(mio_stream), + }; + + State::Waiting(stream) + } + Err(err) => State::Error(err), + } + }; + + let res = future::poll_fn(|cx| { + match mem::replace(&mut state, State::Done) { + State::Waiting(stream) => { + // Once we've connected, wait for the stream to be writable as that's when + // the actual connection has been initiated. Once we're writable we check + // for `take_socket_error` to see if the connect actually hit an error or + // not. + // + // If all that succeeded then we ship everything on up. + if let Poll::Pending = stream.io_handle.poll_writable(cx)? { + state = State::Waiting(stream); + return Poll::Pending; + } + + if let Some(err) = stream.io_handle.get_ref().take_error()? { + return Poll::Ready(Err(err)); + } + + Poll::Ready(Ok(stream)) + } + State::Error(err) => Poll::Ready(Err(err)), + State::Done => panic!("`TcpStream::connect()` future polled after completion"), + } + }) + .await; + + match res { + Ok(stream) => return Ok(stream), + Err(err) => last_err = Some(err), + } + } + + Err(last_err.unwrap_or_else(|| { + io::Error::new( + io::ErrorKind::InvalidInput, + "could not resolve to any addresses", + ) + })) + } + + /// Returns the local address that this stream is connected to. + /// + /// ## Examples + /// + /// ```no_run + /// # #![feature(async_await)] + /// use async_std::net::TcpStream; + /// use std::net::{IpAddr, Ipv4Addr}; + /// + /// # futures::executor::block_on(async { + /// let stream = TcpStream::connect("127.0.0.1:8080").await?; + /// + /// let expected = IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)); + /// assert_eq!(stream.local_addr()?.ip(), expected); + /// # std::io::Result::Ok(()) + /// # }).unwrap(); + /// ``` + pub fn local_addr(&self) -> io::Result { + self.io_handle.get_ref().local_addr() + } + + /// Returns the remote address that this stream is connected to. + /// + /// ## Examples + /// + /// ```no_run + /// # #![feature(async_await)] + /// use async_std::net::TcpStream; + /// use std::net::{IpAddr, Ipv4Addr}; + /// + /// # futures::executor::block_on(async { + /// let stream = TcpStream::connect("127.0.0.1:8080").await?; + /// + /// let expected = IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)); + /// assert_eq!(stream.peer_addr()?.ip(), expected); + /// # std::io::Result::Ok(()) + /// # }).unwrap(); + /// ``` + pub fn peer_addr(&self) -> io::Result { + self.io_handle.get_ref().peer_addr() + } + + /// Gets the value of the `IP_TTL` option for this socket. + /// + /// For more information about this option, see [`set_ttl`]. + /// + /// [`set_ttl`]: #method.set_ttl + /// + /// # Examples + /// + /// ```no_run + /// # #![feature(async_await)] + /// use async_std::net::TcpStream; + /// + /// # futures::executor::block_on(async { + /// let stream = TcpStream::connect("127.0.0.1:8080").await?; + /// + /// stream.set_ttl(100)?; + /// assert_eq!(stream.ttl()?, 100); + /// # std::io::Result::Ok(()) + /// # }).unwrap(); + /// ``` + pub fn ttl(&self) -> io::Result { + self.io_handle.get_ref().ttl() + } + + /// Sets the value for the `IP_TTL` option on this socket. + /// + /// This value sets the time-to-live field that is used in every packet sent + /// from this socket. + /// + /// # Examples + /// + /// ```no_run + /// # #![feature(async_await)] + /// use async_std::net::TcpStream; + /// + /// # futures::executor::block_on(async { + /// let stream = TcpStream::connect("127.0.0.1:8080").await?; + /// + /// stream.set_ttl(100)?; + /// assert_eq!(stream.ttl()?, 100); + /// # std::io::Result::Ok(()) + /// # }).unwrap(); + /// ``` + pub fn set_ttl(&self, ttl: u32) -> io::Result<()> { + self.io_handle.get_ref().set_ttl(ttl) + } + + /// Receives data on the socket from the remote address to which it is connected, without + /// removing that data from the queue. On success, returns the number of bytes peeked. + /// + /// Successive calls return the same data. This is accomplished by passing `MSG_PEEK` as a flag + /// to the underlying `recv` system call. + /// + /// # Examples + /// + /// ```no_run + /// # #![feature(async_await)] + /// use async_std::net::TcpStream; + /// + /// # futures::executor::block_on(async { + /// let stream = TcpStream::connect("127.0.0.1:8000").await?; + /// + /// let mut buf = [0; 10]; + /// let len = stream.peek(&mut buf).await?; + /// # std::io::Result::Ok(()) + /// # }).unwrap(); + /// ``` + pub async fn peek(&self, buf: &mut [u8]) -> io::Result { + let res = future::poll_fn(|cx| { + ready!(self.io_handle.poll_readable(cx)?); + match self.io_handle.get_ref().peek(buf) { + Ok(len) => Poll::Ready(Ok(len)), + Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => { + self.io_handle.clear_readable(cx)?; + Poll::Pending + } + Err(e) => Poll::Ready(Err(e)), + } + }) + .await?; + Ok(res) + } + + /// Gets the value of the `TCP_NODELAY` option on this socket. + /// + /// For more information about this option, see [`set_nodelay`]. + /// + /// [`set_nodelay`]: #method.set_nodelay + /// + /// # Examples + /// + /// ```no_run + /// # #![feature(async_await)] + /// use async_std::net::TcpStream; + /// + /// # futures::executor::block_on(async { + /// let stream = TcpStream::connect("127.0.0.1:8080").await?; + /// + /// stream.set_nodelay(true)?; + /// assert_eq!(stream.nodelay()?, true); + /// # std::io::Result::Ok(()) + /// # }).unwrap(); + /// ``` + pub fn nodelay(&self) -> io::Result { + self.io_handle.get_ref().nodelay() + } + + /// Sets the value of the `TCP_NODELAY` option on this socket. + /// + /// If set, this option disables the Nagle algorithm. This means that + /// segments are always sent as soon as possible, even if there is only a + /// small amount of data. When not set, data is buffered until there is a + /// sufficient amount to send out, thereby avoiding the frequent sending of + /// small packets. + /// + /// # Examples + /// + /// ```no_run + /// # #![feature(async_await)] + /// use async_std::net::TcpStream; + /// + /// # futures::executor::block_on(async { + /// let stream = TcpStream::connect("127.0.0.1:8080").await?; + /// + /// stream.set_nodelay(true)?; + /// assert_eq!(stream.nodelay()?, true); + /// # std::io::Result::Ok(()) + /// # }).unwrap(); + /// ``` + pub fn set_nodelay(&self, nodelay: bool) -> io::Result<()> { + self.io_handle.get_ref().set_nodelay(nodelay) + } + + /// Shuts down the read, write, or both halves of this connection. + /// + /// This method will cause all pending and future I/O on the specified portions to return + /// immediately with an appropriate value (see the documentation of [`Shutdown`]). + /// + /// [`Shutdown`]: https://doc.rust-lang.org/std/net/enum.Shutdown.html + /// + /// # Examples + /// + /// ```no_run + /// # #![feature(async_await)] + /// use async_std::net::TcpStream; + /// use std::net::Shutdown; + /// + /// # futures::executor::block_on(async { + /// let stream = TcpStream::connect("127.0.0.1:8080").await?; + /// stream.shutdown(Shutdown::Both)?; + /// # std::io::Result::Ok(()) + /// # }).unwrap(); + /// ``` + pub fn shutdown(&self, how: std::net::Shutdown) -> std::io::Result<()> { + self.io_handle.get_ref().shutdown(how) + } +} + +impl AsyncRead for TcpStream { + fn poll_read( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut [u8], + ) -> Poll> { + Pin::new(&mut &*self).poll_read(cx, buf) + } + + fn poll_read_vectored( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + bufs: &mut [IoSliceMut<'_>], + ) -> Poll> { + Pin::new(&mut &*self).poll_read_vectored(cx, bufs) + } +} + +impl AsyncRead for &TcpStream { + fn poll_read( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut [u8], + ) -> Poll> { + Pin::new(&mut &self.io_handle).poll_read(cx, buf) + } + + fn poll_read_vectored( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + bufs: &mut [IoSliceMut<'_>], + ) -> Poll> { + Pin::new(&mut &self.io_handle).poll_read_vectored(cx, bufs) + } +} + +impl AsyncWrite for TcpStream { + fn poll_write( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + Pin::new(&mut &*self).poll_write(cx, buf) + } + + fn poll_write_vectored( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + bufs: &[IoSlice<'_>], + ) -> Poll> { + Pin::new(&mut &*self).poll_write_vectored(cx, bufs) + } + + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + Pin::new(&mut &*self).poll_flush(cx) + } + + fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + Pin::new(&mut &*self).poll_close(cx) + } +} + +impl AsyncWrite for &TcpStream { + fn poll_write( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + Pin::new(&mut &self.io_handle).poll_write(cx, buf) + } + + fn poll_write_vectored( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + bufs: &[IoSlice<'_>], + ) -> Poll> { + Pin::new(&mut &self.io_handle).poll_write_vectored(cx, bufs) + } + + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + Pin::new(&mut &self.io_handle).poll_flush(cx) + } + + fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + Pin::new(&mut &self.io_handle).poll_close(cx) + } +} + +/// A TCP socket server, listening for connections. +/// +/// After creating a `TcpListener` by [`bind`]ing it to a socket address, it listens for incoming +/// TCP connections. These can be accepted by awaiting elements from the async stream of +/// [`incoming`] connections. +/// +/// The socket will be closed when the value is dropped. +/// +/// The Transmission Control Protocol is specified in [IETF RFC 793]. +/// +/// This type is an async version of [`std::net::TcpListener`]. +/// +/// [`bind`]: #method.bind +/// [`incoming`]: #method.incoming +/// [IETF RFC 793]: https://tools.ietf.org/html/rfc793 +/// [`std::net::TcpListener`]: https://doc.rust-lang.org/std/net/struct.TcpListener.html +/// +/// # Examples +/// +/// ```no_run +/// # #![feature(async_await)] +/// use async_std::io; +/// use async_std::net::TcpListener; +/// use async_std::prelude::*; +/// +/// # futures::executor::block_on(async { +/// let listener = TcpListener::bind("127.0.0.1:8080").await?; +/// println!("Listening on {}", listener.local_addr()?); +/// +/// let mut incoming = listener.incoming(); +/// while let Some(stream) = incoming.next().await { +/// let stream = stream?; +/// println!("Accepting from: {}", stream.peer_addr()?); +/// +/// let (reader, writer) = &mut (&stream, &stream); +/// io::copy(reader, writer).await?; +/// } +/// # std::io::Result::Ok(()) +/// # }).unwrap(); +/// ``` +#[derive(Debug)] +pub struct TcpListener { + io_handle: IoHandle, + + #[cfg(unix)] + raw_fd: std::os::unix::io::RawFd, + // #[cfg(windows)] + // raw_socket: std::os::windows::io::RawSocket, +} + +impl TcpListener { + /// Creates a new `TcpListener` which will be bound to the specified address. + /// + /// The returned listener is ready for accepting connections. + /// + /// Binding with a port number of 0 will request that the OS assigns a port to this listener. + /// The port allocated can be queried via the [`local_addr`] method. + /// + /// # Examples + /// Create a TCP listener bound to 127.0.0.1:0: + /// + /// ```no_run + /// # #![feature(async_await)] + /// use async_std::net::TcpListener; + /// + /// # futures::executor::block_on(async { + /// let listener = TcpListener::bind("127.0.0.1:0").await?; + /// # std::io::Result::Ok(()) + /// # }).unwrap(); + /// ``` + /// + /// [`local_addr`]: #method.local_addr + pub async fn bind(addrs: A) -> io::Result { + let mut last_err = None; + + for addr in addrs.to_socket_addrs()? { + match mio::net::TcpListener::bind(&addr) { + Ok(mio_listener) => { + #[cfg(unix)] + let listener = TcpListener { + raw_fd: mio_listener.as_raw_fd(), + io_handle: IoHandle::new(mio_listener), + }; + + #[cfg(windows)] + let listener = TcpListener { + // raw_socket: mio_listener.as_raw_socket(), + io_handle: IoHandle::new(mio_listener), + }; + return Ok(listener); + } + Err(err) => last_err = Some(err), + } + } + + Err(last_err.unwrap_or_else(|| { + io::Error::new( + io::ErrorKind::InvalidInput, + "could not resolve to any addresses", + ) + })) + } + + /// Accepts a new incoming connection to this listener. + /// + /// When a connection is established, the corresponding stream and address will be returned. + /// + /// ## Examples + /// + /// ```no_run + /// # #![feature(async_await)] + /// use async_std::net::TcpListener; + /// + /// # futures::executor::block_on(async { + /// let listener = TcpListener::bind("127.0.0.1:0").await?; + /// let (stream, addr) = listener.accept().await?; + /// # std::io::Result::Ok(()) + /// # }).unwrap(); + /// ``` + pub async fn accept(&self) -> io::Result<(TcpStream, SocketAddr)> { + future::poll_fn(|cx| { + ready!(self.io_handle.poll_readable(cx)?); + + match self.io_handle.get_ref().accept_std() { + Ok((io, addr)) => { + let mio_stream = mio::net::TcpStream::from_stream(io)?; + + #[cfg(unix)] + let stream = TcpStream { + raw_fd: mio_stream.as_raw_fd(), + io_handle: IoHandle::new(mio_stream), + }; + + #[cfg(windows)] + let stream = TcpStream { + // raw_socket: mio_stream.as_raw_socket(), + io_handle: IoHandle::new(mio_stream), + }; + + Poll::Ready(Ok((stream, addr))) + } + Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => { + self.io_handle.clear_readable(cx)?; + Poll::Pending + } + Err(err) => Poll::Ready(Err(err)), + } + }) + .await + } + + /// Returns a stream of incoming connections. + /// + /// Iterating over this stream is equivalent to calling [`accept`] in a loop. The stream of + /// connections is infinite, i.e awaiting the next connection will never result in [`None`]. + /// + /// [`accept`]: #method.accept + /// [`None`]: https://doc.rust-lang.org/std/option/enum.Option.html#variant.None + /// + /// ## Examples + /// + /// ```no_run + /// # #![feature(async_await)] + /// use async_std::net::TcpListener; + /// use async_std::prelude::*; + /// + /// # futures::executor::block_on(async { + /// let listener = TcpListener::bind("127.0.0.1:0").await?; + /// let mut incoming = listener.incoming(); + /// + /// while let Some(stream) = incoming.next().await { + /// let mut stream = stream?; + /// stream.write_all(b"hello world").await?; + /// } + /// # std::io::Result::Ok(()) + /// # }).unwrap(); + /// ``` + pub fn incoming(&self) -> Incoming<'_> { + Incoming(self) + } + + /// Returns the local address that this listener is bound to. + /// + /// This can be useful, for example, to identify when binding to port 0 which port was assigned + /// by the OS. + /// + /// # Examples + /// + /// ```no_run + /// # #![feature(async_await)] + /// use async_std::net::TcpListener; + /// use std::net::{Ipv4Addr, SocketAddr, SocketAddrV4}; + /// + /// # futures::executor::block_on(async { + /// let listener = TcpListener::bind("127.0.0.1:8080").await?; + /// + /// let expected = SocketAddrV4::new(Ipv4Addr::new(127, 0, 0, 1), 8080); + /// assert_eq!(listener.local_addr()?, SocketAddr::V4(expected)); + /// # std::io::Result::Ok(()) + /// # }).unwrap(); + /// ``` + pub fn local_addr(&self) -> io::Result { + self.io_handle.get_ref().local_addr() + } +} + +/// A stream of incoming TCP connections. +/// +/// This stream is infinite, i.e awaiting the next connection will never result in [`None`]. It is +/// created by the [`incoming`] method on [`TcpListener`]. +/// +/// This type is an async version of [`std::net::Incoming`]. +/// +/// [`None`]: https://doc.rust-lang.org/std/option/enum.Option.html#variant.None +/// [`incoming`]: struct.TcpListener.html#method.incoming +/// [`TcpListener`]: struct.TcpListener.html +/// [`std::net::Incoming`]: https://doc.rust-lang.org/std/net/struct.Incoming.html +#[derive(Debug)] +pub struct Incoming<'a>(&'a TcpListener); + +impl<'a> Stream for Incoming<'a> { + type Item = io::Result; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let future = self.0.accept(); + pin_utils::pin_mut!(future); + + let (socket, _) = ready!(future.poll(cx))?; + Poll::Ready(Some(Ok(socket))) + } +} + +impl From for TcpStream { + /// Converts a `std::net::TcpStream` into its asynchronous equivalent. + fn from(stream: net::TcpStream) -> TcpStream { + let mio_stream = mio::net::TcpStream::from_stream(stream).unwrap(); + + #[cfg(unix)] + let stream = TcpStream { + raw_fd: mio_stream.as_raw_fd(), + io_handle: IoHandle::new(mio_stream), + }; + + #[cfg(windows)] + let stream = TcpStream { + // raw_socket: mio_stream.as_raw_socket(), + io_handle: IoHandle::new(mio_stream), + }; + + stream + } +} + +impl From for TcpListener { + /// Converts a `std::net::TcpListener` into its asynchronous equivalent. + fn from(listener: net::TcpListener) -> TcpListener { + let mio_listener = mio::net::TcpListener::from_std(listener).unwrap(); + + #[cfg(unix)] + let listener = TcpListener { + raw_fd: mio_listener.as_raw_fd(), + io_handle: IoHandle::new(mio_listener), + }; + + #[cfg(windows)] + let listener = TcpListener { + // raw_socket: mio_listener.as_raw_socket(), + io_handle: IoHandle::new(mio_listener), + }; + + listener + } +} + +cfg_if! { + if #[cfg(feature = "docs.rs")] { + use crate::os::unix::io::{AsRawFd, FromRawFd, IntoRawFd, RawFd}; + // use crate::os::windows::io::{AsRawHandle, FromRawHandle, IntoRawHandle, RawHandle}; + } else if #[cfg(unix)] { + use std::os::unix::io::{AsRawFd, FromRawFd, IntoRawFd, RawFd}; + } else if #[cfg(windows)] { + // use std::os::windows::io::{AsRawHandle, FromRawHandle, IntoRawHandle, RawHandle}; + } +} + +#[cfg_attr(feature = "docs.rs", doc(cfg(unix)))] +cfg_if! { + if #[cfg(any(unix, feature = "docs.rs"))] { + impl AsRawFd for TcpListener { + fn as_raw_fd(&self) -> RawFd { + self.raw_fd + } + } + + impl FromRawFd for TcpListener { + unsafe fn from_raw_fd(fd: RawFd) -> TcpListener { + net::TcpListener::from_raw_fd(fd).into() + } + } + + impl IntoRawFd for TcpListener { + fn into_raw_fd(self) -> RawFd { + self.raw_fd + } + } + + impl AsRawFd for TcpStream { + fn as_raw_fd(&self) -> RawFd { + self.raw_fd + } + } + + impl FromRawFd for TcpStream { + unsafe fn from_raw_fd(fd: RawFd) -> TcpStream { + net::TcpStream::from_raw_fd(fd).into() + } + } + + impl IntoRawFd for TcpStream { + fn into_raw_fd(self) -> RawFd { + self.raw_fd + } + } + } +} + +#[cfg_attr(feature = "docs.rs", doc(cfg(windows)))] +cfg_if! { + if #[cfg(any(windows, feature = "docs.rs"))] { + // impl AsRawSocket for TcpListener { + // fn as_raw_socket(&self) -> RawSocket { + // self.raw_socket + // } + // } + // + // impl FromRawSocket for TcpListener { + // unsafe fn from_raw_socket(handle: RawSocket) -> TcpListener { + // net::TcpListener::from_raw_socket(handle).try_into().unwrap() + // } + // } + // + // impl IntoRawSocket for TcpListener { + // fn into_raw_socket(self) -> RawSocket { + // self.raw_socket + // } + // } + // + // impl AsRawSocket for TcpStream { + // fn as_raw_socket(&self) -> RawSocket { + // self.raw_socket + // } + // } + // + // impl FromRawSocket for TcpStream { + // unsafe fn from_raw_socket(handle: RawSocket) -> TcpStream { + // net::TcpStream::from_raw_socket(handle).try_into().unwrap() + // } + // } + // + // impl IntoRawSocket for TcpListener { + // fn into_raw_socket(self) -> RawSocket { + // self.raw_socket + // } + // } + } +} diff --git a/src/net/udp.rs b/src/net/udp.rs new file mode 100644 index 00000000..7e5e17dd --- /dev/null +++ b/src/net/udp.rs @@ -0,0 +1,588 @@ +use std::io; +use std::net::{self, Ipv4Addr, Ipv6Addr, SocketAddr, ToSocketAddrs}; +use std::task::Poll; + +use cfg_if::cfg_if; +use futures::{prelude::*, ready}; + +use crate::net::driver::IoHandle; + +/// A UDP socket. +/// +/// After creating a `UdpSocket` by [`bind`]ing it to a socket address, data can be [sent to] and +/// [received from] any other socket address. +/// +/// As stated in the User Datagram Protocol's specification in [IETF RFC 768], UDP is an unordered, +/// unreliable protocol. Refer to [`TcpListener`] and [`TcpStream`] for async TCP primitives. +/// +/// This type is an async version of [`std::net::UdpSocket`]. +/// +/// [`bind`]: #method.bind +/// [received from]: #method.recv_from +/// [sent to]: #method.send_to +/// [`TcpListener`]: struct.TcpListener.html +/// [`TcpStream`]: struct.TcpStream.html +/// [`std::net`]: https://doc.rust-lang.org/std/net/index.html +/// [IETF RFC 768]: https://tools.ietf.org/html/rfc768 +/// [`std::net::UdpSocket`]: https://doc.rust-lang.org/std/net/struct.UdpSocket.html +/// +/// ## Examples +/// +/// ```no_run +/// # #![feature(async_await)] +/// use async_std::net::UdpSocket; +/// +/// # futures::executor::block_on(async { +/// let socket = UdpSocket::bind("127.0.0.1:8080").await?; +/// let mut buf = vec![0u8; 1024]; +/// +/// println!("Listening on {}", socket.local_addr()?); +/// +/// loop { +/// let (n, peer) = socket.recv_from(&mut buf).await?; +/// let sent = socket.send_to(&buf[..n], &peer).await?; +/// println!("Sent {} out of {} bytes to {}", sent, n, peer); +/// } +/// # std::io::Result::Ok(()) +/// # }).unwrap(); +/// ``` +#[derive(Debug)] +pub struct UdpSocket { + io_handle: IoHandle, + + #[cfg(unix)] + raw_fd: std::os::unix::io::RawFd, + // #[cfg(windows)] + // raw_socket: std::os::windows::io::RawSocket, +} + +impl UdpSocket { + /// Creates a UDP socket from the given address. + /// + /// Binding with a port number of 0 will request that the OS assigns a port to this socket. The + /// port allocated can be queried via the [`local_addr`] method. + /// + /// [`local_addr`]: #method.local_addr + /// + /// # Examples + /// + /// ```no_run + /// # #![feature(async_await)] + /// use async_std::net::UdpSocket; + /// + /// # futures::executor::block_on(async { + /// let socket = UdpSocket::bind("127.0.0.1:0").await?; + /// # std::io::Result::Ok(()) + /// # }).unwrap(); + /// ``` + pub async fn bind(addr: A) -> io::Result { + let mut last_err = None; + + for addr in addr.to_socket_addrs()? { + match mio::net::UdpSocket::bind(&addr) { + Ok(mio_socket) => { + #[cfg(unix)] + let socket = UdpSocket { + raw_fd: mio_socket.as_raw_fd(), + io_handle: IoHandle::new(mio_socket), + }; + + #[cfg(windows)] + let socket = UdpSocket { + // raw_socket: mio_socket.as_raw_socket(), + io_handle: IoHandle::new(mio_socket), + }; + + return Ok(socket); + } + Err(err) => last_err = Some(err), + } + } + + Err(last_err.unwrap_or_else(|| { + io::Error::new( + io::ErrorKind::InvalidInput, + "could not resolve to any addresses", + ) + })) + } + + /// Returns the local address that this listener is bound to. + /// + /// This can be useful, for example, when binding to port 0 to figure out which port was + /// actually bound. + /// + /// # Examples + /// + /// ```no_run + /// # #![feature(async_await)] + /// use async_std::net::UdpSocket; + /// + /// # futures::executor::block_on(async { + /// let socket = UdpSocket::bind("127.0.0.1:0").await?; + /// println!("Address: {:?}", socket.local_addr()); + /// # std::io::Result::Ok(()) + /// # }).unwrap(); + /// ``` + pub fn local_addr(&self) -> io::Result { + self.io_handle.get_ref().local_addr() + } + + /// Sends data on the socket to the given address. + /// + /// On success, returns the number of bytes written. + /// + /// # Examples + /// + /// ```no_run + /// # #![feature(async_await)] + /// use async_std::net::UdpSocket; + /// + /// const THE_MERCHANT_OF_VENICE: &[u8] = b" + /// If you prick us, do we not bleed? + /// If you tickle us, do we not laugh? + /// If you poison us, do we not die? + /// And if you wrong us, shall we not revenge? + /// "; + /// + /// # futures::executor::block_on(async { + /// let socket = UdpSocket::bind("127.0.0.1:0").await?; + /// + /// let addr = "127.0.0.1:7878"; + /// let sent = socket.send_to(THE_MERCHANT_OF_VENICE, &addr).await?; + /// println!("Sent {} bytes to {}", sent, addr); + /// # Ok::<_, Box>(()) + /// # }).unwrap(); + /// ``` + pub async fn send_to(&self, buf: &[u8], addrs: A) -> io::Result { + let addr = match addrs.to_socket_addrs()?.next() { + Some(addr) => addr, + None => { + return Err(io::Error::new( + io::ErrorKind::InvalidInput, + "no addresses to send data to", + )); + } + }; + + future::poll_fn(|cx| { + ready!(self.io_handle.poll_writable(cx)?); + + match self.io_handle.get_ref().send_to(buf, &addr) { + Ok(n) => Poll::Ready(Ok(n)), + Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => { + self.io_handle.clear_writable(cx)?; + Poll::Pending + } + Err(err) => Poll::Ready(Err(err)), + } + }) + .await + } + + /// Receives data from the socket. + /// + /// On success, returns the number of bytes read and the origin. + /// + /// # Examples + /// + /// ```no_run + /// # #![feature(async_await)] + /// use async_std::net::UdpSocket; + /// + /// # futures::executor::block_on(async { + /// let socket = UdpSocket::bind("127.0.0.1:0").await?; + /// + /// let mut buf = vec![0; 1024]; + /// let (n, peer) = socket.recv_from(&mut buf).await?; + /// println!("Received {} bytes from {}", n, peer); + /// # Ok::<_, Box>(()) + /// # }).unwrap(); + /// ``` + pub async fn recv_from(&self, buf: &mut [u8]) -> io::Result<(usize, SocketAddr)> { + future::poll_fn(|cx| { + ready!(self.io_handle.poll_readable(cx)?); + + match self.io_handle.get_ref().recv_from(buf) { + Ok(n) => Poll::Ready(Ok(n)), + Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => { + self.io_handle.clear_readable(cx)?; + Poll::Pending + } + Err(err) => Poll::Ready(Err(err)), + } + }) + .await + } + + /// Connects the UDP socket to a remote address. + /// + /// When connected, methods [`send`] and [`recv`] will use the specified address for sending + /// and receiving messages. Additionally, a filter will be applied to [`recv_from`] so that it + /// only receives messages from that same address. + /// + /// [`send`]: #method.send + /// [`recv`]: #method.recv + /// [`recv_from`]: #method.recv_from + /// + /// # Examples + /// + /// ```no_run + /// # #![feature(async_await)] + /// use async_std::net::UdpSocket; + /// + /// # futures::executor::block_on(async { + /// let socket = UdpSocket::bind("127.0.0.1:0").await?; + /// socket.connect("127.0.0.1:8080").await?; + /// # std::io::Result::Ok(()) + /// # }).unwrap(); + /// ``` + pub async fn connect(&self, addrs: A) -> io::Result<()> { + let mut last_err = None; + + for addr in addrs.to_socket_addrs()? { + match self.io_handle.get_ref().connect(addr) { + Ok(()) => return Ok(()), + Err(err) => last_err = Some(err), + } + } + + Err(last_err.unwrap_or_else(|| { + io::Error::new( + io::ErrorKind::InvalidInput, + "could not resolve to any addresses", + ) + })) + } + + /// Sends data on the socket to the given address. + /// + /// On success, returns the number of bytes written. + /// + /// # Examples + /// + /// ```no_run + /// # #![feature(async_await)] + /// use async_std::net::UdpSocket; + /// + /// const THE_MERCHANT_OF_VENICE: &[u8] = b" + /// If you prick us, do we not bleed? + /// If you tickle us, do we not laugh? + /// If you poison us, do we not die? + /// And if you wrong us, shall we not revenge? + /// "; + /// + /// # futures::executor::block_on(async { + /// let socket = UdpSocket::bind("127.0.0.1:0").await?; + /// + /// let addr = "127.0.0.1:7878"; + /// let sent = socket.send_to(THE_MERCHANT_OF_VENICE, &addr).await?; + /// println!("Sent {} bytes to {}", sent, addr); + /// # Ok::<_, Box>(()) + /// # }).unwrap(); + /// ``` + pub async fn send(&self, buf: &[u8]) -> io::Result { + future::poll_fn(|cx| { + ready!(self.io_handle.poll_writable(cx)?); + + match self.io_handle.get_ref().send(buf) { + Ok(n) => Poll::Ready(Ok(n)), + Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => { + self.io_handle.clear_writable(cx)?; + Poll::Pending + } + Err(err) => Poll::Ready(Err(err)), + } + }) + .await + } + + /// Receives data from the socket. + /// + /// On success, returns the number of bytes read and the origin. + /// + /// # Examples + /// + /// ```no_run + /// # #![feature(async_await)] + /// use async_std::net::UdpSocket; + /// + /// # futures::executor::block_on(async { + /// let socket = UdpSocket::bind("127.0.0.1:0").await?; + /// + /// let mut buf = vec![0; 1024]; + /// let (n, peer) = socket.recv_from(&mut buf).await?; + /// println!("Received {} bytes from {}", n, peer); + /// # Ok::<_, Box>(()) + /// # }).unwrap(); + /// ``` + pub async fn recv(&self, buf: &mut [u8]) -> io::Result { + future::poll_fn(|cx| { + ready!(self.io_handle.poll_readable(cx)?); + + match self.io_handle.get_ref().recv(buf) { + Ok(n) => Poll::Ready(Ok(n)), + Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => { + self.io_handle.clear_readable(cx)?; + Poll::Pending + } + Err(err) => Poll::Ready(Err(err)), + } + }) + .await + } + + /// Gets the value of the `SO_BROADCAST` option for this socket. + /// + /// For more information about this option, see [`set_broadcast`]. + /// + /// [`set_broadcast`]: #method.set_broadcast + pub fn broadcast(&self) -> io::Result { + self.io_handle.get_ref().broadcast() + } + + /// Sets the value of the `SO_BROADCAST` option for this socket. + /// + /// When enabled, this socket is allowed to send packets to a broadcast address. + pub fn set_broadcast(&self, on: bool) -> io::Result<()> { + self.io_handle.get_ref().set_broadcast(on) + } + + /// Gets the value of the `IP_MULTICAST_LOOP` option for this socket. + /// + /// For more information about this option, see [`set_multicast_loop_v4`]. + /// + /// [`set_multicast_loop_v4`]: #method.set_multicast_loop_v4 + pub fn multicast_loop_v4(&self) -> io::Result { + self.io_handle.get_ref().multicast_loop_v4() + } + + /// Sets the value of the `IP_MULTICAST_LOOP` option for this socket. + /// + /// If enabled, multicast packets will be looped back to the local socket. + /// + /// # Note + /// + /// This may not have any affect on IPv6 sockets. + pub fn set_multicast_loop_v4(&self, on: bool) -> io::Result<()> { + self.io_handle.get_ref().set_multicast_loop_v4(on) + } + + /// Gets the value of the `IP_MULTICAST_TTL` option for this socket. + /// + /// For more information about this option, see [`set_multicast_ttl_v4`]. + /// + /// [`set_multicast_ttl_v4`]: #method.set_multicast_ttl_v4 + pub fn multicast_ttl_v4(&self) -> io::Result { + self.io_handle.get_ref().multicast_ttl_v4() + } + + /// Sets the value of the `IP_MULTICAST_TTL` option for this socket. + /// + /// Indicates the time-to-live value of outgoing multicast packets for this socket. The default + /// value is 1 which means that multicast packets don't leave the local network unless + /// explicitly requested. + /// + /// # Note + /// + /// This may not have any affect on IPv6 sockets. + pub fn set_multicast_ttl_v4(&self, ttl: u32) -> io::Result<()> { + self.io_handle.get_ref().set_multicast_ttl_v4(ttl) + } + + /// Gets the value of the `IPV6_MULTICAST_LOOP` option for this socket. + /// + /// For more information about this option, see [`set_multicast_loop_v6`]. + /// + /// [`set_multicast_loop_v6`]: #method.set_multicast_loop_v6 + pub fn multicast_loop_v6(&self) -> io::Result { + self.io_handle.get_ref().multicast_loop_v6() + } + + /// Sets the value of the `IPV6_MULTICAST_LOOP` option for this socket. + /// + /// Controls whether this socket sees the multicast packets it sends itself. + /// + /// # Note + /// + /// This may not have any affect on IPv4 sockets. + pub fn set_multicast_loop_v6(&self, on: bool) -> io::Result<()> { + self.io_handle.get_ref().set_multicast_loop_v6(on) + } + + /// Gets the value of the `IP_TTL` option for this socket. + /// + /// For more information about this option, see [`set_ttl`]. + /// + /// [`set_ttl`]: #method.set_ttl + pub fn ttl(&self) -> io::Result { + self.io_handle.get_ref().ttl() + } + + /// Sets the value for the `IP_TTL` option on this socket. + /// + /// This value sets the time-to-live field that is used in every packet sent + /// from this socket. + pub fn set_ttl(&self, ttl: u32) -> io::Result<()> { + self.io_handle.get_ref().set_ttl(ttl) + } + + /// Executes an operation of the `IP_ADD_MEMBERSHIP` type. + /// + /// This method specifies a new multicast group for this socket to join. The address must be + /// a valid multicast address, and `interface` is the address of the local interface with which + /// the system should join the multicast group. If it's equal to `INADDR_ANY` then an + /// appropriate interface is chosen by the system. + /// + /// # Examples + /// + /// ```no_run + /// # #![feature(async_await)] + /// use async_std::net::UdpSocket; + /// use std::net::Ipv4Addr; + /// + /// # futures::executor::block_on(async { + /// let interface = Ipv4Addr::new(0, 0, 0, 0); + /// let mdns_addr = Ipv4Addr::new(224, 0, 0, 123); + /// + /// let socket = UdpSocket::bind("127.0.0.1:0").await?; + /// socket.join_multicast_v4(&mdns_addr, &interface)?; + /// # Ok::<_, Box>(()) + /// # }).unwrap(); + /// ``` + pub fn join_multicast_v4(&self, multiaddr: &Ipv4Addr, interface: &Ipv4Addr) -> io::Result<()> { + self.io_handle + .get_ref() + .join_multicast_v4(multiaddr, interface) + } + + /// Executes an operation of the `IPV6_ADD_MEMBERSHIP` type. + /// + /// This method specifies a new multicast group for this socket to join. The address must be + /// a valid multicast address, and `interface` is the index of the interface to join/leave (or + /// 0 to indicate any interface). + /// + /// # Examples + /// + /// ```no_run + /// # #![feature(async_await)] + /// use async_std::net::UdpSocket; + /// use std::net::{Ipv6Addr, SocketAddr}; + /// + /// # futures::executor::block_on(async { + /// let socket_addr = SocketAddr::new(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 0).into(), 0); + /// let mdns_addr = Ipv6Addr::new(0xFF02, 0, 0, 0, 0, 0, 0, 0x0123) ; + /// let socket = UdpSocket::bind(&socket_addr).await?; + /// + /// socket.join_multicast_v6(&mdns_addr, 0)?; + /// # Ok::<_, Box>(()) + /// # }).unwrap(); + /// ``` + pub fn join_multicast_v6(&self, multiaddr: &Ipv6Addr, interface: u32) -> io::Result<()> { + self.io_handle + .get_ref() + .join_multicast_v6(multiaddr, interface) + } + + /// Executes an operation of the `IP_DROP_MEMBERSHIP` type. + /// + /// For more information about this option, see [`join_multicast_v4`]. + /// + /// [`join_multicast_v4`]: #method.join_multicast_v4 + pub fn leave_multicast_v4(&self, multiaddr: &Ipv4Addr, interface: &Ipv4Addr) -> io::Result<()> { + self.io_handle + .get_ref() + .leave_multicast_v4(multiaddr, interface) + } + + /// Executes an operation of the `IPV6_DROP_MEMBERSHIP` type. + /// + /// For more information about this option, see [`join_multicast_v6`]. + /// + /// [`join_multicast_v6`]: #method.join_multicast_v6 + pub fn leave_multicast_v6(&self, multiaddr: &Ipv6Addr, interface: u32) -> io::Result<()> { + self.io_handle + .get_ref() + .leave_multicast_v6(multiaddr, interface) + } +} + +impl From for UdpSocket { + /// Converts a `std::net::UdpSocket` into its asynchronous equivalent. + fn from(socket: net::UdpSocket) -> UdpSocket { + let mio_socket = mio::net::UdpSocket::from_socket(socket).unwrap(); + + #[cfg(unix)] + let socket = UdpSocket { + raw_fd: mio_socket.as_raw_fd(), + io_handle: IoHandle::new(mio_socket), + }; + + #[cfg(windows)] + let socket = UdpSocket { + // raw_socket: mio_socket.as_raw_socket(), + io_handle: IoHandle::new(mio_socket), + }; + + socket + } +} + +cfg_if! { + if #[cfg(feature = "docs.rs")] { + use crate::os::unix::io::{AsRawFd, FromRawFd, IntoRawFd, RawFd}; + // use crate::os::windows::io::{AsRawHandle, FromRawHandle, IntoRawHandle, RawHandle}; + } else if #[cfg(unix)] { + use std::os::unix::io::{AsRawFd, FromRawFd, IntoRawFd, RawFd}; + } else if #[cfg(windows)] { + // use std::os::windows::io::{AsRawHandle, FromRawHandle, IntoRawHandle, RawHandle}; + } +} + +#[cfg_attr(feature = "docs.rs", doc(cfg(unix)))] +cfg_if! { + if #[cfg(any(unix, feature = "docs.rs"))] { + impl AsRawFd for UdpSocket { + fn as_raw_fd(&self) -> RawFd { + self.raw_fd + } + } + + impl FromRawFd for UdpSocket { + unsafe fn from_raw_fd(fd: RawFd) -> UdpSocket { + net::UdpSocket::from_raw_fd(fd).into() + } + } + + impl IntoRawFd for UdpSocket { + fn into_raw_fd(self) -> RawFd { + self.raw_fd + } + } + } +} + +#[cfg_attr(feature = "docs.rs", doc(cfg(windows)))] +cfg_if! { + if #[cfg(any(windows, feature = "docs.rs"))] { + // use std::os::windows::io::{AsRawSocket, FromRawSocket, IntoRawSocket, RawSocket}; + // + // impl AsRawSocket for UdpSocket { + // fn as_raw_socket(&self) -> RawSocket { + // self.raw_socket + // } + // } + // + // impl FromRawSocket for UdpSocket { + // unsafe fn from_raw_socket(handle: RawSocket) -> UdpSocket { + // net::UdpSocket::from_raw_socket(handle).into() + // } + // } + // + // impl IntoRawSocket for UdpSocket { + // fn into_raw_socket(self) -> RawSocket { + // self.raw_socket + // } + // } + } +} diff --git a/src/os/mod.rs b/src/os/mod.rs new file mode 100644 index 00000000..e81f86b6 --- /dev/null +++ b/src/os/mod.rs @@ -0,0 +1,9 @@ +//! OS-specific extensions. + +#[cfg(any(unix, feature = "docs.rs"))] +#[cfg_attr(feature = "docs.rs", doc(cfg(unix)))] +pub mod unix; + +#[cfg(any(windows, feature = "docs.rs"))] +#[cfg_attr(feature = "docs.rs", doc(cfg(windows)))] +pub mod windows; diff --git a/src/os/unix/fs.rs b/src/os/unix/fs.rs new file mode 100644 index 00000000..55b60b2b --- /dev/null +++ b/src/os/unix/fs.rs @@ -0,0 +1,75 @@ +//! Unix-specific filesystem extensions. + +use std::io; +use std::path::Path; + +use cfg_if::cfg_if; + +use crate::task::blocking; + +/// Creates a new symbolic link on the filesystem. +/// +/// The `dst` path will be a symbolic link pointing to the `src` path. +/// +/// This function is an async version of [`std::os::unix::fs::symlink`]. +/// +/// [`std::os::unix::fs::symlink`]: https://doc.rust-lang.org/std/os/unix/fs/fn.symlink.html +/// +/// # Examples +/// +/// ```no_run +/// # #![feature(async_await)] +/// use async_std::os::unix::fs::symlink; +/// +/// # futures::executor::block_on(async { +/// symlink("a.txt", "b.txt").await?; +/// # std::io::Result::Ok(()) +/// # }).unwrap(); +/// ``` +pub async fn symlink, Q: AsRef>(src: P, dst: Q) -> io::Result<()> { + let src = src.as_ref().to_owned(); + let dst = dst.as_ref().to_owned(); + blocking::spawn(async move { std::os::unix::fs::symlink(&src, &dst) }).await +} + +cfg_if! { + if #[cfg(feature = "docs.rs")] { + /// Unix-specific extensions to `DirBuilder`. + pub trait DirBuilderExt { + /// Sets the mode to create new directories with. This option defaults to + /// `0o777`. + fn mode(&mut self, mode: u32) -> &mut Self; + } + + /// Unix-specific extension methods for `DirEntry`. + pub trait DirEntryExt { + /// Returns the underlying `d_ino` field in the contained `dirent` + /// structure. + fn ino(&self) -> u64; + } + + /// Unix-specific extensions to `OpenOptions`. + pub trait OpenOptionsExt { + /// Sets the mode bits that a new file will be created with. + /// + /// If a new file is created as part of a `File::open_opts` call then this + /// specified `mode` will be used as the permission bits for the new file. + /// If no `mode` is set, the default of `0o666` will be used. + /// The operating system masks out bits with the systems `umask`, to produce + /// the final permissions. + fn mode(&mut self, mode: u32) -> &mut Self; + + /// Pass custom flags to the `flags` argument of `open`. + /// + /// The bits that define the access mode are masked out with `O_ACCMODE`, to + /// ensure they do not interfere with the access mode set by Rusts options. + /// + /// Custom flags can only set flags, not remove flags set by Rusts options. + /// This options overwrites any previously set custom flags. + fn custom_flags(&mut self, flags: i32) -> &mut Self; + } + } else { + #[doc(inline)] + pub use std::os::unix::fs::{DirBuilderExt, OpenOptionsExt}; + } +} diff --git a/src/os/unix/io.rs b/src/os/unix/io.rs new file mode 100644 index 00000000..fb0d255f --- /dev/null +++ b/src/os/unix/io.rs @@ -0,0 +1,57 @@ +//! Unix-specific I/O extensions. + +use cfg_if::cfg_if; + +cfg_if! { + if #[cfg(feature = "docs.rs")] { + /// Raw file descriptors. + pub type RawFd = std::os::raw::c_int; + + /// A trait to extract the raw unix file descriptor from an underlying + /// object. + /// + /// This is only available on unix platforms and must be imported in order + /// to call the method. Windows platforms have a corresponding `AsRawHandle` + /// and `AsRawSocket` set of traits. + pub trait AsRawFd { + /// Extracts the raw file descriptor. + /// + /// This method does **not** pass ownership of the raw file descriptor + /// to the caller. The descriptor is only guaranteed to be valid while + /// the original object has not yet been destroyed. + fn as_raw_fd(&self) -> RawFd; + } + + /// A trait to express the ability to construct an object from a raw file + /// descriptor. + pub trait FromRawFd { + /// Constructs a new instance of `Self` from the given raw file + /// descriptor. + /// + /// This function **consumes ownership** of the specified file + /// descriptor. The returned object will take responsibility for closing + /// it when the object goes out of scope. + /// + /// This function is also unsafe as the primitives currently returned + /// have the contract that they are the sole owner of the file + /// descriptor they are wrapping. Usage of this function could + /// accidentally allow violating this contract which can cause memory + /// unsafety in code that relies on it being true. + unsafe fn from_raw_fd(fd: RawFd) -> Self; + } + + /// A trait to express the ability to consume an object and acquire ownership of + /// its raw file descriptor. + pub trait IntoRawFd { + /// Consumes this object, returning the raw underlying file descriptor. + /// + /// This function **transfers ownership** of the underlying file descriptor + /// to the caller. Callers are then the unique owners of the file descriptor + /// and must close the descriptor once it's no longer needed. + fn into_raw_fd(self) -> RawFd; + } + } else { + #[doc(inline)] + pub use std::os::unix::io::{AsRawFd, FromRawFd, IntoRawFd, RawFd}; + } +} diff --git a/src/os/unix/mod.rs b/src/os/unix/mod.rs new file mode 100644 index 00000000..722cfe6b --- /dev/null +++ b/src/os/unix/mod.rs @@ -0,0 +1,5 @@ +//! Platform-specific extensions for Unix platforms. + +pub mod fs; +pub mod io; +pub mod net; diff --git a/src/os/unix/net.rs b/src/os/unix/net.rs new file mode 100644 index 00000000..095b612f --- /dev/null +++ b/src/os/unix/net.rs @@ -0,0 +1,985 @@ +//! Unix-specific networking extensions. + +use std::fmt; +use std::io; +use std::mem; +use std::net::Shutdown; +use std::path::Path; +use std::pin::Pin; +use std::task::{Context, Poll}; + +use cfg_if::cfg_if; +use futures::{prelude::*, ready}; +use mio_uds; + +use crate::net::driver::IoHandle; +use crate::os::unix::io::{AsRawFd, FromRawFd, IntoRawFd, RawFd}; +use crate::task::blocking; + +/// A Unix datagram socket. +/// +/// After creating a `UnixDatagram` by [`bind`]ing it to a path, data can be [sent to] and +/// [received from] any other socket address. +/// +/// This type is an async version of [`std::os::unix::net::UnixDatagram`]. +/// +/// [`std::os::unix::net::UnixDatagram`]: +/// https://doc.rust-lang.org/std/os/unix/net/struct.UnixDatagram.html +/// [`bind`]: #method.bind +/// [received from]: #method.recv_from +/// [sent to]: #method.send_to +/// +/// ## Examples +/// +/// ```no_run +/// # #![feature(async_await)] +/// use async_std::os::unix::net::UnixDatagram; +/// +/// # futures::executor::block_on(async { +/// let socket = UnixDatagram::bind("/tmp/socket1").await?; +/// socket.send_to(b"hello world", "/tmp/socket2").await?; +/// +/// let mut buf = vec![0u8; 1024]; +/// let (n, peer) = socket.recv_from(&mut buf).await?; +/// println!("Received {} bytes from {:?}", n, peer); +/// # std::io::Result::Ok(()) +/// # }).unwrap(); +/// ``` +pub struct UnixDatagram { + #[cfg(not(feature = "docs.rs"))] + io_handle: IoHandle, + + raw_fd: RawFd, +} + +impl UnixDatagram { + #[cfg(not(feature = "docs.rs"))] + fn new(socket: mio_uds::UnixDatagram) -> UnixDatagram { + UnixDatagram { + raw_fd: socket.as_raw_fd(), + io_handle: IoHandle::new(socket), + } + } + + /// Creates a Unix datagram socket bound to the given path. + /// + /// # Examples + /// + /// ```no_run + /// # #![feature(async_await)] + /// use async_std::os::unix::net::UnixDatagram; + /// + /// # futures::executor::block_on(async { + /// let socket = UnixDatagram::bind("/tmp/socket").await?; + /// # std::io::Result::Ok(()) + /// # }).unwrap(); + /// ``` + pub async fn bind>(path: P) -> io::Result { + let path = path.as_ref().to_owned(); + let socket = blocking::spawn(async move { mio_uds::UnixDatagram::bind(path) }).await?; + Ok(UnixDatagram::new(socket)) + } + + /// Creates a Unix datagram which is not bound to any address. + /// + /// # Examples + /// + /// ```no_run + /// # #![feature(async_await)] + /// use async_std::os::unix::net::UnixDatagram; + /// + /// # futures::executor::block_on(async { + /// let socket = UnixDatagram::unbound()?; + /// # std::io::Result::Ok(()) + /// # }).unwrap(); + /// ``` + pub fn unbound() -> io::Result { + let socket = mio_uds::UnixDatagram::unbound()?; + Ok(UnixDatagram::new(socket)) + } + + /// Creates an unnamed pair of connected sockets. + /// + /// Returns two sockets which are connected to each other. + /// + /// # Examples + /// + /// ```no_run + /// # #![feature(async_await)] + /// use async_std::os::unix::net::UnixDatagram; + /// + /// # futures::executor::block_on(async { + /// let (socket1, socket2) = UnixDatagram::pair()?; + /// # std::io::Result::Ok(()) + /// # }).unwrap(); + /// ``` + pub fn pair() -> io::Result<(UnixDatagram, UnixDatagram)> { + let (a, b) = mio_uds::UnixDatagram::pair()?; + let a = UnixDatagram::new(a); + let b = UnixDatagram::new(b); + Ok((a, b)) + } + + /// Connects the socket to the specified address. + /// + /// The [`send`] method may be used to send data to the specified address. [`recv`] and + /// [`recv_from`] will only receive data from that address. + /// + /// [`send`]: #method.send + /// [`recv`]: #method.recv + /// [`recv_from`]: #method.recv_from + /// + /// # Examples + /// + /// ```no_run + /// # #![feature(async_await)] + /// use async_std::os::unix::net::UnixDatagram; + /// + /// # futures::executor::block_on(async { + /// let socket = UnixDatagram::unbound()?; + /// socket.connect("/tmp/socket").await?; + /// # std::io::Result::Ok(()) + /// # }).unwrap(); + /// ``` + pub async fn connect>(&self, path: P) -> io::Result<()> { + // TODO(stjepang): Connect the socket on a blocking pool. + let p = path.as_ref(); + self.io_handle.get_ref().connect(p) + } + + /// Returns the address of this socket. + /// + /// # Examples + /// + /// ```no_run + /// # #![feature(async_await)] + /// use async_std::os::unix::net::UnixDatagram; + /// + /// # futures::executor::block_on(async { + /// let socket = UnixDatagram::bind("/tmp/socket").await?; + /// let addr = socket.local_addr()?; + /// # std::io::Result::Ok(()) + /// # }).unwrap(); + /// ``` + pub fn local_addr(&self) -> io::Result { + self.io_handle.get_ref().local_addr() + } + + /// Returns the address of this socket's peer. + /// + /// The [`connect`] method will connect the socket to a peer. + /// + /// [`connect`]: #method.connect + /// + /// # Examples + /// + /// ```no_run + /// # #![feature(async_await)] + /// use async_std::os::unix::net::UnixDatagram; + /// + /// # futures::executor::block_on(async { + /// let mut socket = UnixDatagram::unbound()?; + /// socket.connect("/tmp/socket").await?; + /// let peer = socket.peer_addr()?; + /// # std::io::Result::Ok(()) + /// # }).unwrap(); + /// ``` + pub fn peer_addr(&self) -> io::Result { + self.io_handle.get_ref().peer_addr() + } + + /// Receives data from the socket. + /// + /// On success, returns the number of bytes read and the address from where the data came. + /// + /// # Examples + /// + /// ```no_run + /// # #![feature(async_await)] + /// use async_std::os::unix::net::UnixDatagram; + /// + /// # futures::executor::block_on(async { + /// let mut socket = UnixDatagram::unbound()?; + /// let mut buf = vec![0; 1024]; + /// let (n, peer) = socket.recv_from(&mut buf).await?; + /// # std::io::Result::Ok(()) + /// # }).unwrap(); + /// ``` + pub async fn recv_from(&self, buf: &mut [u8]) -> io::Result<(usize, SocketAddr)> { + future::poll_fn(|cx| { + ready!(self.io_handle.poll_readable(cx)?); + + match self.io_handle.get_ref().recv_from(buf) { + Ok(n) => Poll::Ready(Ok(n)), + Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => { + self.io_handle.clear_readable(cx)?; + Poll::Pending + } + Err(err) => Poll::Ready(Err(err)), + } + }) + .await + } + + /// Receives data from the socket. + /// + /// On success, returns the number of bytes read. + /// + /// # Examples + /// + /// ```no_run + /// # #![feature(async_await)] + /// use async_std::os::unix::net::UnixDatagram; + /// + /// # futures::executor::block_on(async { + /// let socket = UnixDatagram::bind("/tmp/socket").await?; + /// let mut buf = vec![0; 1024]; + /// let n = socket.recv(&mut buf).await?; + /// # std::io::Result::Ok(()) + /// # }).unwrap(); + /// ``` + pub async fn recv(&self, buf: &mut [u8]) -> io::Result { + future::poll_fn(|cx| { + ready!(self.io_handle.poll_writable(cx)?); + + match self.io_handle.get_ref().recv(buf) { + Ok(n) => Poll::Ready(Ok(n)), + Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => { + self.io_handle.clear_writable(cx)?; + Poll::Pending + } + Err(err) => Poll::Ready(Err(err)), + } + }) + .await + } + + /// Sends data on the socket to the specified address. + /// + /// On success, returns the number of bytes written. + /// + /// # Examples + /// + /// ```no_run + /// # #![feature(async_await)] + /// use async_std::os::unix::net::UnixDatagram; + /// + /// # futures::executor::block_on(async { + /// let mut socket = UnixDatagram::unbound()?; + /// socket.send_to(b"hello world", "/tmp/socket").await?; + /// # std::io::Result::Ok(()) + /// # }).unwrap(); + /// ``` + pub async fn send_to>(&self, buf: &[u8], path: P) -> io::Result { + future::poll_fn(|cx| { + ready!(self.io_handle.poll_writable(cx)?); + + match self.io_handle.get_ref().send_to(buf, path.as_ref()) { + Ok(n) => Poll::Ready(Ok(n)), + Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => { + self.io_handle.clear_writable(cx)?; + Poll::Pending + } + Err(err) => Poll::Ready(Err(err)), + } + }) + .await + } + + /// Sends data on the socket to the socket's peer. + /// + /// On success, returns the number of bytes written. + /// + /// # Examples + /// + /// ```no_run + /// # #![feature(async_await)] + /// use async_std::os::unix::net::UnixDatagram; + /// + /// # futures::executor::block_on(async { + /// let mut socket = UnixDatagram::unbound()?; + /// socket.connect("/tmp/socket").await?; + /// socket.send(b"hello world").await?; + /// # std::io::Result::Ok(()) + /// # }).unwrap(); + /// ``` + pub async fn send(&self, buf: &[u8]) -> io::Result { + future::poll_fn(|cx| { + ready!(self.io_handle.poll_writable(cx)?); + + match self.io_handle.get_ref().send(buf) { + Ok(n) => Poll::Ready(Ok(n)), + Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => { + self.io_handle.clear_writable(cx)?; + Poll::Pending + } + Err(err) => Poll::Ready(Err(err)), + } + }) + .await + } + + /// Shut down the read, write, or both halves of this connection. + /// + /// This function will cause all pending and future I/O calls on the specified portions to + /// immediately return with an appropriate value (see the documentation of [`Shutdown`]). + /// + /// [`Shutdown`]: https://doc.rust-lang.org/std/net/enum.Shutdown.html + /// + /// ## Examples + /// + /// ```no_run + /// # #![feature(async_await)] + /// use async_std::os::unix::net::UnixDatagram; + /// use std::net::Shutdown; + /// + /// # futures::executor::block_on(async { + /// let socket = UnixDatagram::unbound()?; + /// socket.shutdown(Shutdown::Both)?; + /// # std::io::Result::Ok(()) + /// # }).unwrap(); + /// ``` + pub fn shutdown(&self, how: Shutdown) -> io::Result<()> { + self.io_handle.get_ref().shutdown(how) + } +} + +impl fmt::Debug for UnixDatagram { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let mut builder = f.debug_struct("UnixDatagram"); + builder.field("fd", &self.as_raw_fd()); + + if let Ok(addr) = self.local_addr() { + builder.field("local", &addr); + } + + if let Ok(addr) = self.peer_addr() { + builder.field("peer", &addr); + } + + builder.finish() + } +} + +/// A Unix domain socket server, listening for connections. +/// +/// After creating a `UnixListener` by [`bind`]ing it to a socket address, it listens for incoming +/// connections. These can be accepted by awaiting elements from the async stream of [`incoming`] +/// connections. +/// +/// The socket will be closed when the value is dropped. +/// +/// This type is an async version of [`std::os::unix::net::UnixListener`]. +/// +/// [`std::os::unix::net::UnixListener`]: +/// https://doc.rust-lang.org/std/os/unix/net/struct.UnixListener.html +/// [`bind`]: #method.bind +/// [`incoming`]: #method.incoming +/// +/// # Examples +/// +/// ```no_run +/// # #![feature(async_await)] +/// use async_std::os::unix::net::UnixListener; +/// use async_std::prelude::*; +/// +/// # futures::executor::block_on(async { +/// let listener = UnixListener::bind("/tmp/socket").await?; +/// let mut incoming = listener.incoming(); +/// +/// while let Some(stream) = incoming.next().await { +/// let mut stream = stream?; +/// stream.write_all(b"hello world").await?; +/// } +/// # std::io::Result::Ok(()) +/// # }).unwrap(); +/// ``` +pub struct UnixListener { + #[cfg(not(feature = "docs.rs"))] + io_handle: IoHandle, + + raw_fd: RawFd, +} + +impl UnixListener { + /// Creates a Unix datagram listener bound to the given path. + /// + /// # Examples + /// + /// ```no_run + /// # #![feature(async_await)] + /// use async_std::os::unix::net::UnixListener; + /// + /// # futures::executor::block_on(async { + /// let listener = UnixListener::bind("/tmp/socket").await?; + /// # std::io::Result::Ok(()) + /// # }).unwrap(); + /// ``` + pub async fn bind>(path: P) -> io::Result { + let path = path.as_ref().to_owned(); + let listener = blocking::spawn(async move { mio_uds::UnixListener::bind(path) }).await?; + + Ok(UnixListener { + raw_fd: listener.as_raw_fd(), + io_handle: IoHandle::new(listener), + }) + } + + /// Accepts a new incoming connection to this listener. + /// + /// When a connection is established, the corresponding stream and address will be returned. + /// + /// # Examples + /// + /// ```no_run + /// # #![feature(async_await)] + /// use async_std::os::unix::net::UnixListener; + /// + /// # futures::executor::block_on(async { + /// let listener = UnixListener::bind("/tmp/socket").await?; + /// let (socket, addr) = listener.accept().await?; + /// # std::io::Result::Ok(()) + /// # }).unwrap(); + /// ``` + pub async fn accept(&self) -> io::Result<(UnixStream, SocketAddr)> { + future::poll_fn(|cx| { + ready!(self.io_handle.poll_readable(cx)?); + + match self.io_handle.get_ref().accept_std() { + Ok(Some((io, addr))) => { + let mio_stream = mio_uds::UnixStream::from_stream(io)?; + let stream = UnixStream { + raw_fd: mio_stream.as_raw_fd(), + io_handle: IoHandle::new(mio_stream), + }; + Poll::Ready(Ok((stream, addr))) + } + Ok(None) => { + self.io_handle.clear_readable(cx)?; + Poll::Pending + } + Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => { + self.io_handle.clear_readable(cx)?; + Poll::Pending + } + Err(err) => Poll::Ready(Err(err)), + } + }) + .await + } + + /// Returns a stream of incoming connections. + /// + /// Iterating over this stream is equivalent to calling [`accept`] in a loop. The stream of + /// connections is infinite, i.e awaiting the next connection will never result in [`None`]. + /// + /// [`accept`]: #method.accept + /// [`None`]: https://doc.rust-lang.org/std/option/enum.Option.html#variant.None + /// + /// # Examples + /// + /// ```no_run + /// # #![feature(async_await)] + /// use async_std::os::unix::net::UnixListener; + /// use async_std::prelude::*; + /// + /// # futures::executor::block_on(async { + /// let listener = UnixListener::bind("/tmp/socket").await?; + /// let mut incoming = listener.incoming(); + /// + /// while let Some(stream) = incoming.next().await { + /// let mut stream = stream?; + /// stream.write_all(b"hello world").await?; + /// } + /// # std::io::Result::Ok(()) + /// # }).unwrap(); + /// ``` + pub fn incoming(&self) -> Incoming<'_> { + Incoming(self) + } + + /// Returns the local socket address of this listener. + /// + /// # Examples + /// + /// ```no_run + /// # #![feature(async_await)] + /// use async_std::os::unix::net::UnixListener; + /// + /// # futures::executor::block_on(async { + /// let listener = UnixListener::bind("/tmp/socket").await?; + /// let addr = listener.local_addr()?; + /// # std::io::Result::Ok(()) + /// # }).unwrap(); + /// ``` + pub fn local_addr(&self) -> io::Result { + self.io_handle.get_ref().local_addr() + } +} + +impl fmt::Debug for UnixListener { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let mut builder = f.debug_struct("UnixListener"); + builder.field("fd", &self.as_raw_fd()); + + if let Ok(addr) = self.local_addr() { + builder.field("local", &addr); + } + + builder.finish() + } +} + +/// A stream of incoming Unix domain socket connections. +/// +/// This stream is infinite, i.e awaiting the next connection will never result in [`None`]. It is +/// created by the [`incoming`] method on [`UnixListener`]. +/// +/// This type is an async version of [`std::os::unix::net::Incoming`]. +/// +/// [`std::os::unix::net::Incoming`]: https://doc.rust-lang.org/std/os/unix/net/struct.Incoming.html +/// [`None`]: https://doc.rust-lang.org/std/option/enum.Option.html#variant.None +/// [`incoming`]: struct.UnixListener.html#method.incoming +/// [`UnixListener`]: struct.UnixListener.html +#[derive(Debug)] +pub struct Incoming<'a>(&'a UnixListener); + +impl Stream for Incoming<'_> { + type Item = io::Result; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let future = self.0.accept(); + futures::pin_mut!(future); + + let (socket, _) = ready!(future.poll(cx))?; + Poll::Ready(Some(Ok(socket))) + } +} + +/// A Unix stream socket. +/// +/// This type is an async version of [`std::os::unix::net::UnixStream`]. +/// +/// [`std::os::unix::net::UnixStream`]: +/// https://doc.rust-lang.org/std/os/unix/net/struct.UnixStream.html +/// +/// # Examples +/// +/// ```no_run +/// # #![feature(async_await)] +/// use async_std::os::unix::net::UnixStream; +/// use async_std::prelude::*; +/// +/// # futures::executor::block_on(async { +/// let mut stream = UnixStream::connect("/tmp/socket").await?; +/// stream.write_all(b"hello world").await?; +/// +/// let mut response = Vec::new(); +/// stream.read_to_end(&mut response).await?; +/// # std::io::Result::Ok(()) +/// # }).unwrap(); +/// ``` +pub struct UnixStream { + #[cfg(not(feature = "docs.rs"))] + io_handle: IoHandle, + + raw_fd: RawFd, +} + +impl UnixStream { + /// Connects to the socket to the specified address. + /// + /// # Examples + /// + /// ```no_run + /// # #![feature(async_await)] + /// use async_std::os::unix::net::UnixStream; + /// + /// # futures::executor::block_on(async { + /// let stream = UnixStream::connect("/tmp/socket").await?; + /// # std::io::Result::Ok(()) + /// # }).unwrap(); + /// ``` + pub async fn connect>(path: P) -> io::Result { + enum State { + Waiting(UnixStream), + Error(io::Error), + Done, + } + + let path = path.as_ref().to_owned(); + let mut state = { + match blocking::spawn(async move { mio_uds::UnixStream::connect(path) }).await { + Ok(mio_stream) => State::Waiting(UnixStream { + raw_fd: mio_stream.as_raw_fd(), + io_handle: IoHandle::new(mio_stream), + }), + Err(err) => State::Error(err), + } + }; + + future::poll_fn(|cx| { + match &mut state { + State::Waiting(stream) => { + ready!(stream.io_handle.poll_writable(cx)?); + + if let Some(err) = stream.io_handle.get_ref().take_error()? { + return Poll::Ready(Err(err)); + } + } + State::Error(_) => { + let err = match mem::replace(&mut state, State::Done) { + State::Error(err) => err, + _ => unreachable!(), + }; + + return Poll::Ready(Err(err)); + } + State::Done => panic!("`UnixStream::connect()` future polled after completion"), + } + + match mem::replace(&mut state, State::Done) { + State::Waiting(stream) => Poll::Ready(Ok(stream)), + _ => unreachable!(), + } + }) + .await + } + + /// Creates an unnamed pair of connected sockets. + /// + /// Returns two streams which are connected to each other. + /// + /// # Examples + /// + /// ```no_run + /// # #![feature(async_await)] + /// use async_std::os::unix::net::UnixStream; + /// + /// # futures::executor::block_on(async { + /// let stream = UnixStream::pair()?; + /// # std::io::Result::Ok(()) + /// # }).unwrap(); + /// ``` + pub fn pair() -> io::Result<(UnixStream, UnixStream)> { + let (a, b) = mio_uds::UnixStream::pair()?; + let a = UnixStream { + raw_fd: a.as_raw_fd(), + io_handle: IoHandle::new(a), + }; + let b = UnixStream { + raw_fd: b.as_raw_fd(), + io_handle: IoHandle::new(b), + }; + Ok((a, b)) + } + + /// Returns the socket address of the local half of this connection. + /// + /// # Examples + /// + /// ```no_run + /// # #![feature(async_await)] + /// use async_std::os::unix::net::UnixStream; + /// + /// # futures::executor::block_on(async { + /// let stream = UnixStream::connect("/tmp/socket").await?; + /// let addr = stream.local_addr()?; + /// # std::io::Result::Ok(()) + /// # }).unwrap(); + /// ``` + pub fn local_addr(&self) -> io::Result { + self.io_handle.get_ref().local_addr() + } + + /// Returns the socket address of the remote half of this connection. + /// + /// # Examples + /// + /// ```no_run + /// # #![feature(async_await)] + /// use async_std::os::unix::net::UnixStream; + /// + /// # futures::executor::block_on(async { + /// let stream = UnixStream::connect("/tmp/socket").await?; + /// let peer = stream.peer_addr()?; + /// # std::io::Result::Ok(()) + /// # }).unwrap(); + /// ``` + pub fn peer_addr(&self) -> io::Result { + self.io_handle.get_ref().peer_addr() + } + + /// Shuts down the read, write, or both halves of this connection. + /// + /// This function will cause all pending and future I/O calls on the specified portions to + /// immediately return with an appropriate value (see the documentation of [`Shutdown`]). + /// + /// [`Shutdown`]: https://doc.rust-lang.org/std/net/enum.Shutdown.html + /// + /// ```no_run + /// # #![feature(async_await)] + /// use async_std::os::unix::net::UnixStream; + /// use std::net::Shutdown; + /// + /// # futures::executor::block_on(async { + /// let stream = UnixStream::connect("/tmp/socket").await?; + /// stream.shutdown(Shutdown::Both)?; + /// # std::io::Result::Ok(()) + /// # }).unwrap(); + /// ``` + pub fn shutdown(&self, how: Shutdown) -> io::Result<()> { + self.io_handle.get_ref().shutdown(how) + } +} + +impl AsyncRead for UnixStream { + fn poll_read( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut [u8], + ) -> Poll> { + Pin::new(&mut &*self).poll_read(cx, buf) + } +} + +impl AsyncRead for &UnixStream { + fn poll_read( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut [u8], + ) -> Poll> { + Pin::new(&mut &self.io_handle).poll_read(cx, buf) + } +} + +impl AsyncWrite for UnixStream { + fn poll_write( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + Pin::new(&mut &*self).poll_write(cx, buf) + } + + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + Pin::new(&mut &*self).poll_flush(cx) + } + + fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + Pin::new(&mut &*self).poll_close(cx) + } +} + +impl AsyncWrite for &UnixStream { + fn poll_write( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + Pin::new(&mut &self.io_handle).poll_write(cx, buf) + } + + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + Pin::new(&mut &self.io_handle).poll_flush(cx) + } + + fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + Pin::new(&mut &self.io_handle).poll_close(cx) + } +} + +impl fmt::Debug for UnixStream { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let mut builder = f.debug_struct("UnixStream"); + builder.field("fd", &self.as_raw_fd()); + + if let Ok(addr) = self.local_addr() { + builder.field("local", &addr); + } + + if let Ok(addr) = self.peer_addr() { + builder.field("peer", &addr); + } + + builder.finish() + } +} + +#[cfg(unix)] +impl From for UnixStream { + /// Converts a `std::os::unix::net::UnixStream` into its asynchronous equivalent. + fn from(stream: std::os::unix::net::UnixStream) -> UnixStream { + let mio_stream = mio_uds::UnixStream::from_stream(stream).unwrap(); + UnixStream { + raw_fd: mio_stream.as_raw_fd(), + io_handle: IoHandle::new(mio_stream), + } + } +} + +#[cfg(unix)] +impl From for UnixDatagram { + /// Converts a `std::os::unix::net::UnixDatagram` into its asynchronous equivalent. + fn from(datagram: std::os::unix::net::UnixDatagram) -> UnixDatagram { + let mio_datagram = mio_uds::UnixDatagram::from_datagram(datagram).unwrap(); + UnixDatagram { + raw_fd: mio_datagram.as_raw_fd(), + io_handle: IoHandle::new(mio_datagram), + } + } +} + +#[cfg(unix)] +impl From for UnixListener { + /// Converts a `std::os::unix::net::UnixListener` into its asynchronous equivalent. + fn from(listener: std::os::unix::net::UnixListener) -> UnixListener { + let mio_listener = mio_uds::UnixListener::from_listener(listener).unwrap(); + UnixListener { + raw_fd: mio_listener.as_raw_fd(), + io_handle: IoHandle::new(mio_listener), + } + } +} + +impl AsRawFd for UnixListener { + fn as_raw_fd(&self) -> RawFd { + self.raw_fd + } +} + +impl FromRawFd for UnixListener { + unsafe fn from_raw_fd(fd: RawFd) -> UnixListener { + let listener = std::os::unix::net::UnixListener::from_raw_fd(fd); + listener.into() + } +} + +impl IntoRawFd for UnixListener { + fn into_raw_fd(self) -> RawFd { + self.raw_fd + } +} + +impl AsRawFd for UnixStream { + fn as_raw_fd(&self) -> RawFd { + self.raw_fd + } +} + +impl FromRawFd for UnixStream { + unsafe fn from_raw_fd(fd: RawFd) -> UnixStream { + let stream = std::os::unix::net::UnixStream::from_raw_fd(fd); + stream.into() + } +} + +impl IntoRawFd for UnixStream { + fn into_raw_fd(self) -> RawFd { + self.raw_fd + } +} + +impl AsRawFd for UnixDatagram { + fn as_raw_fd(&self) -> RawFd { + self.raw_fd + } +} + +impl FromRawFd for UnixDatagram { + unsafe fn from_raw_fd(fd: RawFd) -> UnixDatagram { + let datagram = std::os::unix::net::UnixDatagram::from_raw_fd(fd); + datagram.into() + } +} + +impl IntoRawFd for UnixDatagram { + fn into_raw_fd(self) -> RawFd { + self.raw_fd + } +} + +cfg_if! { + if #[cfg(feature = "docs.rs")] { + /// An address associated with a Unix socket. + /// + /// # Examples + /// + /// ``` + /// use async_std::os::unix::net::UnixListener; + /// + /// let socket = UnixListener::bind("/tmp/socket").await?; + /// let addr = socket.local_addr()?; + /// ``` + #[derive(Clone)] + pub struct SocketAddr { + _private: (), + } + + impl SocketAddr { + /// Returns `true` if the address is unnamed. + /// + /// # Examples + /// + /// A named address: + /// + /// ```no_run + /// use async_std::os::unix::net::UnixListener; + /// + /// let socket = UnixListener::bind("/tmp/socket").await?; + /// let addr = socket.local_addr()?; + /// assert_eq!(addr.is_unnamed(), false); + /// ``` + /// + /// An unnamed address: + /// + /// ```no_run + /// use async_std::os::unix::net::UnixDatagram; + /// + /// let socket = UnixDatagram::unbound().await?; + /// let addr = socket.local_addr()?; + /// assert_eq!(addr.is_unnamed(), true); + /// ``` + pub fn is_unnamed(&self) -> bool { + unreachable!() + } + + /// Returns the contents of this address if it is a `pathname` address. + /// + /// # Examples + /// + /// With a pathname: + /// + /// ```no_run + /// use async_std::os::unix::net::UnixListener; + /// use std::path::Path; + /// + /// let socket = UnixListener::bind("/tmp/socket").await?; + /// let addr = socket.local_addr()?; + /// assert_eq!(addr.as_pathname(), Some(Path::new("/tmp/socket"))); + /// ``` + /// + /// Without a pathname: + /// + /// ``` + /// use async_std::os::unix::net::UnixDatagram; + /// + /// let socket = UnixDatagram::unbound()?; + /// let addr = socket.local_addr()?; + /// assert_eq!(addr.as_pathname(), None); + /// ``` + pub fn as_pathname(&self) -> Option<&Path> { + unreachable!() + } + } + + impl fmt::Debug for SocketAddr { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + unreachable!() + } + } + } else { + #[doc(inline)] + pub use std::os::unix::net::SocketAddr; + } +} diff --git a/src/os/windows/io.rs b/src/os/windows/io.rs new file mode 100644 index 00000000..37761478 --- /dev/null +++ b/src/os/windows/io.rs @@ -0,0 +1,51 @@ +//! Windows-specific I/O extensions. + +use cfg_if::cfg_if; + +cfg_if! { + if #[cfg(feature = "docs.rs")] { + /// Raw HANDLEs. + pub type RawHandle = *mut std::os::raw::c_void; + + /// Raw SOCKETs. + pub type RawSocket = u64; + + /// Extracts raw handles. + pub trait AsRawHandle { + /// Extracts the raw handle, without taking any ownership. + fn as_raw_handle(&self) -> RawHandle; + } + + /// Construct I/O objects from raw handles. + pub trait FromRawHandle { + /// Constructs a new I/O object from the specified raw handle. + /// + /// This function will **consume ownership** of the handle given, + /// passing responsibility for closing the handle to the returned + /// object. + /// + /// This function is also unsafe as the primitives currently returned + /// have the contract that they are the sole owner of the file + /// descriptor they are wrapping. Usage of this function could + /// accidentally allow violating this contract which can cause memory + /// unsafety in code that relies on it being true. + unsafe fn from_raw_handle(handle: RawHandle) -> Self; + } + + /// A trait to express the ability to consume an object and acquire ownership of + /// its raw `HANDLE`. + pub trait IntoRawHandle { + /// Consumes this object, returning the raw underlying handle. + /// + /// This function **transfers ownership** of the underlying handle to the + /// caller. Callers are then the unique owners of the handle and must close + /// it once it's no longer needed. + fn into_raw_handle(self) -> RawHandle; + } + } else { + #[doc(inline)] + pub use std::os::windows::io::{ + AsRawHandle, FromRawHandle, IntoRawHandle, RawHandle, RawSocket, + }; + } +} diff --git a/src/os/windows/mod.rs b/src/os/windows/mod.rs new file mode 100644 index 00000000..30218f0e --- /dev/null +++ b/src/os/windows/mod.rs @@ -0,0 +1,3 @@ +//! Platform-specific extensions for Windows. + +pub mod io; diff --git a/src/prelude.rs b/src/prelude.rs new file mode 100644 index 00000000..14221e99 --- /dev/null +++ b/src/prelude.rs @@ -0,0 +1,44 @@ +//! The async prelude. +//! +//! The prelude re-exports the most commonly used traits in async programming. +//! +//! # Examples +//! +//! Import the prelude to use the [`timeout`] combinator: +//! +//! ```no_run +//! # #![feature(async_await)] +//! use async_std::{io, prelude::*}; +//! use std::time::Duration; +//! +//! # async_std::task::block_on(async { +//! let stdin = io::stdin(); +//! let mut line = String::new(); +//! let dur = Duration::from_secs(5); +//! +//! stdin.read_line(&mut line).timeout(dur).await??; +//! # std::io::Result::Ok(()) +//! # }).unwrap(); +//! ``` +//! +//! [`timeout`]: ../time/trait.Timeout.html#method.timeout + +#[doc(no_inline)] +pub use futures::future::FutureExt as _; +#[doc(no_inline)] +pub use futures::future::TryFutureExt as _; +#[doc(no_inline)] +pub use futures::io::AsyncBufReadExt as _; +#[doc(no_inline)] +pub use futures::io::AsyncReadExt as _; +#[doc(no_inline)] +pub use futures::io::AsyncSeekExt as _; +#[doc(no_inline)] +pub use futures::io::AsyncWriteExt as _; +#[doc(no_inline)] +pub use futures::stream::StreamExt as _; +#[doc(no_inline)] +pub use futures::stream::TryStreamExt as _; + +#[doc(no_inline)] +pub use crate::time::Timeout as _; diff --git a/src/stream/mod.rs b/src/stream/mod.rs new file mode 100644 index 00000000..119a6b73 --- /dev/null +++ b/src/stream/mod.rs @@ -0,0 +1,24 @@ +//! Composable asynchronous iteration. +//! +//! This module is an async version of [`std::iter`]. +//! +//! [`std::iter`]: https://doc.rust-lang.org/std/iter/index.html +//! +//! # Examples +//! +//! ``` +//! # #![feature(async_await)] +//! # use async_std::prelude::*; +//! # async_std::task::block_on(async { +//! use async_std::stream; +//! +//! let mut stream = stream::repeat(9).take(3); +//! while let Some(num) = stream.next().await { +//! assert_eq!(num, 9); +//! } +//! # std::io::Result::Ok(()) +//! # }).unwrap(); +//! ``` + +#[doc(inline)] +pub use futures::stream::{empty, once, repeat, Empty, Once, Repeat, Stream}; diff --git a/src/sync/mod.rs b/src/sync/mod.rs new file mode 100644 index 00000000..5d3abdef --- /dev/null +++ b/src/sync/mod.rs @@ -0,0 +1,33 @@ +//! Synchronization primitives. +//! +//! This module is an async version of [`std::sync`]. +//! +//! [`std::sync`]: https://doc.rust-lang.org/std/sync/index.html +//! +//! # Examples +//! +//! Spawn a task that updates an integer protected by a mutex: +//! +//! ``` +//! # #![feature(async_await)] +//! use async_std::{sync::Mutex, task}; +//! use std::sync::Arc; +//! +//! # futures::executor::block_on(async { +//! let m1 = Arc::new(Mutex::new(0)); +//! let m2 = m1.clone(); +//! +//! task::spawn(async move { +//! *m2.lock().await = 1; +//! }) +//! .await; +//! +//! assert_eq!(*m1.lock().await, 1); +//! # }) +//! ``` + +pub use mutex::{Mutex, MutexGuard}; +pub use rwlock::{RwLock, RwLockReadGuard, RwLockWriteGuard}; + +mod mutex; +mod rwlock; diff --git a/src/sync/mutex.rs b/src/sync/mutex.rs new file mode 100644 index 00000000..727141c5 --- /dev/null +++ b/src/sync/mutex.rs @@ -0,0 +1,336 @@ +use std::cell::UnsafeCell; +use std::fmt; +use std::future::Future; +use std::ops::{Deref, DerefMut}; +use std::pin::Pin; +use std::sync::atomic::{AtomicUsize, Ordering}; +use std::task::{Context, Poll, Waker}; + +use slab::Slab; + +/// Set if the mutex is locked. +const LOCK: usize = 1 << 0; + +/// Set if there are tasks blocked on the mutex. +const BLOCKED: usize = 1 << 1; + +/// A mutual exclusion primitive for protecting shared data. +/// +/// This type is an async version of [`std::sync::Mutex`]. +/// +/// [`std::sync::Mutex`]: https://doc.rust-lang.org/std/sync/struct.Mutex.html +/// +/// # Examples +/// +/// ``` +/// # #![feature(async_await)] +/// use async_std::{sync::Mutex, task}; +/// use std::sync::Arc; +/// +/// # futures::executor::block_on(async { +/// let m = Arc::new(Mutex::new(0)); +/// let mut tasks = vec![]; +/// +/// for _ in 0..10 { +/// let m = m.clone(); +/// tasks.push(task::spawn(async move { +/// *m.lock().await += 1; +/// })); +/// } +/// +/// for t in tasks { +/// t.await; +/// } +/// assert_eq!(*m.lock().await, 10); +/// # }) +/// ``` +pub struct Mutex { + state: AtomicUsize, + blocked: std::sync::Mutex>>, + value: UnsafeCell, +} + +unsafe impl Send for Mutex {} +unsafe impl Sync for Mutex {} + +impl Mutex { + /// Creates a new mutex. + /// + /// # Examples + /// + /// ``` + /// use async_std::sync::Mutex; + /// + /// let mutex = Mutex::new(0); + /// ``` + pub fn new(t: T) -> Mutex { + Mutex { + state: AtomicUsize::new(0), + blocked: std::sync::Mutex::new(Slab::new()), + value: UnsafeCell::new(t), + } + } + + /// Acquires the lock. + /// + /// Returns a guard that releases the lock when dropped. + /// + /// # Examples + /// + /// ``` + /// # #![feature(async_await)] + /// use async_std::{sync::Mutex, task}; + /// use std::sync::Arc; + /// + /// # futures::executor::block_on(async { + /// let m1 = Arc::new(Mutex::new(10)); + /// let m2 = m1.clone(); + /// + /// task::spawn(async move { + /// *m1.lock().await = 20; + /// }) + /// .await; + /// + /// assert_eq!(*m2.lock().await, 20); + /// # }) + /// ``` + pub async fn lock(&self) -> MutexGuard<'_, T> { + pub struct LockFuture<'a, T> { + mutex: &'a Mutex, + opt_key: Option, + acquired: bool, + } + + impl<'a, T> Future for LockFuture<'a, T> { + type Output = MutexGuard<'a, T>; + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + match self.mutex.try_lock() { + Some(guard) => { + self.acquired = true; + Poll::Ready(guard) + } + None => { + let mut blocked = self.mutex.blocked.lock().unwrap(); + + // Register the current task. + match self.opt_key { + None => { + // Insert a new entry into the list of blocked tasks. + let w = cx.waker().clone(); + let key = blocked.insert(Some(w)); + self.opt_key = Some(key); + + if blocked.len() == 1 { + self.mutex.state.fetch_or(BLOCKED, Ordering::Relaxed); + } + } + Some(key) => { + // There is already an entry in the list of blocked tasks. Just + // reset the waker if it was removed. + if blocked[key].is_none() { + let w = cx.waker().clone(); + blocked[key] = Some(w); + } + } + } + + // Try locking again because it's possible the mutex got unlocked just + // before the current task was registered as a blocked task. + match self.mutex.try_lock() { + Some(guard) => { + self.acquired = true; + Poll::Ready(guard) + } + None => Poll::Pending, + } + } + } + } + } + + impl Drop for LockFuture<'_, T> { + fn drop(&mut self) { + if let Some(key) = self.opt_key { + let mut blocked = self.mutex.blocked.lock().unwrap(); + let opt_waker = blocked.remove(key); + + if opt_waker.is_none() && !self.acquired { + // We were awoken but didn't acquire the lock. Wake up another task. + if let Some((_, opt_waker)) = blocked.iter_mut().next() { + if let Some(w) = opt_waker.take() { + w.wake(); + } + } + } + + if blocked.is_empty() { + self.mutex.state.fetch_and(!BLOCKED, Ordering::Relaxed); + } + } + } + } + + LockFuture { + mutex: self, + opt_key: None, + acquired: false, + } + .await + } + + /// Attempts to acquire the lock. + /// + /// If the lock could not be acquired at this time, then [`None`] is returned. Otherwise, a + /// guard is returned that releases the lock when dropped. + /// + /// [`None`]: https://doc.rust-lang.org/std/option/enum.Option.html#variant.None + /// + /// # Examples + /// + /// ``` + /// # #![feature(async_await)] + /// use async_std::{sync::Mutex, task}; + /// use std::sync::Arc; + /// + /// # futures::executor::block_on(async { + /// let m1 = Arc::new(Mutex::new(10)); + /// let m2 = m1.clone(); + /// + /// task::spawn(async move { + /// if let Some(mut guard) = m1.try_lock() { + /// *guard = 20; + /// } else { + /// println!("try_lock failed"); + /// } + /// }) + /// .await; + /// + /// assert_eq!(*m2.lock().await, 20); + /// # }) + /// ``` + pub fn try_lock(&self) -> Option> { + if self.state.fetch_or(LOCK, Ordering::Acquire) & LOCK == 0 { + Some(MutexGuard(self)) + } else { + None + } + } + + /// Consumes the mutex, returning the underlying data. + /// + /// # Examples + /// + /// ``` + /// # #![feature(async_await)] + /// use async_std::sync::Mutex; + /// + /// let mutex = Mutex::new(10); + /// assert_eq!(mutex.into_inner(), 10); + /// ``` + pub fn into_inner(self) -> T { + self.value.into_inner() + } + + /// Returns a mutable reference to the underlying data. + /// + /// Since this call borrows the mutex mutably, no actual locking takes place -- the mutable + /// borrow statically guarantees no locks exist. + /// + /// # Examples + /// + /// ``` + /// # #![feature(async_await)] + /// use async_std::sync::Mutex; + /// + /// # futures::executor::block_on(async { + /// let mut mutex = Mutex::new(0); + /// *mutex.get_mut() = 10; + /// assert_eq!(*mutex.lock().await, 10); + /// }); + /// ``` + pub fn get_mut(&mut self) -> &mut T { + unsafe { &mut *self.value.get() } + } +} + +impl fmt::Debug for Mutex { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self.try_lock() { + None => { + struct LockedPlaceholder; + impl fmt::Debug for LockedPlaceholder { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.write_str("") + } + } + f.debug_struct("Mutex") + .field("data", &LockedPlaceholder) + .finish() + } + Some(guard) => f.debug_struct("Mutex").field("data", &&*guard).finish(), + } + } +} + +impl From for Mutex { + fn from(val: T) -> Mutex { + Mutex::new(val) + } +} + +impl Default for Mutex { + fn default() -> Mutex { + Mutex::new(Default::default()) + } +} + +/// A guard that releases the lock when dropped. +pub struct MutexGuard<'a, T>(&'a Mutex); + +unsafe impl Send for MutexGuard<'_, T> {} +unsafe impl Sync for MutexGuard<'_, T> {} + +impl Drop for MutexGuard<'_, T> { + fn drop(&mut self) { + let state = self.0.state.fetch_and(!LOCK, Ordering::AcqRel); + + // If there are any blocked tasks, wake one of them up. + if state & BLOCKED != 0 { + let mut blocked = self.0.blocked.lock().unwrap(); + + if let Some((_, opt_waker)) = blocked.iter_mut().next() { + // If there is no waker in this entry, that means it was already woken. + if let Some(w) = opt_waker.take() { + w.wake(); + } + } + } + } +} + +impl fmt::Debug for MutexGuard<'_, T> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + fmt::Debug::fmt(&**self, f) + } +} + +impl fmt::Display for MutexGuard<'_, T> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + (**self).fmt(f) + } +} + +impl Deref for MutexGuard<'_, T> { + type Target = T; + + fn deref(&self) -> &T { + unsafe { &*self.0.value.get() } + } +} + +impl DerefMut for MutexGuard<'_, T> { + fn deref_mut(&mut self) -> &mut T { + unsafe { &mut *self.0.value.get() } + } +} diff --git a/src/sync/rwlock.rs b/src/sync/rwlock.rs new file mode 100644 index 00000000..6ad6e38f --- /dev/null +++ b/src/sync/rwlock.rs @@ -0,0 +1,573 @@ +use std::cell::UnsafeCell; +use std::fmt; +use std::future::Future; +use std::ops::{Deref, DerefMut}; +use std::pin::Pin; +use std::sync::atomic::{AtomicUsize, Ordering}; +use std::task::{Context, Poll, Waker}; + +use slab::Slab; + +/// Set if a write lock is held. +const WRITE_LOCK: usize = 1 << 0; + +/// Set if there are read operations blocked on the lock. +const BLOCKED_READS: usize = 1 << 1; + +/// Set if there are write operations blocked on the lock. +const BLOCKED_WRITES: usize = 1 << 2; + +/// The value of a single blocked read contributing to the read count. +const ONE_READ: usize = 1 << 3; + +/// The bits in which the read count is stored. +const READ_COUNT_MASK: usize = !(ONE_READ - 1); + +/// A reader-writer lock for protecting shared data. +/// +/// This type is an async version of [`std::sync::RwLock`]. +/// +/// [`std::sync::RwLock`]: https://doc.rust-lang.org/std/sync/struct.RwLock.html +/// +/// # Examples +/// +/// ``` +/// # #![feature(async_await)] +/// use async_std::sync::RwLock; +/// +/// # futures::executor::block_on(async { +/// let lock = RwLock::new(5); +/// +/// // Multiple read locks can be held at a time. +/// let r1 = lock.read().await; +/// let r2 = lock.read().await; +/// assert_eq!(*r1, 5); +/// assert_eq!(*r2, 5); +/// drop((r1, r2)); +/// +/// // Only one write locks can be held at a time. +/// let mut w = lock.write().await; +/// *w += 1; +/// assert_eq!(*w, 6); +/// # }) +/// ``` +pub struct RwLock { + state: AtomicUsize, + reads: std::sync::Mutex>>, + writes: std::sync::Mutex>>, + value: UnsafeCell, +} + +unsafe impl Send for RwLock {} +unsafe impl Sync for RwLock {} + +impl RwLock { + /// Creates a new reader-writer lock. + /// + /// # Examples + /// + /// ``` + /// use async_std::sync::RwLock; + /// + /// let lock = RwLock::new(0); + /// ``` + pub fn new(t: T) -> RwLock { + RwLock { + state: AtomicUsize::new(0), + reads: std::sync::Mutex::new(Slab::new()), + writes: std::sync::Mutex::new(Slab::new()), + value: UnsafeCell::new(t), + } + } + + /// Acquires a read lock. + /// + /// Returns a guard that releases the lock when dropped. + /// + /// # Examples + /// + /// ``` + /// # #![feature(async_await)] + /// use async_std::sync::RwLock; + /// + /// # futures::executor::block_on(async { + /// let lock = RwLock::new(1); + /// + /// let n = lock.read().await; + /// assert_eq!(*n, 1); + /// + /// assert!(lock.try_read().is_some()); + /// # }) + /// ``` + pub async fn read(&self) -> RwLockReadGuard<'_, T> { + pub struct LockFuture<'a, T> { + lock: &'a RwLock, + opt_key: Option, + acquired: bool, + } + + impl<'a, T> Future for LockFuture<'a, T> { + type Output = RwLockReadGuard<'a, T>; + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + match self.lock.try_read() { + Some(guard) => { + self.acquired = true; + Poll::Ready(guard) + } + None => { + let mut reads = self.lock.reads.lock().unwrap(); + + // Register the current task. + match self.opt_key { + None => { + // Insert a new entry into the list of blocked reads. + let w = cx.waker().clone(); + let key = reads.insert(Some(w)); + self.opt_key = Some(key); + + if reads.len() == 1 { + self.lock.state.fetch_or(BLOCKED_READS, Ordering::Relaxed); + } + } + Some(key) => { + // There is already an entry in the list of blocked reads. Just + // reset the waker if it was removed. + if reads[key].is_none() { + let w = cx.waker().clone(); + reads[key] = Some(w); + } + } + } + + // Try locking again because it's possible the lock got unlocked just + // before the current task was registered as a blocked task. + match self.lock.try_read() { + Some(guard) => { + self.acquired = true; + Poll::Ready(guard) + } + None => Poll::Pending, + } + } + } + } + } + + impl Drop for LockFuture<'_, T> { + fn drop(&mut self) { + if let Some(key) = self.opt_key { + let mut reads = self.lock.reads.lock().unwrap(); + let opt_waker = reads.remove(key); + + if reads.is_empty() { + self.lock.state.fetch_and(!BLOCKED_READS, Ordering::Relaxed); + } + + if opt_waker.is_none() { + // We were awoken. Wake up another blocked read. + if let Some((_, opt_waker)) = reads.iter_mut().next() { + if let Some(w) = opt_waker.take() { + w.wake(); + return; + } + } + drop(reads); + + if !self.acquired { + // We didn't acquire the lock and didn't wake another blocked read. + // Wake a blocked write instead. + let mut writes = self.lock.writes.lock().unwrap(); + if let Some((_, opt_waker)) = writes.iter_mut().next() { + if let Some(w) = opt_waker.take() { + w.wake(); + return; + } + } + } + } + } + } + } + + LockFuture { + lock: self, + opt_key: None, + acquired: false, + } + .await + } + + /// Attempts to acquire a read lock. + /// + /// If a read lock could not be acquired at this time, then [`None`] is returned. Otherwise, a + /// guard is returned that releases the lock when dropped. + /// + /// [`None`]: https://doc.rust-lang.org/std/option/enum.Option.html#variant.None + /// + /// # Examples + /// + /// ``` + /// # #![feature(async_await)] + /// use async_std::sync::RwLock; + /// + /// # futures::executor::block_on(async { + /// let lock = RwLock::new(1); + /// + /// let mut n = lock.read().await; + /// assert_eq!(*n, 1); + /// + /// assert!(lock.try_read().is_some()); + /// # }) + /// ``` + pub fn try_read(&self) -> Option> { + let mut state = self.state.load(Ordering::Acquire); + + loop { + // If a write lock is currently held, then a read lock cannot be acquired. + if state & WRITE_LOCK != 0 { + return None; + } + + // Increment the number of active reads. + match self.state.compare_exchange_weak( + state, + state + ONE_READ, + Ordering::AcqRel, + Ordering::Acquire, + ) { + Ok(_) => return Some(RwLockReadGuard(self)), + Err(s) => state = s, + } + } + } + + /// Acquires a write lock. + /// + /// Returns a guard that releases the lock when dropped. + /// + /// # Examples + /// + /// ``` + /// # #![feature(async_await)] + /// use async_std::sync::RwLock; + /// + /// # futures::executor::block_on(async { + /// let lock = RwLock::new(1); + /// + /// let mut n = lock.write().await; + /// *n = 2; + /// + /// assert!(lock.try_read().is_none()); + /// # }) + /// ``` + pub async fn write(&self) -> RwLockWriteGuard<'_, T> { + pub struct LockFuture<'a, T> { + lock: &'a RwLock, + opt_key: Option, + acquired: bool, + } + + impl<'a, T> Future for LockFuture<'a, T> { + type Output = RwLockWriteGuard<'a, T>; + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + match self.lock.try_write() { + Some(guard) => { + self.acquired = true; + Poll::Ready(guard) + } + None => { + let mut writes = self.lock.writes.lock().unwrap(); + + // Register the current task. + match self.opt_key { + None => { + // Insert a new entry into the list of blocked writes. + let w = cx.waker().clone(); + let key = writes.insert(Some(w)); + self.opt_key = Some(key); + + if writes.len() == 1 { + self.lock.state.fetch_or(BLOCKED_WRITES, Ordering::Relaxed); + } + } + Some(key) => { + // There is already an entry in the list of blocked writes. Just + // reset the waker if it was removed. + if writes[key].is_none() { + let w = cx.waker().clone(); + writes[key] = Some(w); + } + } + } + + // Try locking again because it's possible the lock got unlocked just + // before the current task was registered as a blocked task. + match self.lock.try_write() { + Some(guard) => { + self.acquired = true; + Poll::Ready(guard) + } + None => Poll::Pending, + } + } + } + } + } + + impl Drop for LockFuture<'_, T> { + fn drop(&mut self) { + if let Some(key) = self.opt_key { + let mut writes = self.lock.writes.lock().unwrap(); + let opt_waker = writes.remove(key); + + if writes.is_empty() { + self.lock + .state + .fetch_and(!BLOCKED_WRITES, Ordering::Relaxed); + } + + if opt_waker.is_none() && !self.acquired { + // We were awoken but didn't acquire the lock. Wake up another write. + if let Some((_, opt_waker)) = writes.iter_mut().next() { + if let Some(w) = opt_waker.take() { + w.wake(); + return; + } + } + drop(writes); + + // There are no blocked writes. Wake a blocked read instead. + let mut reads = self.lock.reads.lock().unwrap(); + if let Some((_, opt_waker)) = reads.iter_mut().next() { + if let Some(w) = opt_waker.take() { + w.wake(); + return; + } + } + } + } + } + } + + LockFuture { + lock: self, + opt_key: None, + acquired: false, + } + .await + } + + /// Attempts to acquire a write lock. + /// + /// If a write lock could not be acquired at this time, then [`None`] is returned. Otherwise, a + /// guard is returned that releases the lock when dropped. + /// + /// [`None`]: https://doc.rust-lang.org/std/option/enum.Option.html#variant.None + /// + /// # Examples + /// + /// ``` + /// # #![feature(async_await)] + /// use async_std::sync::RwLock; + /// + /// # futures::executor::block_on(async { + /// let lock = RwLock::new(1); + /// + /// let mut n = lock.read().await; + /// assert_eq!(*n, 1); + /// + /// assert!(lock.try_write().is_none()); + /// # }) + /// ``` + pub fn try_write(&self) -> Option> { + let mut state = self.state.load(Ordering::Acquire); + + loop { + // If any kind of lock is currently held, then a write lock cannot be acquired. + if state & (WRITE_LOCK | READ_COUNT_MASK) != 0 { + return None; + } + + // Set the write lock. + match self.state.compare_exchange_weak( + state, + state | WRITE_LOCK, + Ordering::AcqRel, + Ordering::Acquire, + ) { + Ok(_) => return Some(RwLockWriteGuard(self)), + Err(s) => state = s, + } + } + } + + /// Consumes the lock, returning the underlying data. + /// + /// # Examples + /// + /// ``` + /// # #![feature(async_await)] + /// use async_std::sync::RwLock; + /// + /// let lock = RwLock::new(10); + /// assert_eq!(lock.into_inner(), 10); + /// ``` + pub fn into_inner(self) -> T { + self.value.into_inner() + } + + /// Returns a mutable reference to the underlying data. + /// + /// Since this call borrows the lock mutably, no actual locking takes place -- the mutable + /// borrow statically guarantees no locks exist. + /// + /// # Examples + /// + /// ``` + /// # #![feature(async_await)] + /// use async_std::sync::RwLock; + /// + /// # futures::executor::block_on(async { + /// let mut lock = RwLock::new(0); + /// *lock.get_mut() = 10; + /// assert_eq!(*lock.write().await, 10); + /// }); + /// ``` + pub fn get_mut(&mut self) -> &mut T { + unsafe { &mut *self.value.get() } + } +} + +impl fmt::Debug for RwLock { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self.try_read() { + None => { + struct LockedPlaceholder; + impl fmt::Debug for LockedPlaceholder { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.write_str("") + } + } + f.debug_struct("RwLock") + .field("data", &LockedPlaceholder) + .finish() + } + Some(guard) => f.debug_struct("RwLock").field("data", &&*guard).finish(), + } + } +} + +impl From for RwLock { + fn from(val: T) -> RwLock { + RwLock::new(val) + } +} + +impl Default for RwLock { + fn default() -> RwLock { + RwLock::new(Default::default()) + } +} + +/// A guard that releases the read lock when dropped. +pub struct RwLockReadGuard<'a, T>(&'a RwLock); + +unsafe impl Send for RwLockReadGuard<'_, T> {} +unsafe impl Sync for RwLockReadGuard<'_, T> {} + +impl Drop for RwLockReadGuard<'_, T> { + fn drop(&mut self) { + let state = self.0.state.fetch_sub(ONE_READ, Ordering::AcqRel); + + // If this was the last read and there are blocked writes, wake one of them up. + if (state & READ_COUNT_MASK) == ONE_READ && state & BLOCKED_WRITES != 0 { + let mut writes = self.0.writes.lock().unwrap(); + + if let Some((_, opt_waker)) = writes.iter_mut().next() { + // If there is no waker in this entry, that means it was already woken. + if let Some(w) = opt_waker.take() { + w.wake(); + } + } + } + } +} + +impl fmt::Debug for RwLockReadGuard<'_, T> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + fmt::Debug::fmt(&**self, f) + } +} + +impl fmt::Display for RwLockReadGuard<'_, T> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + (**self).fmt(f) + } +} + +impl Deref for RwLockReadGuard<'_, T> { + type Target = T; + + fn deref(&self) -> &T { + unsafe { &*self.0.value.get() } + } +} + +/// A guard that releases the write lock when dropped. +pub struct RwLockWriteGuard<'a, T>(&'a RwLock); + +unsafe impl Send for RwLockWriteGuard<'_, T> {} +unsafe impl Sync for RwLockWriteGuard<'_, T> {} + +impl Drop for RwLockWriteGuard<'_, T> { + fn drop(&mut self) { + let state = self.0.state.fetch_and(!WRITE_LOCK, Ordering::AcqRel); + + let mut guard = None; + + // Check if there are any blocked reads or writes. + if state & BLOCKED_READS != 0 { + guard = Some(self.0.reads.lock().unwrap()); + } else if state & BLOCKED_WRITES != 0 { + guard = Some(self.0.writes.lock().unwrap()); + } + + // Wake up a single blocked task. + if let Some(mut guard) = guard { + if let Some((_, opt_waker)) = guard.iter_mut().next() { + // If there is no waker in this entry, that means it was already woken. + if let Some(w) = opt_waker.take() { + w.wake(); + } + } + } + } +} + +impl fmt::Debug for RwLockWriteGuard<'_, T> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + fmt::Debug::fmt(&**self, f) + } +} + +impl fmt::Display for RwLockWriteGuard<'_, T> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + (**self).fmt(f) + } +} + +impl Deref for RwLockWriteGuard<'_, T> { + type Target = T; + + fn deref(&self) -> &T { + unsafe { &*self.0.value.get() } + } +} + +impl DerefMut for RwLockWriteGuard<'_, T> { + fn deref_mut(&mut self) -> &mut T { + unsafe { &mut *self.0.value.get() } + } +} diff --git a/src/task/blocking.rs b/src/task/blocking.rs new file mode 100644 index 00000000..ccffa7db --- /dev/null +++ b/src/task/blocking.rs @@ -0,0 +1,70 @@ +//! A thread pool for running blocking functions asynchronously. + +use std::fmt; +use std::future::Future; +use std::pin::Pin; +use std::task::{Context, Poll}; +use std::thread; + +use crossbeam::channel::{unbounded, Receiver, Sender}; +use lazy_static::lazy_static; + +use crate::utils::abort_on_panic; + +struct Pool { + sender: Sender>, + receiver: Receiver>, +} + +lazy_static! { + static ref POOL: Pool = { + for _ in 0..2 { + thread::Builder::new() + .name("async-blocking-driver".to_string()) + .spawn(|| { + for task in &POOL.receiver { + abort_on_panic(|| task.run()); + } + }) + .expect("cannot start a thread driving blocking tasks"); + } + + let (sender, receiver) = unbounded(); + Pool { sender, receiver } + }; +} + +/// Spawns a blocking task. +/// +/// The task will be spawned onto a thread pool specifically dedicated to blocking tasks. +pub fn spawn(future: F) -> JoinHandle +where + F: Future + Send + 'static, + R: Send + 'static, +{ + let schedule = |t| POOL.sender.send(t).unwrap(); + let (task, handle) = async_task::spawn(future, schedule, ()); + task.schedule(); + JoinHandle(handle) +} + +/// A handle to a blocking task. +pub struct JoinHandle(async_task::JoinHandle); + +impl Unpin for JoinHandle {} + +impl Future for JoinHandle { + type Output = R; + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + Pin::new(&mut self.0).poll(cx).map(|out| out.unwrap()) + } +} + +impl fmt::Debug for JoinHandle { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("JoinHandle") + .field("handle", &self.0) + .finish() + } +} diff --git a/src/task/local.rs b/src/task/local.rs new file mode 100644 index 00000000..3a8601c2 --- /dev/null +++ b/src/task/local.rs @@ -0,0 +1,250 @@ +use std::cell::UnsafeCell; +use std::error::Error; +use std::fmt; +use std::sync::atomic::{AtomicUsize, Ordering}; +use std::sync::Mutex; + +use lazy_static::lazy_static; + +use super::pool; + +/// Declares task-local values. +/// +/// The macro wraps any number of static declarations and makes them task-local. Attributes and +/// visibility modifiers are allowed. +/// +/// Each declared value is of the accessor type [`LocalKey`]. +/// +/// [`LocalKey`]: task/struct.LocalKey.html +/// +/// # Examples +/// +/// ``` +/// # #![feature(async_await)] +/// use async_std::{task, task_local}; +/// use std::cell::Cell; +/// +/// task_local! { +/// static VAL: Cell = Cell::new(5); +/// } +/// +/// task::block_on(async { +/// let v = VAL.with(|c| c.get()); +/// assert_eq!(v, 5); +/// }); +/// ``` +#[macro_export] +macro_rules! task_local { + () => (); + + ($(#[$attr:meta])* $vis:vis static $name:ident: $t:ty = $init:expr) => ( + $(#[$attr])* $vis static $name: $crate::task::LocalKey<$t> = { + #[inline] + fn __init() -> $t { + $init + } + + $crate::task::LocalKey { + __init, + __key: ::std::sync::atomic::AtomicUsize::new(0), + } + }; + ); + + ($(#[$attr:meta])* $vis:vis static $name:ident: $t:ty = $init:expr; $($rest:tt)*) => ( + $crate::task_local!($(#[$attr])* $vis static $name: $t = $init); + $crate::task_local!($($rest)*); + ); +} + +/// The key for accessing a task-local value. +/// +/// Every task-local value is lazily initialized on first access and destroyed when the task +/// completes. +#[derive(Debug)] +pub struct LocalKey { + #[doc(hidden)] + pub __init: fn() -> T, + + #[doc(hidden)] + pub __key: AtomicUsize, +} + +impl LocalKey { + /// Gets a reference to the task-local value with this key. + /// + /// The passed closure receives a reference to the task-local value. + /// + /// The task-local value will be lazily initialized if this task has not accessed it before. + /// + /// # Panics + /// + /// This function will panic if not called within the context of a task created by + /// [`block_on`], [`spawn`], or [`Builder::spawn`]. + /// + /// [`block_on`]: fn.block_on.html + /// [`spawn`]: fn.spawn.html + /// [`Builder::spawn`]: struct.Builder.html#method.spawn + /// + /// # Examples + /// + /// ``` + /// # #![feature(async_await)] + /// use async_std::{task, task_local}; + /// use std::cell::Cell; + /// + /// task_local! { + /// static FOO: Cell = Cell::new(5); + /// } + /// + /// task::block_on(async { + /// let v = FOO.with(|c| c.get()); + /// assert_eq!(v, 5); + /// }); + /// ``` + pub fn with(&'static self, f: F) -> R + where + F: FnOnce(&T) -> R, + { + self.try_with(f) + .expect("`LocalKey::with` called outside the context of a task") + } + + /// Attempts to get a reference to the task-local value with this key. + /// + /// The passed closure receives a reference to the task-local value. + /// + /// The task-local value will be lazily initialized if this task has not accessed it before. + /// + /// This function returns an error if not called within the context of a task created by + /// [`block_on`], [`spawn`], or [`Builder::spawn`]. + /// + /// [`block_on`]: fn.block_on.html + /// [`spawn`]: fn.spawn.html + /// [`Builder::spawn`]: struct.Builder.html#method.spawn + /// + /// # Examples + /// + /// ``` + /// # #![feature(async_await)] + /// use async_std::{task, task_local}; + /// use std::cell::Cell; + /// + /// task_local! { + /// static VAL: Cell = Cell::new(5); + /// } + /// + /// task::block_on(async { + /// let v = VAL.try_with(|c| c.get()); + /// assert_eq!(v, Ok(5)); + /// }); + /// + /// // Returns an error because not called within the context of a task. + /// assert!(VAL.try_with(|c| c.get()).is_err()); + /// ``` + pub fn try_with(&'static self, f: F) -> Result + where + F: FnOnce(&T) -> R, + { + pool::get_task(|task| unsafe { + // Prepare the numeric key, initialization function, and the map of task-locals. + let key = self.key(); + let init = || Box::new((self.__init)()) as Box; + let map = &task.metadata().local_map; + + // Get the value in the map of task-locals, or initialize and insert one. + let value: *const dyn Send = map.get_or_insert(key, init); + + // Call the closure with the value passed as an argument. + f(&*(value as *const T)) + }) + .ok_or(AccessError { _private: () }) + } + + /// Returns the numeric key associated with this task-local. + #[inline] + fn key(&self) -> usize { + #[cold] + fn init(key: &AtomicUsize) -> usize { + lazy_static! { + static ref COUNTER: Mutex = Mutex::new(1); + } + + let mut counter = COUNTER.lock().unwrap(); + let prev = key.compare_and_swap(0, *counter, Ordering::AcqRel); + + if prev == 0 { + *counter += 1; + *counter - 1 + } else { + prev + } + } + + let key = self.__key.load(Ordering::Acquire); + if key == 0 { + init(&self.__key) + } else { + key + } + } +} + +/// An error returned by [`LocalKey::try_with`]. +/// +/// [`LocalKey::try_with`]: struct.LocalKey.html#method.try_with +#[derive(Clone, Copy, Eq, PartialEq)] +pub struct AccessError { + _private: (), +} + +impl fmt::Debug for AccessError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("AccessError").finish() + } +} + +impl fmt::Display for AccessError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + "already destroyed or called outside the context of a task".fmt(f) + } +} + +impl Error for AccessError {} + +/// A map that holds task-locals. +pub(crate) struct Map { + /// A list of `(key, value)` entries sorted by the key. + entries: UnsafeCell)>>, +} + +impl Map { + /// Creates an empty map of task-locals. + pub fn new() -> Map { + Map { + entries: UnsafeCell::new(Vec::new()), + } + } + + /// Returns a thread-local value associated with `key` or inserts one constructed by `init`. + #[inline] + pub fn get_or_insert(&self, key: usize, init: impl FnOnce() -> Box) -> &dyn Send { + let entries = unsafe { &mut *self.entries.get() }; + + let index = match entries.binary_search_by_key(&key, |e| e.0) { + Ok(i) => i, + Err(i) => { + entries.insert(i, (key, init())); + i + } + }; + + &*entries[index].1 + } + + /// Clears the map and drops all task-locals. + pub fn clear(&self) { + let entries = unsafe { &mut *self.entries.get() }; + entries.clear(); + } +} diff --git a/src/task/mod.rs b/src/task/mod.rs new file mode 100644 index 00000000..711d1393 --- /dev/null +++ b/src/task/mod.rs @@ -0,0 +1,37 @@ +//! Asynchronous tasks. +//! +//! This module is similar to [`std::thread`], except it uses asynchronous tasks in place of +//! threads. +//! +//! [`std::thread`]: https://doc.rust-lang.org/std/thread/index.html +//! +//! # Examples +//! +//! Spawn a task and await its result: +//! +//! ``` +//! # #![feature(async_await)] +//! use async_std::task; +//! +//! # async_std::task::block_on(async { +//! let handle = task::spawn(async { +//! 1 + 2 +//! }); +//! assert_eq!(handle.await, 3); +//! # }); +//! ``` + +#[doc(inline)] +pub use futures::task::{Context, Poll, Waker}; + +pub use local::{AccessError, LocalKey}; +pub use pool::{block_on, current, spawn, Builder}; +pub use sleep::sleep; +pub use task::{JoinHandle, Task, TaskId}; + +mod local; +mod pool; +mod sleep; +mod task; + +pub(crate) mod blocking; diff --git a/src/task/pool.rs b/src/task/pool.rs new file mode 100644 index 00000000..cf3da7f9 --- /dev/null +++ b/src/task/pool.rs @@ -0,0 +1,289 @@ +use std::cell::{Cell, UnsafeCell}; +use std::fmt::Arguments; +use std::future::Future; +use std::io; +use std::mem; +use std::panic::{self, AssertUnwindSafe}; +use std::pin::Pin; +use std::ptr; +use std::thread; + +use crossbeam::channel::{unbounded, Sender}; +use futures::prelude::*; +use lazy_static::lazy_static; + +use super::task; +use super::{JoinHandle, Task}; + +/// Returns a handle to the current task. +/// +/// # Panics +/// +/// This function will panic if not called within the context of a task created by [`block_on`], +/// [`spawn`], or [`Builder::spawn`]. +/// +/// [`block_on`]: fn.block_on.html +/// [`spawn`]: fn.spawn.html +/// [`Builder::spawn`]: struct.Builder.html#method.spawn +/// +/// # Examples +/// +/// ``` +/// # #![feature(async_await)] +/// use async_std::task::current; +/// +/// # async_std::task::block_on(async { +/// println!("The name of this task is {:?}", current().name()); +/// # }); +pub fn current() -> Task { + get_task(|task| task.clone()).expect("`task::current()` called outside the context of a task") +} + +/// Spawns a task. +/// +/// This function is similar to [`std::thread::spawn`], except it spawns an asynchronous task. +/// +/// [`std::thread`]: https://doc.rust-lang.org/std/thread/fn.spawn.html +/// +/// # Examples +/// +/// ``` +/// # #![feature(async_await)] +/// use async_std::task; +/// +/// # async_std::task::block_on(async { +/// let handle = task::spawn(async { +/// 1 + 2 +/// }); +/// +/// assert_eq!(handle.await, 3); +/// # }); +/// ``` +pub fn spawn(future: F) -> JoinHandle +where + F: Future + Send + 'static, + T: Send + 'static, +{ + spawn_with_builder(Builder::new(), future, "spawn") +} + +/// Spawns a task and blocks the current thread on its result. +/// +/// Calling this function is similar to [spawning] a thread and immediately [joining] it, except an +/// asynchronous task will be spawned. +/// +/// [spawning]: https://doc.rust-lang.org/std/thread/fn.spawn.html +/// [joining]: https://doc.rust-lang.org/std/thread/struct.JoinHandle.html#method.join +/// +/// # Examples +/// +/// ```no_run +/// # #![feature(async_await)] +/// use async_std::task; +/// +/// fn main() { +/// task::block_on(async { +/// println!("Hello, world!"); +/// }) +/// } +/// ``` +pub fn block_on(future: F) -> T +where + F: Future + Send, + T: Send, +{ + unsafe { + // A place on the stack where the result will be stored. + let out = &mut UnsafeCell::new(None); + + // Wrap the future into one that stores the result into `out`. + let future = { + let out = out.get(); + async move { + let v = AssertUnwindSafe(future).catch_unwind().await; + *out = Some(v); + } + }; + + // Pin the future onto the stack. + futures::pin_mut!(future); + + // Transmute the future into one that is static and sendable. + let future = mem::transmute::< + Pin<&mut dyn Future>, + Pin<&'static mut (dyn Future + Send)>, + >(future); + + // Spawn the future and wait for it to complete. + futures::executor::block_on(spawn_with_builder(Builder::new(), future, "block_on")); + + // Take out the result. + match (*out.get()).take().unwrap() { + Ok(v) => v, + Err(err) => panic::resume_unwind(err), + } + } +} + +/// Task builder that configures the settings of a new task. +#[derive(Debug)] +pub struct Builder { + pub(crate) name: Option, +} + +impl Builder { + /// Creates a new builder. + pub fn new() -> Builder { + Builder { name: None } + } + + /// Configures the name of the task. + pub fn name(mut self, name: String) -> Builder { + self.name = Some(name); + self + } + + /// Spawns a task with the configured settings. + pub fn spawn(self, future: F) -> io::Result> + where + F: Future + Send + 'static, + T: Send + 'static, + { + Ok(spawn_with_builder(self, future, "spawn")) + } +} + +pub(crate) fn spawn_with_builder( + builder: Builder, + future: F, + fn_name: &'static str, +) -> JoinHandle +where + F: Future + Send + 'static, + T: Send + 'static, +{ + let Builder { name } = builder; + + type Job = async_task::Task; + + lazy_static! { + static ref QUEUE: Sender = { + let (sender, receiver) = unbounded::(); + + for _ in 0..num_cpus::get().max(1) { + let receiver = receiver.clone(); + thread::Builder::new() + .name("async-task-driver".to_string()) + .spawn(|| { + TAG.with(|tag| { + for job in receiver { + tag.set(job.tag()); + abort_on_panic(|| job.run()); + tag.set(ptr::null()); + } + }); + }) + .expect("cannot start a thread driving tasks"); + } + + sender + }; + } + + let tag = task::Tag::new(name); + let schedule = |job| QUEUE.send(job).unwrap(); + + let child_id = tag.task_id().as_u64(); + let parent_id = get_task(|t| t.id().as_u64()).unwrap_or(0); + print( + format_args!("{}", fn_name), + LogData { + parent_id, + child_id, + }, + ); + + // Wrap the future into one that drops task-local variables on exit. + let future = async move { + let res = future.await; + + // Abort on panic because thread-local variables behave the same way. + abort_on_panic(|| get_task(|task| task.metadata().local_map.clear())); + + print( + format_args!("{} completed", fn_name), + LogData { + parent_id, + child_id, + }, + ); + res + }; + + let (task, handle) = async_task::spawn(future, schedule, tag); + task.schedule(); + JoinHandle::new(handle) +} + +thread_local! { + static TAG: Cell<*const task::Tag> = Cell::new(ptr::null_mut()); +} + +pub(crate) fn get_task R, R>(f: F) -> Option { + let res = TAG.try_with(|tag| unsafe { tag.get().as_ref().map(task::Tag::task).map(f) }); + + match res { + Ok(Some(val)) => Some(val), + Ok(None) | Err(_) => None, + } +} + +/// Calls a function and aborts if it panics. +/// +/// This is useful in unsafe code where we can't recover from panics. +#[inline] +fn abort_on_panic(f: impl FnOnce() -> T) -> T { + struct Bomb; + + impl Drop for Bomb { + fn drop(&mut self) { + std::process::abort(); + } + } + + let bomb = Bomb; + let t = f(); + mem::forget(bomb); + t +} + +/// This struct only exists because kv logging isn't supported from the macros right now. +struct LogData { + parent_id: u64, + child_id: u64, +} + +impl<'a> log::kv::Source for LogData { + fn visit<'kvs>( + &'kvs self, + visitor: &mut dyn log::kv::Visitor<'kvs>, + ) -> Result<(), log::kv::Error> { + visitor.visit_pair("parent_id".into(), self.parent_id.into())?; + visitor.visit_pair("child_id".into(), self.child_id.into())?; + Ok(()) + } +} + +fn print(msg: Arguments<'_>, key_values: impl log::kv::Source) { + log::logger().log( + &log::Record::builder() + .args(msg) + .key_values(&key_values) + .level(log::Level::Trace) + .target(module_path!()) + .module_path(Some(module_path!())) + .file(Some(file!())) + .line(Some(line!())) + .build(), + ); +} diff --git a/src/task/sleep.rs b/src/task/sleep.rs new file mode 100644 index 00000000..814ca6b4 --- /dev/null +++ b/src/task/sleep.rs @@ -0,0 +1,28 @@ +use std::time::Duration; + +use futures::prelude::*; + +use crate::time::Timeout; + +/// Sleeps for the specified amount of time. +/// +/// This function might sleep for slightly longer than the specified duration but never less. +/// +/// This function is an async version of [`std::thread::sleep`]. +/// +/// [`std::thread::sleep`]: https://doc.rust-lang.org/std/thread/fn.sleep.html +/// +/// # Examples +/// +/// ``` +/// # #![feature(async_await)] +/// use async_std::task; +/// use std::time::Duration; +/// +/// # async_std::task::block_on(async { +/// task::sleep(Duration::from_secs(1)).await; +/// # }); +/// ``` +pub async fn sleep(dur: Duration) { + let _ = future::pending::<()>().timeout(dur).await; +} diff --git a/src/task/task.rs b/src/task/task.rs new file mode 100644 index 00000000..02a26e80 --- /dev/null +++ b/src/task/task.rs @@ -0,0 +1,205 @@ +use std::fmt; +use std::future::Future; +use std::i64; +use std::mem; +use std::num::NonZeroU64; +use std::pin::Pin; +use std::sync::atomic::{AtomicU64, AtomicUsize, Ordering}; +use std::sync::Arc; +use std::task::{Context, Poll}; + +use super::local; + +/// A handle to a task. +#[derive(Clone)] +pub struct Task(Arc); + +unsafe impl Send for Task {} +unsafe impl Sync for Task {} + +impl Task { + /// Returns a reference to task metadata. + pub(crate) fn metadata(&self) -> &Metadata { + &self.0 + } + + /// Gets the task's unique identifier. + pub fn id(&self) -> TaskId { + self.metadata().task_id + } + + /// Returns the name of this task. + /// + /// The name is configured by [`Builder::name`] before spawning. + /// + /// [`Builder::name`]: struct.Builder.html#method.name + pub fn name(&self) -> Option<&str> { + self.metadata().name.as_ref().map(|s| s.as_str()) + } +} + +impl fmt::Debug for Task { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("Task").field("name", &self.name()).finish() + } +} + +/// A handle that awaits the result of a task. +/// +/// Created when a task is [spawned]. +/// +/// [spawned]: fn.spawn.html +#[derive(Debug)] +pub struct JoinHandle(async_task::JoinHandle); + +unsafe impl Send for JoinHandle {} +unsafe impl Sync for JoinHandle {} + +impl JoinHandle { + pub(crate) fn new(inner: async_task::JoinHandle) -> JoinHandle { + JoinHandle(inner) + } + + /// Returns a handle to the underlying task. + /// + /// # Examples + /// + /// ``` + /// # #![feature(async_await)] + /// use async_std::task; + /// + /// # async_std::task::block_on(async { + /// let handle = task::spawn(async { + /// 1 + 2 + /// }); + /// println!("id = {}", handle.task().id()); + /// # }); + pub fn task(&self) -> &Task { + self.0.tag().task() + } +} + +impl Future for JoinHandle { + type Output = T; + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + match Pin::new(&mut self.0).poll(cx) { + Poll::Pending => Poll::Pending, + Poll::Ready(None) => panic!("task has panicked"), + Poll::Ready(Some(val)) => Poll::Ready(val), + } + } +} + +/// A unique identifier for a task. +/// +/// # Examples +/// +/// ``` +/// # #![feature(async_await)] +/// use async_std::task; +/// +/// # async_std::task::block_on(async { +/// task::block_on(async { +/// println!("id = {:?}", task::current().id()); +/// }) +/// # }); +/// ``` +#[derive(Eq, PartialEq, Clone, Copy, Hash, Debug)] +pub struct TaskId(NonZeroU64); + +impl TaskId { + pub(crate) fn new() -> TaskId { + static COUNTER: AtomicU64 = AtomicU64::new(1); + + let id = COUNTER.fetch_add(1, Ordering::Relaxed); + + if id > i64::MAX as u64 { + std::process::abort(); + } + unsafe { TaskId(NonZeroU64::new_unchecked(id)) } + } + + pub(crate) fn as_u64(&self) -> u64 { + self.0.get() + } +} + +impl fmt::Display for TaskId { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{}", self.0) + } +} + +pub(crate) struct Metadata { + pub task_id: TaskId, + pub name: Option, + pub local_map: local::Map, +} + +pub(crate) struct Tag { + task_id: TaskId, + raw_metadata: AtomicUsize, +} + +impl Tag { + pub fn new(name: Option) -> Tag { + let task_id = TaskId::new(); + + let opt_task = name.map(|name| { + Task(Arc::new(Metadata { + task_id, + name: Some(name), + local_map: local::Map::new(), + })) + }); + + Tag { + task_id, + raw_metadata: AtomicUsize::new(unsafe { + mem::transmute::, usize>(opt_task) + }), + } + } + + pub fn task(&self) -> &Task { + unsafe { + let raw = self.raw_metadata.load(Ordering::Acquire); + + if mem::transmute::<&usize, &Option>(&raw).is_none() { + let new = Some(Task(Arc::new(Metadata { + task_id: TaskId::new(), + name: None, + local_map: local::Map::new(), + }))); + + let new_raw = mem::transmute::, usize>(new); + + if self + .raw_metadata + .compare_exchange(raw, new_raw, Ordering::AcqRel, Ordering::Acquire) + .is_err() + { + let new = mem::transmute::>(new_raw); + drop(new); + } + }; + + mem::transmute::<&AtomicUsize, &Option>(&self.raw_metadata) + .as_ref() + .unwrap() + } + } + + pub fn task_id(&self) -> TaskId { + self.task_id + } +} + +impl Drop for Tag { + fn drop(&mut self) { + let raw = *self.raw_metadata.get_mut(); + let opt_task = unsafe { mem::transmute::>(raw) }; + drop(opt_task); + } +} diff --git a/src/time/mod.rs b/src/time/mod.rs new file mode 100644 index 00000000..41104502 --- /dev/null +++ b/src/time/mod.rs @@ -0,0 +1,135 @@ +//! Timeouts for async operations. +//! +//! This module is an async extension of [`std::time`]. +//! +//! [`std::time`]: https://doc.rust-lang.org/std/time/index.html +//! +//! # Examples +//! +//! Read a line from stdin with a timeout of 5 seconds. +//! +//! ```no_run +//! # #![feature(async_await)] +//! # fn main() -> std::io::Result<()> { async_std::task::block_on(async { +//! # +//! use async_std::{io, prelude::*}; +//! use std::time::Duration; +//! +//! let stdin = io::stdin(); +//! let mut line = String::new(); +//! +//! let n = stdin +//! .read_line(&mut line) +//! .timeout(Duration::from_secs(5)) +//! .await??; +//! # +//! # Ok(()) }) } +//! ``` + +use std::error::Error; +use std::fmt; +use std::future::Future; +use std::io; +use std::pin::Pin; +use std::task::{Context, Poll}; +use std::time::Duration; + +use cfg_if::cfg_if; +use futures_timer::Delay; +use pin_utils::unsafe_pinned; + +/// An error returned when a future times out. +#[derive(Clone, Copy, Debug, Eq, PartialEq)] +pub struct TimeoutError; + +impl Error for TimeoutError {} + +impl fmt::Display for TimeoutError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + "future has timed out".fmt(f) + } +} + +impl From for io::Error { + fn from(_: TimeoutError) -> io::Error { + io::Error::new(io::ErrorKind::TimedOut, "future has timed out") + } +} + +cfg_if! { + if #[cfg(feature = "docs.rs")] { + #[doc(hidden)] + pub struct ImplFuture(std::marker::PhantomData); + + /// An extension trait that configures timeouts for futures. + pub trait Timeout: Future + Sized { + /// Awaits a future to completion or times out after a duration of time. + /// + /// # Examples + /// + /// ```no_run + /// # #![feature(async_await)] + /// # fn main() -> io::Result<()> { async_std::task::block_on(async { + /// # + /// use async_std::{io, prelude::*}; + /// use std::time::Duration; + /// + /// let stdin = io::stdin(); + /// let mut line = String::new(); + /// + /// let n = stdin + /// .read_line(&mut line) + /// .timeout(Duration::from_secs(5)) + /// .await??; + /// # + /// # Ok(()) }) } + /// ``` + fn timeout(self, dur: Duration) -> ImplFuture> { + TimeoutFuture { + future: self, + delay: Delay::new(dur), + } + } + } + } else { + /// An extension trait that configures timeouts for futures. + pub trait Timeout: Future + Sized { + /// Awaits a future to completion or times out after a duration of time. + fn timeout(self, dur: Duration) -> TimeoutFuture { + TimeoutFuture { + future: self, + delay: Delay::new(dur), + } + } + } + + /// A future that times out after a duration of time. + #[doc(hidden)] + #[derive(Debug)] + pub struct TimeoutFuture { + future: F, + delay: Delay, + } + + impl TimeoutFuture { + unsafe_pinned!(future: F); + unsafe_pinned!(delay: Delay); + } + + impl Future for TimeoutFuture { + type Output = Result; + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + match self.as_mut().future().poll(cx) { + Poll::Ready(v) => Poll::Ready(Ok(v)), + Poll::Pending => match self.delay().poll(cx) { + Poll::Ready(_) => Poll::Ready(Err(TimeoutError)), + Poll::Pending => Poll::Pending, + }, + } + } + } + } +} + +impl Timeout for F {} diff --git a/src/utils.rs b/src/utils.rs new file mode 100644 index 00000000..258042d8 --- /dev/null +++ b/src/utils.rs @@ -0,0 +1,21 @@ +use std::mem; +use std::process; + +/// Calls a function and aborts if it panics. +/// +/// This is useful in unsafe code where we can't recover from panics. +#[inline] +pub fn abort_on_panic(f: impl FnOnce() -> T) -> T { + struct Bomb; + + impl Drop for Bomb { + fn drop(&mut self) { + process::abort(); + } + } + + let bomb = Bomb; + let t = f(); + mem::forget(bomb); + t +} diff --git a/tests/block_on.rs b/tests/block_on.rs new file mode 100644 index 00000000..baa12402 --- /dev/null +++ b/tests/block_on.rs @@ -0,0 +1,18 @@ +#![feature(async_await)] + +use async_std::task; + +#[test] +fn smoke() { + let res = task::block_on(async { 1 + 2 }); + assert_eq!(res, 3); +} + +#[test] +#[should_panic = "foo"] +fn panic() { + task::block_on(async { + // This panic should get propagated into the parent thread. + panic!("foo"); + }); +} diff --git a/tests/mutex.rs b/tests/mutex.rs new file mode 100644 index 00000000..7e418946 --- /dev/null +++ b/tests/mutex.rs @@ -0,0 +1,66 @@ +#![feature(async_await)] + +use std::sync::Arc; + +use async_std::sync::Mutex; +use async_std::task; +use futures::channel::mpsc; +use futures::prelude::*; + +#[test] +fn smoke() { + task::block_on(async { + let m = Mutex::new(()); + drop(m.lock().await); + drop(m.lock().await); + }) +} + +#[test] +fn try_lock() { + let m = Mutex::new(()); + *m.try_lock().unwrap() = (); +} + +#[test] +fn into_inner() { + let m = Mutex::new(10); + assert_eq!(m.into_inner(), 10); +} + +#[test] +fn get_mut() { + let mut m = Mutex::new(10); + *m.get_mut() = 20; + assert_eq!(m.into_inner(), 20); +} + +#[test] +fn contention() { + task::block_on(async { + let (tx, mut rx) = mpsc::unbounded(); + + let tx = Arc::new(tx); + let mutex = Arc::new(Mutex::new(0)); + let num_tasks = 10000; + + for _ in 0..num_tasks { + let tx = tx.clone(); + let mutex = mutex.clone(); + + task::spawn(async move { + let mut lock = mutex.lock().await; + *lock += 1; + tx.unbounded_send(()).unwrap(); + drop(lock); + }); + } + + for _ in 0..num_tasks { + rx.next().await.unwrap(); + } + + let lock = mutex.lock().await; + assert_eq!(num_tasks, *lock); + }); +} diff --git a/tests/rwlock.rs b/tests/rwlock.rs new file mode 100644 index 00000000..2e083834 --- /dev/null +++ b/tests/rwlock.rs @@ -0,0 +1,186 @@ +#![feature(async_await)] + +use std::cell::Cell; +use std::num::Wrapping; +use std::pin::Pin; +use std::sync::atomic::{AtomicUsize, Ordering}; +use std::sync::Arc; +use std::task::{Context, Poll}; + +use async_std::sync::RwLock; +use async_std::task; +use futures::channel::mpsc; +use futures::prelude::*; + +/// Generates a random number in `0..n`. +pub fn random(n: u32) -> u32 { + thread_local! { + static RNG: Cell> = Cell::new(Wrapping(1406868647)); + } + + RNG.with(|rng| { + // This is the 32-bit variant of Xorshift. + // + // Source: https://en.wikipedia.org/wiki/Xorshift + let mut x = rng.get(); + x ^= x << 13; + x ^= x >> 17; + x ^= x << 5; + rng.set(x); + + // This is a fast alternative to `x % n`. + // + // Author: Daniel Lemire + // Source: https://lemire.me/blog/2016/06/27/a-fast-alternative-to-the-modulo-reduction/ + ((x.0 as u64).wrapping_mul(n as u64) >> 32) as u32 + }) +} + +#[test] +fn smoke() { + task::block_on(async { + let lock = RwLock::new(()); + drop(lock.read().await); + drop(lock.write().await); + drop((lock.read().await, lock.read().await)); + drop(lock.write().await); + }); +} + +#[test] +fn try_write() { + task::block_on(async { + let lock = RwLock::new(0isize); + let read_guard = lock.read().await; + assert!(lock.try_write().is_none()); + drop(read_guard); + }); +} + +#[test] +fn into_inner() { + let lock = RwLock::new(10); + assert_eq!(lock.into_inner(), 10); +} + +#[test] +fn into_inner_and_drop() { + struct Counter(Arc); + + impl Drop for Counter { + fn drop(&mut self) { + self.0.fetch_add(1, Ordering::SeqCst); + } + } + + let cnt = Arc::new(AtomicUsize::new(0)); + let lock = RwLock::new(Counter(cnt.clone())); + assert_eq!(cnt.load(Ordering::SeqCst), 0); + + { + let _inner = lock.into_inner(); + assert_eq!(cnt.load(Ordering::SeqCst), 0); + } + + assert_eq!(cnt.load(Ordering::SeqCst), 1); +} + +#[test] +fn get_mut() { + let mut lock = RwLock::new(10); + *lock.get_mut() = 20; + assert_eq!(lock.into_inner(), 20); +} + +#[test] +fn contention() { + const N: u32 = 10; + const M: usize = 1000; + + let (tx, mut rx) = mpsc::unbounded(); + let tx = Arc::new(tx); + let rw = Arc::new(RwLock::new(())); + + // Spawn N tasks that randomly acquire the lock M times. + for _ in 0..N { + let tx = tx.clone(); + let rw = rw.clone(); + + task::spawn(async move { + for _ in 0..M { + if random(N) == 0 { + drop(rw.write().await); + } else { + drop(rw.read().await); + } + } + tx.unbounded_send(()).unwrap(); + }); + } + + task::block_on(async { + for _ in 0..N { + rx.next().await.unwrap(); + } + }); +} + +#[test] +fn writer_and_readers() { + #[derive(Default)] + struct Yield(Cell); + + impl Future for Yield { + type Output = (); + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + if self.0.get() { + Poll::Ready(()) + } else { + self.0.set(true); + cx.waker().wake_by_ref(); + Poll::Pending + } + } + } + + let lock = Arc::new(RwLock::new(0i32)); + let (tx, mut rx) = mpsc::unbounded(); + + // Spawn a writer task. + task::spawn({ + let lock = lock.clone(); + async move { + let mut lock = lock.write().await; + for _ in 0..10 { + let tmp = *lock; + *lock = -1; + Yield::default().await; + *lock = tmp + 1; + } + tx.unbounded_send(()).unwrap(); + } + }); + + // Readers try to catch the writer in the act. + let mut readers = Vec::new(); + for _ in 0..5 { + let lock = lock.clone(); + readers.push(task::spawn(async move { + let lock = lock.read().await; + assert!(*lock >= 0); + })); + } + + task::block_on(async { + // Wait for readers to pass their asserts. + for r in readers { + r.await; + } + + // Wait for writer to finish. + rx.next().await.unwrap(); + let lock = lock.read().await; + assert_eq!(*lock, 10); + }); +} diff --git a/tests/task_local.rs b/tests/task_local.rs new file mode 100644 index 00000000..eb18637c --- /dev/null +++ b/tests/task_local.rs @@ -0,0 +1,35 @@ +#![feature(async_await)] + +use std::sync::atomic::{AtomicBool, Ordering}; + +use async_std::{task, task_local}; + +#[test] +fn drop_local() { + static DROP_LOCAL: AtomicBool = AtomicBool::new(false); + + struct Local; + + impl Drop for Local { + fn drop(&mut self) { + DROP_LOCAL.store(true, Ordering::SeqCst); + } + } + + task_local! { + static LOCAL: Local = Local; + } + + // Spawn a task that just touches its task-local. + let handle = task::spawn(async { + LOCAL.with(|_| ()); + }); + let task = handle.task().clone(); + + // Wait for the task to finish and make sure its task-local has been dropped. + task::block_on(async { + handle.await; + assert!(DROP_LOCAL.load(Ordering::SeqCst)); + drop(task); + }); +} diff --git a/tests/tcp.rs b/tests/tcp.rs new file mode 100644 index 00000000..a2f47ce6 --- /dev/null +++ b/tests/tcp.rs @@ -0,0 +1,53 @@ +#![feature(async_await)] + +use async_std::io; +use async_std::net::{TcpListener, TcpStream}; +use async_std::prelude::*; +use async_std::task; + +const THE_WINTERS_TALE: &[u8] = b" + Each your doing, + So singular in each particular, + Crowns what you are doing in the present deed, + That all your acts are queens. +"; + +#[test] +fn connect() -> io::Result<()> { + task::block_on(async { + let listener = TcpListener::bind("127.0.0.1:0").await?; + let addr = listener.local_addr()?; + let t = task::spawn(async move { listener.accept().await }); + + let stream2 = TcpStream::connect(&addr).await?; + let stream1 = t.await?.0; + + assert_eq!(stream1.peer_addr()?, stream2.local_addr()?); + assert_eq!(stream2.peer_addr()?, stream1.local_addr()?); + + Ok(()) + }) +} + +#[test] +fn incoming_read() -> io::Result<()> { + task::block_on(async { + let listener = TcpListener::bind("127.0.0.1:0").await?; + let addr = listener.local_addr()?; + + task::spawn(async move { + let mut stream = TcpStream::connect(&addr).await?; + stream.write_all(THE_WINTERS_TALE).await?; + io::Result::Ok(()) + }); + + let mut buf = vec![0; 1024]; + let mut incoming = listener.incoming(); + let mut stream = incoming.next().await.unwrap()?; + + let n = stream.read(&mut buf).await?; + assert_eq!(&buf[..n], THE_WINTERS_TALE); + + Ok(()) + }) +} diff --git a/tests/udp.rs b/tests/udp.rs new file mode 100644 index 00000000..24d8e39c --- /dev/null +++ b/tests/udp.rs @@ -0,0 +1,31 @@ +#![feature(async_await)] + +use async_std::io; +use async_std::net::UdpSocket; +use async_std::task; + +const THE_MERCHANT_OF_VENICE: &[u8] = b" + If you prick us, do we not bleed? + If you tickle us, do we not laugh? + If you poison us, do we not die? + And if you wrong us, shall we not revenge? +"; + +#[test] +fn send_recv() -> io::Result<()> { + task::block_on(async { + let socket1 = UdpSocket::bind("127.0.0.1:0").await?; + let socket2 = UdpSocket::bind("127.0.0.1:0").await?; + + socket1.connect(socket2.local_addr()?).await?; + socket2.connect(socket1.local_addr()?).await?; + + socket1.send(THE_MERCHANT_OF_VENICE).await?; + + let mut buf = [0u8; 1024]; + let n = socket2.recv(&mut buf).await?; + assert_eq!(&buf[..n], THE_MERCHANT_OF_VENICE); + + Ok(()) + }) +} diff --git a/tests/uds.rs b/tests/uds.rs new file mode 100644 index 00000000..f062bf98 --- /dev/null +++ b/tests/uds.rs @@ -0,0 +1,25 @@ +#![cfg(unix)] +#![feature(async_await)] + +use async_std::io; +use async_std::os::unix::net::UnixDatagram; +use async_std::task; + +const JULIUS_CAESAR: &[u8] = b" + Friends, Romans, countrymen - lend me your ears! + I come not to praise Caesar, but to bury him. +"; + +#[test] +fn send_recv() -> io::Result<()> { + task::block_on(async { + let (socket1, socket2) = UnixDatagram::pair().unwrap(); + socket1.send(JULIUS_CAESAR).await?; + + let mut buf = vec![0; 1024]; + let n = socket2.recv(&mut buf).await?; + assert_eq!(&buf[..n], JULIUS_CAESAR); + + Ok(()) + }) +}