bevy_tasks/
thread_executor.rs

1use std::{
2    marker::PhantomData,
3    thread::{self, ThreadId},
4};
5
6use async_executor::{Executor, Task};
7use futures_lite::Future;
8
9/// An executor that can only be ticked on the thread it was instantiated on. But
10/// can spawn `Send` tasks from other threads.
11///
12/// # Example
13/// ```
14/// # use std::sync::{Arc, atomic::{AtomicI32, Ordering}};
15/// use bevy_tasks::ThreadExecutor;
16///
17/// let thread_executor = ThreadExecutor::new();
18/// let count = Arc::new(AtomicI32::new(0));
19///
20/// // create some owned values that can be moved into another thread
21/// let count_clone = count.clone();
22///
23/// std::thread::scope(|scope| {
24///     scope.spawn(|| {
25///         // we cannot get the ticker from another thread
26///         let not_thread_ticker = thread_executor.ticker();
27///         assert!(not_thread_ticker.is_none());
28///         
29///         // but we can spawn tasks from another thread
30///         thread_executor.spawn(async move {
31///             count_clone.fetch_add(1, Ordering::Relaxed);
32///         }).detach();
33///     });
34/// });
35///
36/// // the tasks do not make progress unless the executor is manually ticked
37/// assert_eq!(count.load(Ordering::Relaxed), 0);
38///
39/// // tick the ticker until task finishes
40/// let thread_ticker = thread_executor.ticker().unwrap();
41/// thread_ticker.try_tick();
42/// assert_eq!(count.load(Ordering::Relaxed), 1);
43/// ```
44#[derive(Debug)]
45pub struct ThreadExecutor<'task> {
46    executor: Executor<'task>,
47    thread_id: ThreadId,
48}
49
50impl<'task> Default for ThreadExecutor<'task> {
51    fn default() -> Self {
52        Self {
53            executor: Executor::new(),
54            thread_id: thread::current().id(),
55        }
56    }
57}
58
59impl<'task> ThreadExecutor<'task> {
60    /// create a new [`ThreadExecutor`]
61    pub fn new() -> Self {
62        Self::default()
63    }
64
65    /// Spawn a task on the thread executor
66    pub fn spawn<T: Send + 'task>(
67        &self,
68        future: impl Future<Output = T> + Send + 'task,
69    ) -> Task<T> {
70        self.executor.spawn(future)
71    }
72
73    /// Gets the [`ThreadExecutorTicker`] for this executor.
74    /// Use this to tick the executor.
75    /// It only returns the ticker if it's on the thread the executor was created on
76    /// and returns `None` otherwise.
77    pub fn ticker<'ticker>(&'ticker self) -> Option<ThreadExecutorTicker<'task, 'ticker>> {
78        if thread::current().id() == self.thread_id {
79            return Some(ThreadExecutorTicker {
80                executor: self,
81                _marker: PhantomData,
82            });
83        }
84        None
85    }
86
87    /// Returns true if `self` and `other`'s executor is same
88    pub fn is_same(&self, other: &Self) -> bool {
89        std::ptr::eq(self, other)
90    }
91}
92
93/// Used to tick the [`ThreadExecutor`]. The executor does not
94/// make progress unless it is manually ticked on the thread it was
95/// created on.
96#[derive(Debug)]
97pub struct ThreadExecutorTicker<'task, 'ticker> {
98    executor: &'ticker ThreadExecutor<'task>,
99    // make type not send or sync
100    _marker: PhantomData<*const ()>,
101}
102impl<'task, 'ticker> ThreadExecutorTicker<'task, 'ticker> {
103    /// Tick the thread executor.
104    pub async fn tick(&self) {
105        self.executor.executor.tick().await;
106    }
107
108    /// Synchronously try to tick a task on the executor.
109    /// Returns false if does not find a task to tick.
110    pub fn try_tick(&self) -> bool {
111        self.executor.executor.try_tick()
112    }
113}
114
115#[cfg(test)]
116mod tests {
117    use super::*;
118    use std::sync::Arc;
119
120    #[test]
121    fn test_ticker() {
122        let executor = Arc::new(ThreadExecutor::new());
123        let ticker = executor.ticker();
124        assert!(ticker.is_some());
125
126        thread::scope(|s| {
127            s.spawn(|| {
128                let ticker = executor.ticker();
129                assert!(ticker.is_none());
130            });
131        });
132    }
133}