bevy_tasks/
task_pool.rs

1use std::{
2    future::Future,
3    marker::PhantomData,
4    mem,
5    panic::AssertUnwindSafe,
6    sync::Arc,
7    thread::{self, JoinHandle},
8};
9
10use async_executor::FallibleTask;
11use concurrent_queue::ConcurrentQueue;
12use futures_lite::FutureExt;
13
14use crate::{
15    block_on,
16    thread_executor::{ThreadExecutor, ThreadExecutorTicker},
17    Task,
18};
19
20struct CallOnDrop(Option<Arc<dyn Fn() + Send + Sync + 'static>>);
21
22impl Drop for CallOnDrop {
23    fn drop(&mut self) {
24        if let Some(call) = self.0.as_ref() {
25            call();
26        }
27    }
28}
29
30/// Used to create a [`TaskPool`]
31#[derive(Default)]
32#[must_use]
33pub struct TaskPoolBuilder {
34    /// If set, we'll set up the thread pool to use at most `num_threads` threads.
35    /// Otherwise use the logical core count of the system
36    num_threads: Option<usize>,
37    /// If set, we'll use the given stack size rather than the system default
38    stack_size: Option<usize>,
39    /// Allows customizing the name of the threads - helpful for debugging. If set, threads will
40    /// be named `<thread_name> (<thread_index>)`, i.e. `"MyThreadPool (2)"`.
41    thread_name: Option<String>,
42
43    on_thread_spawn: Option<Arc<dyn Fn() + Send + Sync + 'static>>,
44    on_thread_destroy: Option<Arc<dyn Fn() + Send + Sync + 'static>>,
45}
46
47impl TaskPoolBuilder {
48    /// Creates a new [`TaskPoolBuilder`] instance
49    pub fn new() -> Self {
50        Self::default()
51    }
52
53    /// Override the number of threads created for the pool. If unset, we default to the number
54    /// of logical cores of the system
55    pub fn num_threads(mut self, num_threads: usize) -> Self {
56        self.num_threads = Some(num_threads);
57        self
58    }
59
60    /// Override the stack size of the threads created for the pool
61    pub fn stack_size(mut self, stack_size: usize) -> Self {
62        self.stack_size = Some(stack_size);
63        self
64    }
65
66    /// Override the name of the threads created for the pool. If set, threads will
67    /// be named `<thread_name> (<thread_index>)`, i.e. `MyThreadPool (2)`
68    pub fn thread_name(mut self, thread_name: String) -> Self {
69        self.thread_name = Some(thread_name);
70        self
71    }
72
73    /// Sets a callback that is invoked once for every created thread as it starts.
74    ///
75    /// This is called on the thread itself and has access to all thread-local storage.
76    /// This will block running async tasks on the thread until the callback completes.
77    pub fn on_thread_spawn(mut self, f: impl Fn() + Send + Sync + 'static) -> Self {
78        self.on_thread_spawn = Some(Arc::new(f));
79        self
80    }
81
82    /// Sets a callback that is invoked once for every created thread as it terminates.
83    ///
84    /// This is called on the thread itself and has access to all thread-local storage.
85    /// This will block thread termination until the callback completes.
86    pub fn on_thread_destroy(mut self, f: impl Fn() + Send + Sync + 'static) -> Self {
87        self.on_thread_destroy = Some(Arc::new(f));
88        self
89    }
90
91    /// Creates a new [`TaskPool`] based on the current options.
92    pub fn build(self) -> TaskPool {
93        TaskPool::new_internal(self)
94    }
95}
96
97/// A thread pool for executing tasks.
98///
99/// While futures usually need to be polled to be executed, Bevy tasks are being
100/// automatically driven by the pool on threads owned by the pool. The [`Task`]
101/// future only needs to be polled in order to receive the result. (For that
102/// purpose, it is often stored in a component or resource, see the
103/// `async_compute` example.)
104///
105/// If the result is not required, one may also use [`Task::detach`] and the pool
106/// will still execute a task, even if it is dropped.
107#[derive(Debug)]
108pub struct TaskPool {
109    /// The executor for the pool.
110    executor: Arc<async_executor::Executor<'static>>,
111
112    // The inner state of the pool.
113    threads: Vec<JoinHandle<()>>,
114    shutdown_tx: async_channel::Sender<()>,
115}
116
117impl TaskPool {
118    thread_local! {
119        static LOCAL_EXECUTOR: async_executor::LocalExecutor<'static> = const { async_executor::LocalExecutor::new() };
120        static THREAD_EXECUTOR: Arc<ThreadExecutor<'static>> = Arc::new(ThreadExecutor::new());
121    }
122
123    /// Each thread should only create one `ThreadExecutor`, otherwise, there are good chances they will deadlock
124    pub fn get_thread_executor() -> Arc<ThreadExecutor<'static>> {
125        Self::THREAD_EXECUTOR.with(|executor| executor.clone())
126    }
127
128    /// Create a `TaskPool` with the default configuration.
129    pub fn new() -> Self {
130        TaskPoolBuilder::new().build()
131    }
132
133    fn new_internal(builder: TaskPoolBuilder) -> Self {
134        let (shutdown_tx, shutdown_rx) = async_channel::unbounded::<()>();
135
136        let executor = Arc::new(async_executor::Executor::new());
137
138        let num_threads = builder
139            .num_threads
140            .unwrap_or_else(crate::available_parallelism);
141
142        let threads = (0..num_threads)
143            .map(|i| {
144                let ex = Arc::clone(&executor);
145                let shutdown_rx = shutdown_rx.clone();
146
147                let thread_name = if let Some(thread_name) = builder.thread_name.as_deref() {
148                    format!("{thread_name} ({i})")
149                } else {
150                    format!("TaskPool ({i})")
151                };
152                let mut thread_builder = thread::Builder::new().name(thread_name);
153
154                if let Some(stack_size) = builder.stack_size {
155                    thread_builder = thread_builder.stack_size(stack_size);
156                }
157
158                let on_thread_spawn = builder.on_thread_spawn.clone();
159                let on_thread_destroy = builder.on_thread_destroy.clone();
160
161                thread_builder
162                    .spawn(move || {
163                        TaskPool::LOCAL_EXECUTOR.with(|local_executor| {
164                            if let Some(on_thread_spawn) = on_thread_spawn {
165                                on_thread_spawn();
166                                drop(on_thread_spawn);
167                            }
168                            let _destructor = CallOnDrop(on_thread_destroy);
169                            loop {
170                                let res = std::panic::catch_unwind(|| {
171                                    let tick_forever = async move {
172                                        loop {
173                                            local_executor.tick().await;
174                                        }
175                                    };
176                                    block_on(ex.run(tick_forever.or(shutdown_rx.recv())))
177                                });
178                                if let Ok(value) = res {
179                                    // Use unwrap_err because we expect a Closed error
180                                    value.unwrap_err();
181                                    break;
182                                }
183                            }
184                        });
185                    })
186                    .expect("Failed to spawn thread.")
187            })
188            .collect();
189
190        Self {
191            executor,
192            threads,
193            shutdown_tx,
194        }
195    }
196
197    /// Return the number of threads owned by the task pool
198    pub fn thread_num(&self) -> usize {
199        self.threads.len()
200    }
201
202    /// Allows spawning non-`'static` futures on the thread pool. The function takes a callback,
203    /// passing a scope object into it. The scope object provided to the callback can be used
204    /// to spawn tasks. This function will await the completion of all tasks before returning.
205    ///
206    /// This is similar to [`thread::scope`] and `rayon::scope`.
207    ///
208    /// # Example
209    ///
210    /// ```
211    /// use bevy_tasks::TaskPool;
212    ///
213    /// let pool = TaskPool::new();
214    /// let mut x = 0;
215    /// let results = pool.scope(|s| {
216    ///     s.spawn(async {
217    ///         // you can borrow the spawner inside a task and spawn tasks from within the task
218    ///         s.spawn(async {
219    ///             // borrow x and mutate it.
220    ///             x = 2;
221    ///             // return a value from the task
222    ///             1
223    ///         });
224    ///         // return some other value from the first task
225    ///         0
226    ///     });
227    /// });
228    ///
229    /// // The ordering of results is non-deterministic if you spawn from within tasks as above.
230    /// // If you're doing this, you'll have to write your code to not depend on the ordering.
231    /// assert!(results.contains(&0));
232    /// assert!(results.contains(&1));
233    ///
234    /// // The ordering is deterministic if you only spawn directly from the closure function.
235    /// let results = pool.scope(|s| {
236    ///     s.spawn(async { 0 });
237    ///     s.spawn(async { 1 });
238    /// });
239    /// assert_eq!(&results[..], &[0, 1]);
240    ///
241    /// // You can access x after scope runs, since it was only temporarily borrowed in the scope.
242    /// assert_eq!(x, 2);
243    /// ```
244    ///
245    /// # Lifetimes
246    ///
247    /// The [`Scope`] object takes two lifetimes: `'scope` and `'env`.
248    ///
249    /// The `'scope` lifetime represents the lifetime of the scope. That is the time during
250    /// which the provided closure and tasks that are spawned into the scope are run.
251    ///
252    /// The `'env` lifetime represents the lifetime of whatever is borrowed by the scope.
253    /// Thus this lifetime must outlive `'scope`.
254    ///
255    /// ```compile_fail
256    /// use bevy_tasks::TaskPool;
257    /// fn scope_escapes_closure() {
258    ///     let pool = TaskPool::new();
259    ///     let foo = Box::new(42);
260    ///     pool.scope(|scope| {
261    ///         std::thread::spawn(move || {
262    ///             // UB. This could spawn on the scope after `.scope` returns and the internal Scope is dropped.
263    ///             scope.spawn(async move {
264    ///                 assert_eq!(*foo, 42);
265    ///             });
266    ///         });
267    ///     });
268    /// }
269    /// ```
270    ///
271    /// ```compile_fail
272    /// use bevy_tasks::TaskPool;
273    /// fn cannot_borrow_from_closure() {
274    ///     let pool = TaskPool::new();
275    ///     pool.scope(|scope| {
276    ///         let x = 1;
277    ///         let y = &x;
278    ///         scope.spawn(async move {
279    ///             assert_eq!(*y, 1);
280    ///         });
281    ///     });
282    /// }
283    pub fn scope<'env, F, T>(&self, f: F) -> Vec<T>
284    where
285        F: for<'scope> FnOnce(&'scope Scope<'scope, 'env, T>),
286        T: Send + 'static,
287    {
288        Self::THREAD_EXECUTOR.with(|scope_executor| {
289            self.scope_with_executor_inner(true, scope_executor, scope_executor, f)
290        })
291    }
292
293    /// This allows passing an external executor to spawn tasks on. When you pass an external executor
294    /// [`Scope::spawn_on_scope`] spawns is then run on the thread that [`ThreadExecutor`] is being ticked on.
295    /// If [`None`] is passed the scope will use a [`ThreadExecutor`] that is ticked on the current thread.
296    ///
297    /// When `tick_task_pool_executor` is set to `true`, the multithreaded task stealing executor is ticked on the scope
298    /// thread. Disabling this can be useful when finishing the scope is latency sensitive. Pulling tasks from
299    /// global executor can run tasks unrelated to the scope and delay when the scope returns.
300    ///
301    /// See [`Self::scope`] for more details in general about how scopes work.
302    pub fn scope_with_executor<'env, F, T>(
303        &self,
304        tick_task_pool_executor: bool,
305        external_executor: Option<&ThreadExecutor>,
306        f: F,
307    ) -> Vec<T>
308    where
309        F: for<'scope> FnOnce(&'scope Scope<'scope, 'env, T>),
310        T: Send + 'static,
311    {
312        Self::THREAD_EXECUTOR.with(|scope_executor| {
313            // If a `external_executor` is passed use that. Otherwise get the executor stored
314            // in the `THREAD_EXECUTOR` thread local.
315            if let Some(external_executor) = external_executor {
316                self.scope_with_executor_inner(
317                    tick_task_pool_executor,
318                    external_executor,
319                    scope_executor,
320                    f,
321                )
322            } else {
323                self.scope_with_executor_inner(
324                    tick_task_pool_executor,
325                    scope_executor,
326                    scope_executor,
327                    f,
328                )
329            }
330        })
331    }
332
333    #[allow(unsafe_code)]
334    fn scope_with_executor_inner<'env, F, T>(
335        &self,
336        tick_task_pool_executor: bool,
337        external_executor: &ThreadExecutor,
338        scope_executor: &ThreadExecutor,
339        f: F,
340    ) -> Vec<T>
341    where
342        F: for<'scope> FnOnce(&'scope Scope<'scope, 'env, T>),
343        T: Send + 'static,
344    {
345        // SAFETY: This safety comment applies to all references transmuted to 'env.
346        // Any futures spawned with these references need to return before this function completes.
347        // This is guaranteed because we drive all the futures spawned onto the Scope
348        // to completion in this function. However, rust has no way of knowing this so we
349        // transmute the lifetimes to 'env here to appease the compiler as it is unable to validate safety.
350        // Any usages of the references passed into `Scope` must be accessed through
351        // the transmuted reference for the rest of this function.
352        let executor: &async_executor::Executor = &self.executor;
353        // SAFETY: As above, all futures must complete in this function so we can change the lifetime
354        let executor: &'env async_executor::Executor = unsafe { mem::transmute(executor) };
355        // SAFETY: As above, all futures must complete in this function so we can change the lifetime
356        let external_executor: &'env ThreadExecutor<'env> =
357            unsafe { mem::transmute(external_executor) };
358        // SAFETY: As above, all futures must complete in this function so we can change the lifetime
359        let scope_executor: &'env ThreadExecutor<'env> = unsafe { mem::transmute(scope_executor) };
360        let spawned: ConcurrentQueue<FallibleTask<Result<T, Box<(dyn std::any::Any + Send)>>>> =
361            ConcurrentQueue::unbounded();
362        // shadow the variable so that the owned value cannot be used for the rest of the function
363        // SAFETY: As above, all futures must complete in this function so we can change the lifetime
364        let spawned: &'env ConcurrentQueue<
365            FallibleTask<Result<T, Box<(dyn std::any::Any + Send)>>>,
366        > = unsafe { mem::transmute(&spawned) };
367
368        let scope = Scope {
369            executor,
370            external_executor,
371            scope_executor,
372            spawned,
373            scope: PhantomData,
374            env: PhantomData,
375        };
376
377        // shadow the variable so that the owned value cannot be used for the rest of the function
378        // SAFETY: As above, all futures must complete in this function so we can change the lifetime
379        let scope: &'env Scope<'_, 'env, T> = unsafe { mem::transmute(&scope) };
380
381        f(scope);
382
383        if spawned.is_empty() {
384            Vec::new()
385        } else {
386            block_on(async move {
387                let get_results = async {
388                    let mut results = Vec::with_capacity(spawned.len());
389                    while let Ok(task) = spawned.pop() {
390                        if let Some(res) = task.await {
391                            match res {
392                                Ok(res) => results.push(res),
393                                Err(payload) => std::panic::resume_unwind(payload),
394                            }
395                        } else {
396                            panic!("Failed to catch panic!");
397                        }
398                    }
399                    results
400                };
401
402                let tick_task_pool_executor = tick_task_pool_executor || self.threads.is_empty();
403
404                // we get this from a thread local so we should always be on the scope executors thread.
405                // note: it is possible `scope_executor` and `external_executor` is the same executor,
406                // in that case, we should only tick one of them, otherwise, it may cause deadlock.
407                let scope_ticker = scope_executor.ticker().unwrap();
408                let external_ticker = if !external_executor.is_same(scope_executor) {
409                    external_executor.ticker()
410                } else {
411                    None
412                };
413
414                match (external_ticker, tick_task_pool_executor) {
415                    (Some(external_ticker), true) => {
416                        Self::execute_global_external_scope(
417                            executor,
418                            external_ticker,
419                            scope_ticker,
420                            get_results,
421                        )
422                        .await
423                    }
424                    (Some(external_ticker), false) => {
425                        Self::execute_external_scope(external_ticker, scope_ticker, get_results)
426                            .await
427                    }
428                    // either external_executor is none or it is same as scope_executor
429                    (None, true) => {
430                        Self::execute_global_scope(executor, scope_ticker, get_results).await
431                    }
432                    (None, false) => Self::execute_scope(scope_ticker, get_results).await,
433                }
434            })
435        }
436    }
437
438    #[inline]
439    async fn execute_global_external_scope<'scope, 'ticker, T>(
440        executor: &'scope async_executor::Executor<'scope>,
441        external_ticker: ThreadExecutorTicker<'scope, 'ticker>,
442        scope_ticker: ThreadExecutorTicker<'scope, 'ticker>,
443        get_results: impl Future<Output = Vec<T>>,
444    ) -> Vec<T> {
445        // we restart the executors if a task errors. if a scoped
446        // task errors it will panic the scope on the call to get_results
447        let execute_forever = async move {
448            loop {
449                let tick_forever = async {
450                    loop {
451                        external_ticker.tick().or(scope_ticker.tick()).await;
452                    }
453                };
454                // we don't care if it errors. If a scoped task errors it will propagate
455                // to get_results
456                let _result = AssertUnwindSafe(executor.run(tick_forever))
457                    .catch_unwind()
458                    .await
459                    .is_ok();
460            }
461        };
462        execute_forever.or(get_results).await
463    }
464
465    #[inline]
466    async fn execute_external_scope<'scope, 'ticker, T>(
467        external_ticker: ThreadExecutorTicker<'scope, 'ticker>,
468        scope_ticker: ThreadExecutorTicker<'scope, 'ticker>,
469        get_results: impl Future<Output = Vec<T>>,
470    ) -> Vec<T> {
471        let execute_forever = async {
472            loop {
473                let tick_forever = async {
474                    loop {
475                        external_ticker.tick().or(scope_ticker.tick()).await;
476                    }
477                };
478                let _result = AssertUnwindSafe(tick_forever).catch_unwind().await.is_ok();
479            }
480        };
481        execute_forever.or(get_results).await
482    }
483
484    #[inline]
485    async fn execute_global_scope<'scope, 'ticker, T>(
486        executor: &'scope async_executor::Executor<'scope>,
487        scope_ticker: ThreadExecutorTicker<'scope, 'ticker>,
488        get_results: impl Future<Output = Vec<T>>,
489    ) -> Vec<T> {
490        let execute_forever = async {
491            loop {
492                let tick_forever = async {
493                    loop {
494                        scope_ticker.tick().await;
495                    }
496                };
497                let _result = AssertUnwindSafe(executor.run(tick_forever))
498                    .catch_unwind()
499                    .await
500                    .is_ok();
501            }
502        };
503        execute_forever.or(get_results).await
504    }
505
506    #[inline]
507    async fn execute_scope<'scope, 'ticker, T>(
508        scope_ticker: ThreadExecutorTicker<'scope, 'ticker>,
509        get_results: impl Future<Output = Vec<T>>,
510    ) -> Vec<T> {
511        let execute_forever = async {
512            loop {
513                let tick_forever = async {
514                    loop {
515                        scope_ticker.tick().await;
516                    }
517                };
518                let _result = AssertUnwindSafe(tick_forever).catch_unwind().await.is_ok();
519            }
520        };
521        execute_forever.or(get_results).await
522    }
523
524    /// Spawns a static future onto the thread pool. The returned [`Task`] is a
525    /// future that can be polled for the result. It can also be canceled and
526    /// "detached", allowing the task to continue running even if dropped. In
527    /// any case, the pool will execute the task even without polling by the
528    /// end-user.
529    ///
530    /// If the provided future is non-`Send`, [`TaskPool::spawn_local`] should
531    /// be used instead.
532    pub fn spawn<T>(&self, future: impl Future<Output = T> + Send + 'static) -> Task<T>
533    where
534        T: Send + 'static,
535    {
536        Task::new(self.executor.spawn(future))
537    }
538
539    /// Spawns a static future on the thread-local async executor for the
540    /// current thread. The task will run entirely on the thread the task was
541    /// spawned on.
542    ///
543    /// The returned [`Task`] is a future that can be polled for the
544    /// result. It can also be canceled and "detached", allowing the task to
545    /// continue running even if dropped. In any case, the pool will execute the
546    /// task even without polling by the end-user.
547    ///
548    /// Users should generally prefer to use [`TaskPool::spawn`] instead,
549    /// unless the provided future is not `Send`.
550    pub fn spawn_local<T>(&self, future: impl Future<Output = T> + 'static) -> Task<T>
551    where
552        T: 'static,
553    {
554        Task::new(TaskPool::LOCAL_EXECUTOR.with(|executor| executor.spawn(future)))
555    }
556
557    /// Runs a function with the local executor. Typically used to tick
558    /// the local executor on the main thread as it needs to share time with
559    /// other things.
560    ///
561    /// ```
562    /// use bevy_tasks::TaskPool;
563    ///
564    /// TaskPool::new().with_local_executor(|local_executor| {
565    ///     local_executor.try_tick();
566    /// });
567    /// ```
568    pub fn with_local_executor<F, R>(&self, f: F) -> R
569    where
570        F: FnOnce(&async_executor::LocalExecutor) -> R,
571    {
572        Self::LOCAL_EXECUTOR.with(f)
573    }
574}
575
576impl Default for TaskPool {
577    fn default() -> Self {
578        Self::new()
579    }
580}
581
582impl Drop for TaskPool {
583    fn drop(&mut self) {
584        self.shutdown_tx.close();
585
586        let panicking = thread::panicking();
587        for join_handle in self.threads.drain(..) {
588            let res = join_handle.join();
589            if !panicking {
590                res.expect("Task thread panicked while executing.");
591            }
592        }
593    }
594}
595
596/// A [`TaskPool`] scope for running one or more non-`'static` futures.
597///
598/// For more information, see [`TaskPool::scope`].
599#[derive(Debug)]
600pub struct Scope<'scope, 'env: 'scope, T> {
601    executor: &'scope async_executor::Executor<'scope>,
602    external_executor: &'scope ThreadExecutor<'scope>,
603    scope_executor: &'scope ThreadExecutor<'scope>,
604    spawned: &'scope ConcurrentQueue<FallibleTask<Result<T, Box<(dyn std::any::Any + Send)>>>>,
605    // make `Scope` invariant over 'scope and 'env
606    scope: PhantomData<&'scope mut &'scope ()>,
607    env: PhantomData<&'env mut &'env ()>,
608}
609
610impl<'scope, 'env, T: Send + 'scope> Scope<'scope, 'env, T> {
611    /// Spawns a scoped future onto the thread pool. The scope *must* outlive
612    /// the provided future. The results of the future will be returned as a part of
613    /// [`TaskPool::scope`]'s return value.
614    ///
615    /// For futures that should run on the thread `scope` is called on [`Scope::spawn_on_scope`] should be used
616    /// instead.
617    ///
618    /// For more information, see [`TaskPool::scope`].
619    pub fn spawn<Fut: Future<Output = T> + 'scope + Send>(&self, f: Fut) {
620        let task = self
621            .executor
622            .spawn(AssertUnwindSafe(f).catch_unwind())
623            .fallible();
624        // ConcurrentQueue only errors when closed or full, but we never
625        // close and use an unbounded queue, so it is safe to unwrap
626        self.spawned.push(task).unwrap();
627    }
628
629    /// Spawns a scoped future onto the thread the scope is run on. The scope *must* outlive
630    /// the provided future. The results of the future will be returned as a part of
631    /// [`TaskPool::scope`]'s return value.  Users should generally prefer to use
632    /// [`Scope::spawn`] instead, unless the provided future needs to run on the scope's thread.
633    ///
634    /// For more information, see [`TaskPool::scope`].
635    pub fn spawn_on_scope<Fut: Future<Output = T> + 'scope + Send>(&self, f: Fut) {
636        let task = self
637            .scope_executor
638            .spawn(AssertUnwindSafe(f).catch_unwind())
639            .fallible();
640        // ConcurrentQueue only errors when closed or full, but we never
641        // close and use an unbounded queue, so it is safe to unwrap
642        self.spawned.push(task).unwrap();
643    }
644
645    /// Spawns a scoped future onto the thread of the external thread executor.
646    /// This is typically the main thread. The scope *must* outlive
647    /// the provided future. The results of the future will be returned as a part of
648    /// [`TaskPool::scope`]'s return value.  Users should generally prefer to use
649    /// [`Scope::spawn`] instead, unless the provided future needs to run on the external thread.
650    ///
651    /// For more information, see [`TaskPool::scope`].
652    pub fn spawn_on_external<Fut: Future<Output = T> + 'scope + Send>(&self, f: Fut) {
653        let task = self
654            .external_executor
655            .spawn(AssertUnwindSafe(f).catch_unwind())
656            .fallible();
657        // ConcurrentQueue only errors when closed or full, but we never
658        // close and use an unbounded queue, so it is safe to unwrap
659        self.spawned.push(task).unwrap();
660    }
661}
662
663impl<'scope, 'env, T> Drop for Scope<'scope, 'env, T>
664where
665    T: 'scope,
666{
667    fn drop(&mut self) {
668        block_on(async {
669            while let Ok(task) = self.spawned.pop() {
670                task.cancel().await;
671            }
672        });
673    }
674}
675
676#[cfg(test)]
677#[allow(clippy::disallowed_types)]
678mod tests {
679    use super::*;
680    use std::sync::{
681        atomic::{AtomicBool, AtomicI32, Ordering},
682        Barrier,
683    };
684
685    #[test]
686    fn test_spawn() {
687        let pool = TaskPool::new();
688
689        let foo = Box::new(42);
690        let foo = &*foo;
691
692        let count = Arc::new(AtomicI32::new(0));
693
694        let outputs = pool.scope(|scope| {
695            for _ in 0..100 {
696                let count_clone = count.clone();
697                scope.spawn(async move {
698                    if *foo != 42 {
699                        panic!("not 42!?!?")
700                    } else {
701                        count_clone.fetch_add(1, Ordering::Relaxed);
702                        *foo
703                    }
704                });
705            }
706        });
707
708        for output in &outputs {
709            assert_eq!(*output, 42);
710        }
711
712        assert_eq!(outputs.len(), 100);
713        assert_eq!(count.load(Ordering::Relaxed), 100);
714    }
715
716    #[test]
717    fn test_thread_callbacks() {
718        let counter = Arc::new(AtomicI32::new(0));
719        let start_counter = counter.clone();
720        {
721            let barrier = Arc::new(Barrier::new(11));
722            let last_barrier = barrier.clone();
723            // Build and immediately drop to terminate
724            let _pool = TaskPoolBuilder::new()
725                .num_threads(10)
726                .on_thread_spawn(move || {
727                    start_counter.fetch_add(1, Ordering::Relaxed);
728                    barrier.clone().wait();
729                })
730                .build();
731            last_barrier.wait();
732            assert_eq!(10, counter.load(Ordering::Relaxed));
733        }
734        assert_eq!(10, counter.load(Ordering::Relaxed));
735        let end_counter = counter.clone();
736        {
737            let _pool = TaskPoolBuilder::new()
738                .num_threads(20)
739                .on_thread_destroy(move || {
740                    end_counter.fetch_sub(1, Ordering::Relaxed);
741                })
742                .build();
743            assert_eq!(10, counter.load(Ordering::Relaxed));
744        }
745        assert_eq!(-10, counter.load(Ordering::Relaxed));
746        let start_counter = counter.clone();
747        let end_counter = counter.clone();
748        {
749            let barrier = Arc::new(Barrier::new(6));
750            let last_barrier = barrier.clone();
751            let _pool = TaskPoolBuilder::new()
752                .num_threads(5)
753                .on_thread_spawn(move || {
754                    start_counter.fetch_add(1, Ordering::Relaxed);
755                    barrier.wait();
756                })
757                .on_thread_destroy(move || {
758                    end_counter.fetch_sub(1, Ordering::Relaxed);
759                })
760                .build();
761            last_barrier.wait();
762            assert_eq!(-5, counter.load(Ordering::Relaxed));
763        }
764        assert_eq!(-10, counter.load(Ordering::Relaxed));
765    }
766
767    #[test]
768    fn test_mixed_spawn_on_scope_and_spawn() {
769        let pool = TaskPool::new();
770
771        let foo = Box::new(42);
772        let foo = &*foo;
773
774        let local_count = Arc::new(AtomicI32::new(0));
775        let non_local_count = Arc::new(AtomicI32::new(0));
776
777        let outputs = pool.scope(|scope| {
778            for i in 0..100 {
779                if i % 2 == 0 {
780                    let count_clone = non_local_count.clone();
781                    scope.spawn(async move {
782                        if *foo != 42 {
783                            panic!("not 42!?!?")
784                        } else {
785                            count_clone.fetch_add(1, Ordering::Relaxed);
786                            *foo
787                        }
788                    });
789                } else {
790                    let count_clone = local_count.clone();
791                    scope.spawn_on_scope(async move {
792                        if *foo != 42 {
793                            panic!("not 42!?!?")
794                        } else {
795                            count_clone.fetch_add(1, Ordering::Relaxed);
796                            *foo
797                        }
798                    });
799                }
800            }
801        });
802
803        for output in &outputs {
804            assert_eq!(*output, 42);
805        }
806
807        assert_eq!(outputs.len(), 100);
808        assert_eq!(local_count.load(Ordering::Relaxed), 50);
809        assert_eq!(non_local_count.load(Ordering::Relaxed), 50);
810    }
811
812    #[test]
813    fn test_thread_locality() {
814        let pool = Arc::new(TaskPool::new());
815        let count = Arc::new(AtomicI32::new(0));
816        let barrier = Arc::new(Barrier::new(101));
817        let thread_check_failed = Arc::new(AtomicBool::new(false));
818
819        for _ in 0..100 {
820            let inner_barrier = barrier.clone();
821            let count_clone = count.clone();
822            let inner_pool = pool.clone();
823            let inner_thread_check_failed = thread_check_failed.clone();
824            thread::spawn(move || {
825                inner_pool.scope(|scope| {
826                    let inner_count_clone = count_clone.clone();
827                    scope.spawn(async move {
828                        inner_count_clone.fetch_add(1, Ordering::Release);
829                    });
830                    let spawner = thread::current().id();
831                    let inner_count_clone = count_clone.clone();
832                    scope.spawn_on_scope(async move {
833                        inner_count_clone.fetch_add(1, Ordering::Release);
834                        if thread::current().id() != spawner {
835                            // NOTE: This check is using an atomic rather than simply panicking the
836                            // thread to avoid deadlocking the barrier on failure
837                            inner_thread_check_failed.store(true, Ordering::Release);
838                        }
839                    });
840                });
841                inner_barrier.wait();
842            });
843        }
844        barrier.wait();
845        assert!(!thread_check_failed.load(Ordering::Acquire));
846        assert_eq!(count.load(Ordering::Acquire), 200);
847    }
848
849    #[test]
850    fn test_nested_spawn() {
851        let pool = TaskPool::new();
852
853        let foo = Box::new(42);
854        let foo = &*foo;
855
856        let count = Arc::new(AtomicI32::new(0));
857
858        let outputs: Vec<i32> = pool.scope(|scope| {
859            for _ in 0..10 {
860                let count_clone = count.clone();
861                scope.spawn(async move {
862                    for _ in 0..10 {
863                        let count_clone_clone = count_clone.clone();
864                        scope.spawn(async move {
865                            if *foo != 42 {
866                                panic!("not 42!?!?")
867                            } else {
868                                count_clone_clone.fetch_add(1, Ordering::Relaxed);
869                                *foo
870                            }
871                        });
872                    }
873                    *foo
874                });
875            }
876        });
877
878        for output in &outputs {
879            assert_eq!(*output, 42);
880        }
881
882        // the inner loop runs 100 times and the outer one runs 10. 100 + 10
883        assert_eq!(outputs.len(), 110);
884        assert_eq!(count.load(Ordering::Relaxed), 100);
885    }
886
887    #[test]
888    fn test_nested_locality() {
889        let pool = Arc::new(TaskPool::new());
890        let count = Arc::new(AtomicI32::new(0));
891        let barrier = Arc::new(Barrier::new(101));
892        let thread_check_failed = Arc::new(AtomicBool::new(false));
893
894        for _ in 0..100 {
895            let inner_barrier = barrier.clone();
896            let count_clone = count.clone();
897            let inner_pool = pool.clone();
898            let inner_thread_check_failed = thread_check_failed.clone();
899            thread::spawn(move || {
900                inner_pool.scope(|scope| {
901                    let spawner = thread::current().id();
902                    let inner_count_clone = count_clone.clone();
903                    scope.spawn(async move {
904                        inner_count_clone.fetch_add(1, Ordering::Release);
905
906                        // spawning on the scope from another thread runs the futures on the scope's thread
907                        scope.spawn_on_scope(async move {
908                            inner_count_clone.fetch_add(1, Ordering::Release);
909                            if thread::current().id() != spawner {
910                                // NOTE: This check is using an atomic rather than simply panicking the
911                                // thread to avoid deadlocking the barrier on failure
912                                inner_thread_check_failed.store(true, Ordering::Release);
913                            }
914                        });
915                    });
916                });
917                inner_barrier.wait();
918            });
919        }
920        barrier.wait();
921        assert!(!thread_check_failed.load(Ordering::Acquire));
922        assert_eq!(count.load(Ordering::Acquire), 200);
923    }
924
925    // This test will often freeze on other executors.
926    #[test]
927    fn test_nested_scopes() {
928        let pool = TaskPool::new();
929        let count = Arc::new(AtomicI32::new(0));
930
931        pool.scope(|scope| {
932            scope.spawn(async {
933                pool.scope(|scope| {
934                    scope.spawn(async {
935                        count.fetch_add(1, Ordering::Relaxed);
936                    });
937                });
938            });
939        });
940
941        assert_eq!(count.load(Ordering::Acquire), 1);
942    }
943}