diff --git a/src/task/block_on.rs b/src/task/block_on.rs index 92c4b51..ae0a012 100644 --- a/src/task/block_on.rs +++ b/src/task/block_on.rs @@ -6,10 +6,12 @@ use std::sync::Arc; use std::task::{RawWaker, RawWakerVTable}; use std::thread::{self, Thread}; +use super::log_utils; use super::pool; -use super::Builder; +use super::task; use crate::future::Future; use crate::task::{Context, Poll, Waker}; +use crate::utils::abort_on_panic; /// Spawns a task and blocks the current thread on its result. /// @@ -32,8 +34,7 @@ use crate::task::{Context, Poll, Waker}; /// ``` pub fn block_on(future: F) -> T where - F: Future + Send, - T: Send, + F: Future, { unsafe { // A place on the stack where the result will be stored. @@ -51,17 +52,48 @@ where } }; + // Create a tag for the task. + let tag = task::Tag::new(None); + + // Log this `block_on` operation. + let child_id = tag.task_id().as_u64(); + let parent_id = pool::get_task(|t| t.id().as_u64()).unwrap_or(0); + log_utils::print( + format_args!("block_on"), + log_utils::LogData { + parent_id, + child_id, + }, + ); + + // Wrap the future into one that drops task-local variables on exit. + let future = async move { + let res = future.await; + + // Abort on panic because thread-local variables behave the same way. + abort_on_panic(|| pool::get_task(|task| task.metadata().local_map.clear())); + + log_utils::print( + format_args!("block_on completed"), + log_utils::LogData { + parent_id, + child_id, + }, + ); + res + }; + // Pin the future onto the stack. pin_utils::pin_mut!(future); - // Transmute the future into one that is static and sendable. + // Transmute the future into one that is static. let future = mem::transmute::< - Pin<&mut dyn Future>, - Pin<&'static mut (dyn Future + Send)>, + Pin<&'_ mut dyn Future>, + Pin<&'static mut dyn Future>, >(future); - // Spawn the future and wait for it to complete. - block(pool::spawn_with_builder(Builder::new(), future, "block_on")); + // Block on the future and and wait for it to complete. + pool::set_tag(&tag, || block(future)); // Take out the result. match (*out.get()).take().unwrap() { @@ -87,7 +119,10 @@ impl Future for CatchUnwindFuture { } } -fn block(f: F) -> F::Output { +fn block(f: F) -> T +where + F: Future, +{ thread_local! { static ARC_THREAD: Arc = Arc::new(thread::current()); } diff --git a/src/task/log_utils.rs b/src/task/log_utils.rs new file mode 100644 index 0000000..ad0fe8c --- /dev/null +++ b/src/task/log_utils.rs @@ -0,0 +1,32 @@ +use std::fmt::Arguments; + +/// This struct only exists because kv logging isn't supported from the macros right now. +pub(crate) struct LogData { + pub parent_id: u64, + pub child_id: u64, +} + +impl<'a> log::kv::Source for LogData { + fn visit<'kvs>( + &'kvs self, + visitor: &mut dyn log::kv::Visitor<'kvs>, + ) -> Result<(), log::kv::Error> { + visitor.visit_pair("parent_id".into(), self.parent_id.into())?; + visitor.visit_pair("child_id".into(), self.child_id.into())?; + Ok(()) + } +} + +pub fn print(msg: Arguments<'_>, key_values: impl log::kv::Source) { + log::logger().log( + &log::Record::builder() + .args(msg) + .key_values(&key_values) + .level(log::Level::Trace) + .target(module_path!()) + .module_path(Some(module_path!())) + .file(Some(file!())) + .line(Some(line!())) + .build(), + ); +} diff --git a/src/task/mod.rs b/src/task/mod.rs index eef7284..21b0533 100644 --- a/src/task/mod.rs +++ b/src/task/mod.rs @@ -32,6 +32,7 @@ pub use task::{JoinHandle, Task, TaskId}; mod block_on; mod local; +mod log_utils; mod pool; mod sleep; mod task; diff --git a/src/task/pool.rs b/src/task/pool.rs index 3640909..c996848 100644 --- a/src/task/pool.rs +++ b/src/task/pool.rs @@ -1,16 +1,16 @@ use std::cell::Cell; -use std::fmt::Arguments; -use std::mem; use std::ptr; use std::thread; use crossbeam_channel::{unbounded, Sender}; use lazy_static::lazy_static; +use super::log_utils; use super::task; use super::{JoinHandle, Task}; use crate::future::Future; use crate::io; +use crate::utils::abort_on_panic; /// Returns a handle to the current task. /// @@ -64,7 +64,7 @@ where F: Future + Send + 'static, T: Send + 'static, { - spawn_with_builder(Builder::new(), future, "spawn") + spawn_with_builder(Builder::new(), future) } /// Task builder that configures the settings of a new task. @@ -91,15 +91,11 @@ impl Builder { F: Future + Send + 'static, T: Send + 'static, { - Ok(spawn_with_builder(self, future, "spawn")) + Ok(spawn_with_builder(self, future)) } } -pub(crate) fn spawn_with_builder( - builder: Builder, - future: F, - fn_name: &'static str, -) -> JoinHandle +pub(crate) fn spawn_with_builder(builder: Builder, future: F) -> JoinHandle where F: Future + Send + 'static, T: Send + 'static, @@ -117,13 +113,9 @@ where thread::Builder::new() .name("async-task-driver".to_string()) .spawn(|| { - TAG.with(|tag| { - for job in receiver { - tag.set(job.tag()); - abort_on_panic(|| job.run()); - tag.set(ptr::null()); - } - }); + for job in receiver { + set_tag(job.tag(), || abort_on_panic(|| job.run())) + } }) .expect("cannot start a thread driving tasks"); } @@ -135,11 +127,12 @@ where let tag = task::Tag::new(name); let schedule = |job| QUEUE.send(job).unwrap(); + // Log this `spawn` operation. let child_id = tag.task_id().as_u64(); let parent_id = get_task(|t| t.id().as_u64()).unwrap_or(0); - print( - format_args!("{}", fn_name), - LogData { + log_utils::print( + format_args!("spawn"), + log_utils::LogData { parent_id, child_id, }, @@ -152,9 +145,9 @@ where // Abort on panic because thread-local variables behave the same way. abort_on_panic(|| get_task(|task| task.metadata().local_map.clear())); - print( - format_args!("{} completed", fn_name), - LogData { + log_utils::print( + format_args!("spawn completed"), + log_utils::LogData { parent_id, child_id, }, @@ -171,61 +164,34 @@ thread_local! { static TAG: Cell<*const task::Tag> = Cell::new(ptr::null_mut()); } -pub(crate) fn get_task R, R>(f: F) -> Option { - let res = TAG.try_with(|tag| unsafe { tag.get().as_ref().map(task::Tag::task).map(f) }); - - match res { - Ok(Some(val)) => Some(val), - Ok(None) | Err(_) => None, - } -} - -/// Calls a function and aborts if it panics. -/// -/// This is useful in unsafe code where we can't recover from panics. -#[inline] -fn abort_on_panic(f: impl FnOnce() -> T) -> T { - struct Bomb; +pub(crate) fn set_tag(tag: *const task::Tag, f: F) -> R +where + F: FnOnce() -> R, +{ + struct ResetTag<'a>(&'a Cell<*const task::Tag>); - impl Drop for Bomb { + impl Drop for ResetTag<'_> { fn drop(&mut self) { - std::process::abort(); + self.0.set(ptr::null()); } } - let bomb = Bomb; - let t = f(); - mem::forget(bomb); - t -} + TAG.with(|t| { + t.set(tag); + let _guard = ResetTag(t); -/// This struct only exists because kv logging isn't supported from the macros right now. -struct LogData { - parent_id: u64, - child_id: u64, + f() + }) } -impl<'a> log::kv::Source for LogData { - fn visit<'kvs>( - &'kvs self, - visitor: &mut dyn log::kv::Visitor<'kvs>, - ) -> Result<(), log::kv::Error> { - visitor.visit_pair("parent_id".into(), self.parent_id.into())?; - visitor.visit_pair("child_id".into(), self.child_id.into())?; - Ok(()) - } -} +pub(crate) fn get_task(f: F) -> Option +where + F: FnOnce(&Task) -> R, +{ + let res = TAG.try_with(|tag| unsafe { tag.get().as_ref().map(task::Tag::task).map(f) }); -fn print(msg: Arguments<'_>, key_values: impl log::kv::Source) { - log::logger().log( - &log::Record::builder() - .args(msg) - .key_values(&key_values) - .level(log::Level::Trace) - .target(module_path!()) - .module_path(Some(module_path!())) - .file(Some(file!())) - .line(Some(line!())) - .build(), - ); + match res { + Ok(Some(val)) => Some(val), + Ok(None) | Err(_) => None, + } }