fix: allow for recursive block-on calls

Fixes #798,#795,#760
pull/809/head
Friedel Ziegelmayer 4 years ago committed by GitHub
parent 631105b650
commit e12cf80ab0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -75,7 +75,7 @@ futures-timer = { version = "3.0.2", optional = true }
surf = { version = "1.0.3", optional = true }
[target.'cfg(not(target_os = "unknown"))'.dependencies]
smol = { version = "0.1.10", optional = true }
smol = { version = "0.1.11", optional = true }
[target.'cfg(target_arch = "wasm32")'.dependencies]
futures-timer = { version = "3.0.2", optional = true, features = ["wasm-bindgen"] }
@ -103,3 +103,4 @@ required-features = ["unstable"]
[[example]]
name = "surf-web"
required-features = ["surf"]

@ -1,3 +1,4 @@
use std::cell::Cell;
use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
@ -150,8 +151,31 @@ impl Builder {
parent_task_id: TaskLocalsWrapper::get_current(|t| t.id().0).unwrap_or(0),
});
thread_local! {
/// Tracks the number of nested block_on calls.
static NUM_NESTED_BLOCKING: Cell<usize> = Cell::new(0);
}
// Run the future as a task.
unsafe { TaskLocalsWrapper::set_current(&wrapped.tag, || smol::run(wrapped)) }
NUM_NESTED_BLOCKING.with(|num_nested_blocking| {
let count = num_nested_blocking.get();
let should_run = count == 0;
// increase the count
num_nested_blocking.replace(count + 1);
unsafe {
TaskLocalsWrapper::set_current(&wrapped.tag, || {
let res = if should_run {
// The first call should use run.
smol::run(wrapped)
} else {
smol::block_on(wrapped)
};
num_nested_blocking.replace(num_nested_blocking.get() - 1);
res
})
}
})
}
}

@ -1,18 +1,63 @@
#![cfg(not(target_os = "unknown"))]
use async_std::task;
use async_std::{future::ready, task::block_on};
#[test]
fn smoke() {
let res = task::block_on(async { 1 + 2 });
let res = block_on(async { 1 + 2 });
assert_eq!(res, 3);
}
#[test]
#[should_panic = "boom"]
fn panic() {
task::block_on(async {
block_on(async {
// This panic should get propagated into the parent thread.
panic!("boom");
});
}
#[cfg(feature = "unstable")]
#[test]
fn nested_block_on_local() {
use async_std::task::spawn_local;
let x = block_on(async {
let a = block_on(async { block_on(async { ready(3).await }) });
let b = spawn_local(async { block_on(async { ready(2).await }) }).await;
let c = block_on(async { block_on(async { ready(1).await }) });
a + b + c
});
assert_eq!(x, 3 + 2 + 1);
let y = block_on(async {
let a = block_on(async { block_on(async { ready(3).await }) });
let b = spawn_local(async { block_on(async { ready(2).await }) }).await;
let c = block_on(async { block_on(async { ready(1).await }) });
a + b + c
});
assert_eq!(y, 3 + 2 + 1);
}
#[test]
fn nested_block_on() {
let x = block_on(async {
let a = block_on(async { block_on(async { ready(3).await }) });
let b = block_on(async { block_on(async { ready(2).await }) });
let c = block_on(async { block_on(async { ready(1).await }) });
a + b + c
});
assert_eq!(x, 3 + 2 + 1);
let y = block_on(async {
let a = block_on(async { block_on(async { ready(3).await }) });
let b = block_on(async { block_on(async { ready(2).await }) });
let c = block_on(async { block_on(async { ready(1).await }) });
a + b + c
});
assert_eq!(y, 3 + 2 + 1);
}

Loading…
Cancel
Save