bevy_tasks/task_pool.rs
1use std::{
2 future::Future,
3 marker::PhantomData,
4 mem,
5 panic::AssertUnwindSafe,
6 sync::Arc,
7 thread::{self, JoinHandle},
8};
9
10use async_executor::FallibleTask;
11use concurrent_queue::ConcurrentQueue;
12use futures_lite::FutureExt;
13
14use crate::{
15 block_on,
16 thread_executor::{ThreadExecutor, ThreadExecutorTicker},
17 Task,
18};
19
20struct CallOnDrop(Option<Arc<dyn Fn() + Send + Sync + 'static>>);
21
22impl Drop for CallOnDrop {
23 fn drop(&mut self) {
24 if let Some(call) = self.0.as_ref() {
25 call();
26 }
27 }
28}
29
30/// Used to create a [`TaskPool`]
31#[derive(Default)]
32#[must_use]
33pub struct TaskPoolBuilder {
34 /// If set, we'll set up the thread pool to use at most `num_threads` threads.
35 /// Otherwise use the logical core count of the system
36 num_threads: Option<usize>,
37 /// If set, we'll use the given stack size rather than the system default
38 stack_size: Option<usize>,
39 /// Allows customizing the name of the threads - helpful for debugging. If set, threads will
40 /// be named `<thread_name> (<thread_index>)`, i.e. `"MyThreadPool (2)"`.
41 thread_name: Option<String>,
42
43 on_thread_spawn: Option<Arc<dyn Fn() + Send + Sync + 'static>>,
44 on_thread_destroy: Option<Arc<dyn Fn() + Send + Sync + 'static>>,
45}
46
47impl TaskPoolBuilder {
48 /// Creates a new [`TaskPoolBuilder`] instance
49 pub fn new() -> Self {
50 Self::default()
51 }
52
53 /// Override the number of threads created for the pool. If unset, we default to the number
54 /// of logical cores of the system
55 pub fn num_threads(mut self, num_threads: usize) -> Self {
56 self.num_threads = Some(num_threads);
57 self
58 }
59
60 /// Override the stack size of the threads created for the pool
61 pub fn stack_size(mut self, stack_size: usize) -> Self {
62 self.stack_size = Some(stack_size);
63 self
64 }
65
66 /// Override the name of the threads created for the pool. If set, threads will
67 /// be named `<thread_name> (<thread_index>)`, i.e. `MyThreadPool (2)`
68 pub fn thread_name(mut self, thread_name: String) -> Self {
69 self.thread_name = Some(thread_name);
70 self
71 }
72
73 /// Sets a callback that is invoked once for every created thread as it starts.
74 ///
75 /// This is called on the thread itself and has access to all thread-local storage.
76 /// This will block running async tasks on the thread until the callback completes.
77 pub fn on_thread_spawn(mut self, f: impl Fn() + Send + Sync + 'static) -> Self {
78 self.on_thread_spawn = Some(Arc::new(f));
79 self
80 }
81
82 /// Sets a callback that is invoked once for every created thread as it terminates.
83 ///
84 /// This is called on the thread itself and has access to all thread-local storage.
85 /// This will block thread termination until the callback completes.
86 pub fn on_thread_destroy(mut self, f: impl Fn() + Send + Sync + 'static) -> Self {
87 self.on_thread_destroy = Some(Arc::new(f));
88 self
89 }
90
91 /// Creates a new [`TaskPool`] based on the current options.
92 pub fn build(self) -> TaskPool {
93 TaskPool::new_internal(self)
94 }
95}
96
97/// A thread pool for executing tasks.
98///
99/// While futures usually need to be polled to be executed, Bevy tasks are being
100/// automatically driven by the pool on threads owned by the pool. The [`Task`]
101/// future only needs to be polled in order to receive the result. (For that
102/// purpose, it is often stored in a component or resource, see the
103/// `async_compute` example.)
104///
105/// If the result is not required, one may also use [`Task::detach`] and the pool
106/// will still execute a task, even if it is dropped.
107#[derive(Debug)]
108pub struct TaskPool {
109 /// The executor for the pool.
110 executor: Arc<async_executor::Executor<'static>>,
111
112 // The inner state of the pool.
113 threads: Vec<JoinHandle<()>>,
114 shutdown_tx: async_channel::Sender<()>,
115}
116
117impl TaskPool {
118 thread_local! {
119 static LOCAL_EXECUTOR: async_executor::LocalExecutor<'static> = const { async_executor::LocalExecutor::new() };
120 static THREAD_EXECUTOR: Arc<ThreadExecutor<'static>> = Arc::new(ThreadExecutor::new());
121 }
122
123 /// Each thread should only create one `ThreadExecutor`, otherwise, there are good chances they will deadlock
124 pub fn get_thread_executor() -> Arc<ThreadExecutor<'static>> {
125 Self::THREAD_EXECUTOR.with(|executor| executor.clone())
126 }
127
128 /// Create a `TaskPool` with the default configuration.
129 pub fn new() -> Self {
130 TaskPoolBuilder::new().build()
131 }
132
133 fn new_internal(builder: TaskPoolBuilder) -> Self {
134 let (shutdown_tx, shutdown_rx) = async_channel::unbounded::<()>();
135
136 let executor = Arc::new(async_executor::Executor::new());
137
138 let num_threads = builder
139 .num_threads
140 .unwrap_or_else(crate::available_parallelism);
141
142 let threads = (0..num_threads)
143 .map(|i| {
144 let ex = Arc::clone(&executor);
145 let shutdown_rx = shutdown_rx.clone();
146
147 let thread_name = if let Some(thread_name) = builder.thread_name.as_deref() {
148 format!("{thread_name} ({i})")
149 } else {
150 format!("TaskPool ({i})")
151 };
152 let mut thread_builder = thread::Builder::new().name(thread_name);
153
154 if let Some(stack_size) = builder.stack_size {
155 thread_builder = thread_builder.stack_size(stack_size);
156 }
157
158 let on_thread_spawn = builder.on_thread_spawn.clone();
159 let on_thread_destroy = builder.on_thread_destroy.clone();
160
161 thread_builder
162 .spawn(move || {
163 TaskPool::LOCAL_EXECUTOR.with(|local_executor| {
164 if let Some(on_thread_spawn) = on_thread_spawn {
165 on_thread_spawn();
166 drop(on_thread_spawn);
167 }
168 let _destructor = CallOnDrop(on_thread_destroy);
169 loop {
170 let res = std::panic::catch_unwind(|| {
171 let tick_forever = async move {
172 loop {
173 local_executor.tick().await;
174 }
175 };
176 block_on(ex.run(tick_forever.or(shutdown_rx.recv())))
177 });
178 if let Ok(value) = res {
179 // Use unwrap_err because we expect a Closed error
180 value.unwrap_err();
181 break;
182 }
183 }
184 });
185 })
186 .expect("Failed to spawn thread.")
187 })
188 .collect();
189
190 Self {
191 executor,
192 threads,
193 shutdown_tx,
194 }
195 }
196
197 /// Return the number of threads owned by the task pool
198 pub fn thread_num(&self) -> usize {
199 self.threads.len()
200 }
201
202 /// Allows spawning non-`'static` futures on the thread pool. The function takes a callback,
203 /// passing a scope object into it. The scope object provided to the callback can be used
204 /// to spawn tasks. This function will await the completion of all tasks before returning.
205 ///
206 /// This is similar to [`thread::scope`] and `rayon::scope`.
207 ///
208 /// # Example
209 ///
210 /// ```
211 /// use bevy_tasks::TaskPool;
212 ///
213 /// let pool = TaskPool::new();
214 /// let mut x = 0;
215 /// let results = pool.scope(|s| {
216 /// s.spawn(async {
217 /// // you can borrow the spawner inside a task and spawn tasks from within the task
218 /// s.spawn(async {
219 /// // borrow x and mutate it.
220 /// x = 2;
221 /// // return a value from the task
222 /// 1
223 /// });
224 /// // return some other value from the first task
225 /// 0
226 /// });
227 /// });
228 ///
229 /// // The ordering of results is non-deterministic if you spawn from within tasks as above.
230 /// // If you're doing this, you'll have to write your code to not depend on the ordering.
231 /// assert!(results.contains(&0));
232 /// assert!(results.contains(&1));
233 ///
234 /// // The ordering is deterministic if you only spawn directly from the closure function.
235 /// let results = pool.scope(|s| {
236 /// s.spawn(async { 0 });
237 /// s.spawn(async { 1 });
238 /// });
239 /// assert_eq!(&results[..], &[0, 1]);
240 ///
241 /// // You can access x after scope runs, since it was only temporarily borrowed in the scope.
242 /// assert_eq!(x, 2);
243 /// ```
244 ///
245 /// # Lifetimes
246 ///
247 /// The [`Scope`] object takes two lifetimes: `'scope` and `'env`.
248 ///
249 /// The `'scope` lifetime represents the lifetime of the scope. That is the time during
250 /// which the provided closure and tasks that are spawned into the scope are run.
251 ///
252 /// The `'env` lifetime represents the lifetime of whatever is borrowed by the scope.
253 /// Thus this lifetime must outlive `'scope`.
254 ///
255 /// ```compile_fail
256 /// use bevy_tasks::TaskPool;
257 /// fn scope_escapes_closure() {
258 /// let pool = TaskPool::new();
259 /// let foo = Box::new(42);
260 /// pool.scope(|scope| {
261 /// std::thread::spawn(move || {
262 /// // UB. This could spawn on the scope after `.scope` returns and the internal Scope is dropped.
263 /// scope.spawn(async move {
264 /// assert_eq!(*foo, 42);
265 /// });
266 /// });
267 /// });
268 /// }
269 /// ```
270 ///
271 /// ```compile_fail
272 /// use bevy_tasks::TaskPool;
273 /// fn cannot_borrow_from_closure() {
274 /// let pool = TaskPool::new();
275 /// pool.scope(|scope| {
276 /// let x = 1;
277 /// let y = &x;
278 /// scope.spawn(async move {
279 /// assert_eq!(*y, 1);
280 /// });
281 /// });
282 /// }
283 pub fn scope<'env, F, T>(&self, f: F) -> Vec<T>
284 where
285 F: for<'scope> FnOnce(&'scope Scope<'scope, 'env, T>),
286 T: Send + 'static,
287 {
288 Self::THREAD_EXECUTOR.with(|scope_executor| {
289 self.scope_with_executor_inner(true, scope_executor, scope_executor, f)
290 })
291 }
292
293 /// This allows passing an external executor to spawn tasks on. When you pass an external executor
294 /// [`Scope::spawn_on_scope`] spawns is then run on the thread that [`ThreadExecutor`] is being ticked on.
295 /// If [`None`] is passed the scope will use a [`ThreadExecutor`] that is ticked on the current thread.
296 ///
297 /// When `tick_task_pool_executor` is set to `true`, the multithreaded task stealing executor is ticked on the scope
298 /// thread. Disabling this can be useful when finishing the scope is latency sensitive. Pulling tasks from
299 /// global executor can run tasks unrelated to the scope and delay when the scope returns.
300 ///
301 /// See [`Self::scope`] for more details in general about how scopes work.
302 pub fn scope_with_executor<'env, F, T>(
303 &self,
304 tick_task_pool_executor: bool,
305 external_executor: Option<&ThreadExecutor>,
306 f: F,
307 ) -> Vec<T>
308 where
309 F: for<'scope> FnOnce(&'scope Scope<'scope, 'env, T>),
310 T: Send + 'static,
311 {
312 Self::THREAD_EXECUTOR.with(|scope_executor| {
313 // If a `external_executor` is passed use that. Otherwise get the executor stored
314 // in the `THREAD_EXECUTOR` thread local.
315 if let Some(external_executor) = external_executor {
316 self.scope_with_executor_inner(
317 tick_task_pool_executor,
318 external_executor,
319 scope_executor,
320 f,
321 )
322 } else {
323 self.scope_with_executor_inner(
324 tick_task_pool_executor,
325 scope_executor,
326 scope_executor,
327 f,
328 )
329 }
330 })
331 }
332
333 #[allow(unsafe_code)]
334 fn scope_with_executor_inner<'env, F, T>(
335 &self,
336 tick_task_pool_executor: bool,
337 external_executor: &ThreadExecutor,
338 scope_executor: &ThreadExecutor,
339 f: F,
340 ) -> Vec<T>
341 where
342 F: for<'scope> FnOnce(&'scope Scope<'scope, 'env, T>),
343 T: Send + 'static,
344 {
345 // SAFETY: This safety comment applies to all references transmuted to 'env.
346 // Any futures spawned with these references need to return before this function completes.
347 // This is guaranteed because we drive all the futures spawned onto the Scope
348 // to completion in this function. However, rust has no way of knowing this so we
349 // transmute the lifetimes to 'env here to appease the compiler as it is unable to validate safety.
350 // Any usages of the references passed into `Scope` must be accessed through
351 // the transmuted reference for the rest of this function.
352 let executor: &async_executor::Executor = &self.executor;
353 // SAFETY: As above, all futures must complete in this function so we can change the lifetime
354 let executor: &'env async_executor::Executor = unsafe { mem::transmute(executor) };
355 // SAFETY: As above, all futures must complete in this function so we can change the lifetime
356 let external_executor: &'env ThreadExecutor<'env> =
357 unsafe { mem::transmute(external_executor) };
358 // SAFETY: As above, all futures must complete in this function so we can change the lifetime
359 let scope_executor: &'env ThreadExecutor<'env> = unsafe { mem::transmute(scope_executor) };
360 let spawned: ConcurrentQueue<FallibleTask<Result<T, Box<(dyn std::any::Any + Send)>>>> =
361 ConcurrentQueue::unbounded();
362 // shadow the variable so that the owned value cannot be used for the rest of the function
363 // SAFETY: As above, all futures must complete in this function so we can change the lifetime
364 let spawned: &'env ConcurrentQueue<
365 FallibleTask<Result<T, Box<(dyn std::any::Any + Send)>>>,
366 > = unsafe { mem::transmute(&spawned) };
367
368 let scope = Scope {
369 executor,
370 external_executor,
371 scope_executor,
372 spawned,
373 scope: PhantomData,
374 env: PhantomData,
375 };
376
377 // shadow the variable so that the owned value cannot be used for the rest of the function
378 // SAFETY: As above, all futures must complete in this function so we can change the lifetime
379 let scope: &'env Scope<'_, 'env, T> = unsafe { mem::transmute(&scope) };
380
381 f(scope);
382
383 if spawned.is_empty() {
384 Vec::new()
385 } else {
386 block_on(async move {
387 let get_results = async {
388 let mut results = Vec::with_capacity(spawned.len());
389 while let Ok(task) = spawned.pop() {
390 if let Some(res) = task.await {
391 match res {
392 Ok(res) => results.push(res),
393 Err(payload) => std::panic::resume_unwind(payload),
394 }
395 } else {
396 panic!("Failed to catch panic!");
397 }
398 }
399 results
400 };
401
402 let tick_task_pool_executor = tick_task_pool_executor || self.threads.is_empty();
403
404 // we get this from a thread local so we should always be on the scope executors thread.
405 // note: it is possible `scope_executor` and `external_executor` is the same executor,
406 // in that case, we should only tick one of them, otherwise, it may cause deadlock.
407 let scope_ticker = scope_executor.ticker().unwrap();
408 let external_ticker = if !external_executor.is_same(scope_executor) {
409 external_executor.ticker()
410 } else {
411 None
412 };
413
414 match (external_ticker, tick_task_pool_executor) {
415 (Some(external_ticker), true) => {
416 Self::execute_global_external_scope(
417 executor,
418 external_ticker,
419 scope_ticker,
420 get_results,
421 )
422 .await
423 }
424 (Some(external_ticker), false) => {
425 Self::execute_external_scope(external_ticker, scope_ticker, get_results)
426 .await
427 }
428 // either external_executor is none or it is same as scope_executor
429 (None, true) => {
430 Self::execute_global_scope(executor, scope_ticker, get_results).await
431 }
432 (None, false) => Self::execute_scope(scope_ticker, get_results).await,
433 }
434 })
435 }
436 }
437
438 #[inline]
439 async fn execute_global_external_scope<'scope, 'ticker, T>(
440 executor: &'scope async_executor::Executor<'scope>,
441 external_ticker: ThreadExecutorTicker<'scope, 'ticker>,
442 scope_ticker: ThreadExecutorTicker<'scope, 'ticker>,
443 get_results: impl Future<Output = Vec<T>>,
444 ) -> Vec<T> {
445 // we restart the executors if a task errors. if a scoped
446 // task errors it will panic the scope on the call to get_results
447 let execute_forever = async move {
448 loop {
449 let tick_forever = async {
450 loop {
451 external_ticker.tick().or(scope_ticker.tick()).await;
452 }
453 };
454 // we don't care if it errors. If a scoped task errors it will propagate
455 // to get_results
456 let _result = AssertUnwindSafe(executor.run(tick_forever))
457 .catch_unwind()
458 .await
459 .is_ok();
460 }
461 };
462 execute_forever.or(get_results).await
463 }
464
465 #[inline]
466 async fn execute_external_scope<'scope, 'ticker, T>(
467 external_ticker: ThreadExecutorTicker<'scope, 'ticker>,
468 scope_ticker: ThreadExecutorTicker<'scope, 'ticker>,
469 get_results: impl Future<Output = Vec<T>>,
470 ) -> Vec<T> {
471 let execute_forever = async {
472 loop {
473 let tick_forever = async {
474 loop {
475 external_ticker.tick().or(scope_ticker.tick()).await;
476 }
477 };
478 let _result = AssertUnwindSafe(tick_forever).catch_unwind().await.is_ok();
479 }
480 };
481 execute_forever.or(get_results).await
482 }
483
484 #[inline]
485 async fn execute_global_scope<'scope, 'ticker, T>(
486 executor: &'scope async_executor::Executor<'scope>,
487 scope_ticker: ThreadExecutorTicker<'scope, 'ticker>,
488 get_results: impl Future<Output = Vec<T>>,
489 ) -> Vec<T> {
490 let execute_forever = async {
491 loop {
492 let tick_forever = async {
493 loop {
494 scope_ticker.tick().await;
495 }
496 };
497 let _result = AssertUnwindSafe(executor.run(tick_forever))
498 .catch_unwind()
499 .await
500 .is_ok();
501 }
502 };
503 execute_forever.or(get_results).await
504 }
505
506 #[inline]
507 async fn execute_scope<'scope, 'ticker, T>(
508 scope_ticker: ThreadExecutorTicker<'scope, 'ticker>,
509 get_results: impl Future<Output = Vec<T>>,
510 ) -> Vec<T> {
511 let execute_forever = async {
512 loop {
513 let tick_forever = async {
514 loop {
515 scope_ticker.tick().await;
516 }
517 };
518 let _result = AssertUnwindSafe(tick_forever).catch_unwind().await.is_ok();
519 }
520 };
521 execute_forever.or(get_results).await
522 }
523
524 /// Spawns a static future onto the thread pool. The returned [`Task`] is a
525 /// future that can be polled for the result. It can also be canceled and
526 /// "detached", allowing the task to continue running even if dropped. In
527 /// any case, the pool will execute the task even without polling by the
528 /// end-user.
529 ///
530 /// If the provided future is non-`Send`, [`TaskPool::spawn_local`] should
531 /// be used instead.
532 pub fn spawn<T>(&self, future: impl Future<Output = T> + Send + 'static) -> Task<T>
533 where
534 T: Send + 'static,
535 {
536 Task::new(self.executor.spawn(future))
537 }
538
539 /// Spawns a static future on the thread-local async executor for the
540 /// current thread. The task will run entirely on the thread the task was
541 /// spawned on.
542 ///
543 /// The returned [`Task`] is a future that can be polled for the
544 /// result. It can also be canceled and "detached", allowing the task to
545 /// continue running even if dropped. In any case, the pool will execute the
546 /// task even without polling by the end-user.
547 ///
548 /// Users should generally prefer to use [`TaskPool::spawn`] instead,
549 /// unless the provided future is not `Send`.
550 pub fn spawn_local<T>(&self, future: impl Future<Output = T> + 'static) -> Task<T>
551 where
552 T: 'static,
553 {
554 Task::new(TaskPool::LOCAL_EXECUTOR.with(|executor| executor.spawn(future)))
555 }
556
557 /// Runs a function with the local executor. Typically used to tick
558 /// the local executor on the main thread as it needs to share time with
559 /// other things.
560 ///
561 /// ```
562 /// use bevy_tasks::TaskPool;
563 ///
564 /// TaskPool::new().with_local_executor(|local_executor| {
565 /// local_executor.try_tick();
566 /// });
567 /// ```
568 pub fn with_local_executor<F, R>(&self, f: F) -> R
569 where
570 F: FnOnce(&async_executor::LocalExecutor) -> R,
571 {
572 Self::LOCAL_EXECUTOR.with(f)
573 }
574}
575
576impl Default for TaskPool {
577 fn default() -> Self {
578 Self::new()
579 }
580}
581
582impl Drop for TaskPool {
583 fn drop(&mut self) {
584 self.shutdown_tx.close();
585
586 let panicking = thread::panicking();
587 for join_handle in self.threads.drain(..) {
588 let res = join_handle.join();
589 if !panicking {
590 res.expect("Task thread panicked while executing.");
591 }
592 }
593 }
594}
595
596/// A [`TaskPool`] scope for running one or more non-`'static` futures.
597///
598/// For more information, see [`TaskPool::scope`].
599#[derive(Debug)]
600pub struct Scope<'scope, 'env: 'scope, T> {
601 executor: &'scope async_executor::Executor<'scope>,
602 external_executor: &'scope ThreadExecutor<'scope>,
603 scope_executor: &'scope ThreadExecutor<'scope>,
604 spawned: &'scope ConcurrentQueue<FallibleTask<Result<T, Box<(dyn std::any::Any + Send)>>>>,
605 // make `Scope` invariant over 'scope and 'env
606 scope: PhantomData<&'scope mut &'scope ()>,
607 env: PhantomData<&'env mut &'env ()>,
608}
609
610impl<'scope, 'env, T: Send + 'scope> Scope<'scope, 'env, T> {
611 /// Spawns a scoped future onto the thread pool. The scope *must* outlive
612 /// the provided future. The results of the future will be returned as a part of
613 /// [`TaskPool::scope`]'s return value.
614 ///
615 /// For futures that should run on the thread `scope` is called on [`Scope::spawn_on_scope`] should be used
616 /// instead.
617 ///
618 /// For more information, see [`TaskPool::scope`].
619 pub fn spawn<Fut: Future<Output = T> + 'scope + Send>(&self, f: Fut) {
620 let task = self
621 .executor
622 .spawn(AssertUnwindSafe(f).catch_unwind())
623 .fallible();
624 // ConcurrentQueue only errors when closed or full, but we never
625 // close and use an unbounded queue, so it is safe to unwrap
626 self.spawned.push(task).unwrap();
627 }
628
629 /// Spawns a scoped future onto the thread the scope is run on. The scope *must* outlive
630 /// the provided future. The results of the future will be returned as a part of
631 /// [`TaskPool::scope`]'s return value. Users should generally prefer to use
632 /// [`Scope::spawn`] instead, unless the provided future needs to run on the scope's thread.
633 ///
634 /// For more information, see [`TaskPool::scope`].
635 pub fn spawn_on_scope<Fut: Future<Output = T> + 'scope + Send>(&self, f: Fut) {
636 let task = self
637 .scope_executor
638 .spawn(AssertUnwindSafe(f).catch_unwind())
639 .fallible();
640 // ConcurrentQueue only errors when closed or full, but we never
641 // close and use an unbounded queue, so it is safe to unwrap
642 self.spawned.push(task).unwrap();
643 }
644
645 /// Spawns a scoped future onto the thread of the external thread executor.
646 /// This is typically the main thread. The scope *must* outlive
647 /// the provided future. The results of the future will be returned as a part of
648 /// [`TaskPool::scope`]'s return value. Users should generally prefer to use
649 /// [`Scope::spawn`] instead, unless the provided future needs to run on the external thread.
650 ///
651 /// For more information, see [`TaskPool::scope`].
652 pub fn spawn_on_external<Fut: Future<Output = T> + 'scope + Send>(&self, f: Fut) {
653 let task = self
654 .external_executor
655 .spawn(AssertUnwindSafe(f).catch_unwind())
656 .fallible();
657 // ConcurrentQueue only errors when closed or full, but we never
658 // close and use an unbounded queue, so it is safe to unwrap
659 self.spawned.push(task).unwrap();
660 }
661}
662
663impl<'scope, 'env, T> Drop for Scope<'scope, 'env, T>
664where
665 T: 'scope,
666{
667 fn drop(&mut self) {
668 block_on(async {
669 while let Ok(task) = self.spawned.pop() {
670 task.cancel().await;
671 }
672 });
673 }
674}
675
676#[cfg(test)]
677#[allow(clippy::disallowed_types)]
678mod tests {
679 use super::*;
680 use std::sync::{
681 atomic::{AtomicBool, AtomicI32, Ordering},
682 Barrier,
683 };
684
685 #[test]
686 fn test_spawn() {
687 let pool = TaskPool::new();
688
689 let foo = Box::new(42);
690 let foo = &*foo;
691
692 let count = Arc::new(AtomicI32::new(0));
693
694 let outputs = pool.scope(|scope| {
695 for _ in 0..100 {
696 let count_clone = count.clone();
697 scope.spawn(async move {
698 if *foo != 42 {
699 panic!("not 42!?!?")
700 } else {
701 count_clone.fetch_add(1, Ordering::Relaxed);
702 *foo
703 }
704 });
705 }
706 });
707
708 for output in &outputs {
709 assert_eq!(*output, 42);
710 }
711
712 assert_eq!(outputs.len(), 100);
713 assert_eq!(count.load(Ordering::Relaxed), 100);
714 }
715
716 #[test]
717 fn test_thread_callbacks() {
718 let counter = Arc::new(AtomicI32::new(0));
719 let start_counter = counter.clone();
720 {
721 let barrier = Arc::new(Barrier::new(11));
722 let last_barrier = barrier.clone();
723 // Build and immediately drop to terminate
724 let _pool = TaskPoolBuilder::new()
725 .num_threads(10)
726 .on_thread_spawn(move || {
727 start_counter.fetch_add(1, Ordering::Relaxed);
728 barrier.clone().wait();
729 })
730 .build();
731 last_barrier.wait();
732 assert_eq!(10, counter.load(Ordering::Relaxed));
733 }
734 assert_eq!(10, counter.load(Ordering::Relaxed));
735 let end_counter = counter.clone();
736 {
737 let _pool = TaskPoolBuilder::new()
738 .num_threads(20)
739 .on_thread_destroy(move || {
740 end_counter.fetch_sub(1, Ordering::Relaxed);
741 })
742 .build();
743 assert_eq!(10, counter.load(Ordering::Relaxed));
744 }
745 assert_eq!(-10, counter.load(Ordering::Relaxed));
746 let start_counter = counter.clone();
747 let end_counter = counter.clone();
748 {
749 let barrier = Arc::new(Barrier::new(6));
750 let last_barrier = barrier.clone();
751 let _pool = TaskPoolBuilder::new()
752 .num_threads(5)
753 .on_thread_spawn(move || {
754 start_counter.fetch_add(1, Ordering::Relaxed);
755 barrier.wait();
756 })
757 .on_thread_destroy(move || {
758 end_counter.fetch_sub(1, Ordering::Relaxed);
759 })
760 .build();
761 last_barrier.wait();
762 assert_eq!(-5, counter.load(Ordering::Relaxed));
763 }
764 assert_eq!(-10, counter.load(Ordering::Relaxed));
765 }
766
767 #[test]
768 fn test_mixed_spawn_on_scope_and_spawn() {
769 let pool = TaskPool::new();
770
771 let foo = Box::new(42);
772 let foo = &*foo;
773
774 let local_count = Arc::new(AtomicI32::new(0));
775 let non_local_count = Arc::new(AtomicI32::new(0));
776
777 let outputs = pool.scope(|scope| {
778 for i in 0..100 {
779 if i % 2 == 0 {
780 let count_clone = non_local_count.clone();
781 scope.spawn(async move {
782 if *foo != 42 {
783 panic!("not 42!?!?")
784 } else {
785 count_clone.fetch_add(1, Ordering::Relaxed);
786 *foo
787 }
788 });
789 } else {
790 let count_clone = local_count.clone();
791 scope.spawn_on_scope(async move {
792 if *foo != 42 {
793 panic!("not 42!?!?")
794 } else {
795 count_clone.fetch_add(1, Ordering::Relaxed);
796 *foo
797 }
798 });
799 }
800 }
801 });
802
803 for output in &outputs {
804 assert_eq!(*output, 42);
805 }
806
807 assert_eq!(outputs.len(), 100);
808 assert_eq!(local_count.load(Ordering::Relaxed), 50);
809 assert_eq!(non_local_count.load(Ordering::Relaxed), 50);
810 }
811
812 #[test]
813 fn test_thread_locality() {
814 let pool = Arc::new(TaskPool::new());
815 let count = Arc::new(AtomicI32::new(0));
816 let barrier = Arc::new(Barrier::new(101));
817 let thread_check_failed = Arc::new(AtomicBool::new(false));
818
819 for _ in 0..100 {
820 let inner_barrier = barrier.clone();
821 let count_clone = count.clone();
822 let inner_pool = pool.clone();
823 let inner_thread_check_failed = thread_check_failed.clone();
824 thread::spawn(move || {
825 inner_pool.scope(|scope| {
826 let inner_count_clone = count_clone.clone();
827 scope.spawn(async move {
828 inner_count_clone.fetch_add(1, Ordering::Release);
829 });
830 let spawner = thread::current().id();
831 let inner_count_clone = count_clone.clone();
832 scope.spawn_on_scope(async move {
833 inner_count_clone.fetch_add(1, Ordering::Release);
834 if thread::current().id() != spawner {
835 // NOTE: This check is using an atomic rather than simply panicking the
836 // thread to avoid deadlocking the barrier on failure
837 inner_thread_check_failed.store(true, Ordering::Release);
838 }
839 });
840 });
841 inner_barrier.wait();
842 });
843 }
844 barrier.wait();
845 assert!(!thread_check_failed.load(Ordering::Acquire));
846 assert_eq!(count.load(Ordering::Acquire), 200);
847 }
848
849 #[test]
850 fn test_nested_spawn() {
851 let pool = TaskPool::new();
852
853 let foo = Box::new(42);
854 let foo = &*foo;
855
856 let count = Arc::new(AtomicI32::new(0));
857
858 let outputs: Vec<i32> = pool.scope(|scope| {
859 for _ in 0..10 {
860 let count_clone = count.clone();
861 scope.spawn(async move {
862 for _ in 0..10 {
863 let count_clone_clone = count_clone.clone();
864 scope.spawn(async move {
865 if *foo != 42 {
866 panic!("not 42!?!?")
867 } else {
868 count_clone_clone.fetch_add(1, Ordering::Relaxed);
869 *foo
870 }
871 });
872 }
873 *foo
874 });
875 }
876 });
877
878 for output in &outputs {
879 assert_eq!(*output, 42);
880 }
881
882 // the inner loop runs 100 times and the outer one runs 10. 100 + 10
883 assert_eq!(outputs.len(), 110);
884 assert_eq!(count.load(Ordering::Relaxed), 100);
885 }
886
887 #[test]
888 fn test_nested_locality() {
889 let pool = Arc::new(TaskPool::new());
890 let count = Arc::new(AtomicI32::new(0));
891 let barrier = Arc::new(Barrier::new(101));
892 let thread_check_failed = Arc::new(AtomicBool::new(false));
893
894 for _ in 0..100 {
895 let inner_barrier = barrier.clone();
896 let count_clone = count.clone();
897 let inner_pool = pool.clone();
898 let inner_thread_check_failed = thread_check_failed.clone();
899 thread::spawn(move || {
900 inner_pool.scope(|scope| {
901 let spawner = thread::current().id();
902 let inner_count_clone = count_clone.clone();
903 scope.spawn(async move {
904 inner_count_clone.fetch_add(1, Ordering::Release);
905
906 // spawning on the scope from another thread runs the futures on the scope's thread
907 scope.spawn_on_scope(async move {
908 inner_count_clone.fetch_add(1, Ordering::Release);
909 if thread::current().id() != spawner {
910 // NOTE: This check is using an atomic rather than simply panicking the
911 // thread to avoid deadlocking the barrier on failure
912 inner_thread_check_failed.store(true, Ordering::Release);
913 }
914 });
915 });
916 });
917 inner_barrier.wait();
918 });
919 }
920 barrier.wait();
921 assert!(!thread_check_failed.load(Ordering::Acquire));
922 assert_eq!(count.load(Ordering::Acquire), 200);
923 }
924
925 // This test will often freeze on other executors.
926 #[test]
927 fn test_nested_scopes() {
928 let pool = TaskPool::new();
929 let count = Arc::new(AtomicI32::new(0));
930
931 pool.scope(|scope| {
932 scope.spawn(async {
933 pool.scope(|scope| {
934 scope.spawn(async {
935 count.fetch_add(1, Ordering::Relaxed);
936 });
937 });
938 });
939 });
940
941 assert_eq!(count.load(Ordering::Acquire), 1);
942 }
943}