bevy_ecs_macros/
query_data.rs

1use bevy_macro_utils::ensure_no_collision;
2use proc_macro::TokenStream;
3use proc_macro2::{Ident, Span};
4use quote::{format_ident, quote, ToTokens};
5use syn::{
6    parse::{Parse, ParseStream},
7    parse_macro_input, parse_quote,
8    punctuated::Punctuated,
9    token::Comma,
10    Attribute, Data, DataStruct, DeriveInput, Field, Index, Meta,
11};
12
13use crate::{
14    bevy_ecs_path,
15    world_query::{item_struct, world_query_impl},
16};
17
18#[derive(Default)]
19struct QueryDataAttributes {
20    pub is_mutable: bool,
21
22    pub derive_args: Punctuated<Meta, syn::token::Comma>,
23}
24
25static MUTABLE_ATTRIBUTE_NAME: &str = "mutable";
26static DERIVE_ATTRIBUTE_NAME: &str = "derive";
27
28mod field_attr_keywords {
29    syn::custom_keyword!(ignore);
30}
31
32pub static QUERY_DATA_ATTRIBUTE_NAME: &str = "query_data";
33
34pub fn derive_query_data_impl(input: TokenStream) -> TokenStream {
35    let tokens = input.clone();
36
37    let ast = parse_macro_input!(input as DeriveInput);
38    let visibility = ast.vis;
39
40    let mut attributes = QueryDataAttributes::default();
41    for attr in &ast.attrs {
42        if !attr
43            .path()
44            .get_ident()
45            .map_or(false, |ident| ident == QUERY_DATA_ATTRIBUTE_NAME)
46        {
47            continue;
48        }
49
50        attr.parse_args_with(|input: ParseStream| {
51            let meta = input.parse_terminated(syn::Meta::parse, Comma)?;
52            for meta in meta {
53                let ident = meta.path().get_ident().unwrap_or_else(|| {
54                    panic!(
55                        "Unrecognized attribute: `{}`",
56                        meta.path().to_token_stream()
57                    )
58                });
59                if ident == MUTABLE_ATTRIBUTE_NAME {
60                    if let Meta::Path(_) = meta {
61                        attributes.is_mutable = true;
62                    } else {
63                        panic!(
64                            "The `{MUTABLE_ATTRIBUTE_NAME}` attribute is expected to have no value or arguments",
65                        );
66                    }
67                }
68                else if ident == DERIVE_ATTRIBUTE_NAME {
69                    if let Meta::List(meta_list) = meta {
70                        meta_list.parse_nested_meta(|meta| {
71                            attributes.derive_args.push(Meta::Path(meta.path));
72                            Ok(())
73                        })?;
74                    } else {
75                        panic!(
76                            "Expected a structured list within the `{DERIVE_ATTRIBUTE_NAME}` attribute",
77                        );
78                    }
79                } else {
80                    panic!(
81                        "Unrecognized attribute: `{}`",
82                        meta.path().to_token_stream()
83                    );
84                }
85            }
86            Ok(())
87        })
88        .unwrap_or_else(|_| panic!("Invalid `{QUERY_DATA_ATTRIBUTE_NAME}` attribute format"));
89    }
90
91    let path = bevy_ecs_path();
92
93    let user_generics = ast.generics.clone();
94    let (user_impl_generics, user_ty_generics, user_where_clauses) = user_generics.split_for_impl();
95    let user_generics_with_world = {
96        let mut generics = ast.generics;
97        generics.params.insert(0, parse_quote!('__w));
98        generics
99    };
100    let (user_impl_generics_with_world, user_ty_generics_with_world, user_where_clauses_with_world) =
101        user_generics_with_world.split_for_impl();
102
103    let struct_name = ast.ident;
104    let read_only_struct_name = if attributes.is_mutable {
105        Ident::new(&format!("{struct_name}ReadOnly"), Span::call_site())
106    } else {
107        #[allow(clippy::redundant_clone)]
108        struct_name.clone()
109    };
110
111    let item_struct_name = Ident::new(&format!("{struct_name}Item"), Span::call_site());
112    let read_only_item_struct_name = if attributes.is_mutable {
113        Ident::new(&format!("{struct_name}ReadOnlyItem"), Span::call_site())
114    } else {
115        #[allow(clippy::redundant_clone)]
116        item_struct_name.clone()
117    };
118
119    let fetch_struct_name = Ident::new(&format!("{struct_name}Fetch"), Span::call_site());
120    let fetch_struct_name = ensure_no_collision(fetch_struct_name, tokens.clone());
121    let read_only_fetch_struct_name = if attributes.is_mutable {
122        let new_ident = Ident::new(&format!("{struct_name}ReadOnlyFetch"), Span::call_site());
123        ensure_no_collision(new_ident, tokens.clone())
124    } else {
125        #[allow(clippy::redundant_clone)]
126        fetch_struct_name.clone()
127    };
128
129    let marker_name =
130        ensure_no_collision(format_ident!("_world_query_derive_marker"), tokens.clone());
131
132    // Generate a name for the state struct that doesn't conflict
133    // with the struct definition.
134    let state_struct_name = Ident::new(&format!("{struct_name}State"), Span::call_site());
135    let state_struct_name = ensure_no_collision(state_struct_name, tokens);
136
137    let Data::Struct(DataStruct { fields, .. }) = &ast.data else {
138        return syn::Error::new(
139            Span::call_site(),
140            "#[derive(QueryData)]` only supports structs",
141        )
142        .into_compile_error()
143        .into();
144    };
145
146    let mut field_attrs = Vec::new();
147    let mut field_visibilities = Vec::new();
148    let mut field_idents = Vec::new();
149    let mut named_field_idents = Vec::new();
150    let mut field_types = Vec::new();
151    let mut read_only_field_types = Vec::new();
152    for (i, field) in fields.iter().enumerate() {
153        let attrs = match read_world_query_field_info(field) {
154            Ok(QueryDataFieldInfo { attrs }) => attrs,
155            Err(e) => return e.into_compile_error().into(),
156        };
157
158        let named_field_ident = field
159            .ident
160            .as_ref()
161            .cloned()
162            .unwrap_or_else(|| format_ident!("f{i}"));
163        let i = Index::from(i);
164        let field_ident = field
165            .ident
166            .as_ref()
167            .map_or(quote! { #i }, |i| quote! { #i });
168        field_idents.push(field_ident);
169        named_field_idents.push(named_field_ident);
170        field_attrs.push(attrs);
171        field_visibilities.push(field.vis.clone());
172        let field_ty = field.ty.clone();
173        field_types.push(quote!(#field_ty));
174        read_only_field_types.push(quote!(<#field_ty as #path::query::QueryData>::ReadOnly));
175    }
176
177    let derive_args = &attributes.derive_args;
178    // `#[derive()]` is valid syntax
179    let derive_macro_call = quote! { #[derive(#derive_args)] };
180
181    let mutable_item_struct = item_struct(
182        &path,
183        fields,
184        &derive_macro_call,
185        &struct_name,
186        &visibility,
187        &item_struct_name,
188        &field_types,
189        &user_impl_generics_with_world,
190        &field_attrs,
191        &field_visibilities,
192        &field_idents,
193        &user_ty_generics,
194        &user_ty_generics_with_world,
195        user_where_clauses_with_world,
196    );
197    let mutable_world_query_impl = world_query_impl(
198        &path,
199        &struct_name,
200        &visibility,
201        &item_struct_name,
202        &fetch_struct_name,
203        &field_types,
204        &user_impl_generics,
205        &user_impl_generics_with_world,
206        &field_idents,
207        &user_ty_generics,
208        &user_ty_generics_with_world,
209        &named_field_idents,
210        &marker_name,
211        &state_struct_name,
212        user_where_clauses,
213        user_where_clauses_with_world,
214    );
215
216    let (read_only_struct, read_only_impl) = if attributes.is_mutable {
217        // If the query is mutable, we need to generate a separate readonly version of some things
218        let readonly_item_struct = item_struct(
219            &path,
220            fields,
221            &derive_macro_call,
222            &read_only_struct_name,
223            &visibility,
224            &read_only_item_struct_name,
225            &read_only_field_types,
226            &user_impl_generics_with_world,
227            &field_attrs,
228            &field_visibilities,
229            &field_idents,
230            &user_ty_generics,
231            &user_ty_generics_with_world,
232            user_where_clauses_with_world,
233        );
234        let readonly_world_query_impl = world_query_impl(
235            &path,
236            &read_only_struct_name,
237            &visibility,
238            &read_only_item_struct_name,
239            &read_only_fetch_struct_name,
240            &read_only_field_types,
241            &user_impl_generics,
242            &user_impl_generics_with_world,
243            &field_idents,
244            &user_ty_generics,
245            &user_ty_generics_with_world,
246            &named_field_idents,
247            &marker_name,
248            &state_struct_name,
249            user_where_clauses,
250            user_where_clauses_with_world,
251        );
252        let read_only_structs = quote! {
253            #[doc = "Automatically generated [`WorldQuery`] type for a read-only variant of [`"]
254            #[doc = stringify!(#struct_name)]
255            #[doc = "`]."]
256            #[automatically_derived]
257            #visibility struct #read_only_struct_name #user_impl_generics #user_where_clauses {
258                #(
259                    #[doc = "Automatically generated read-only field for accessing `"]
260                    #[doc = stringify!(#field_types)]
261                    #[doc = "`."]
262                    #field_visibilities #named_field_idents: #read_only_field_types,
263                )*
264            }
265
266            #readonly_item_struct
267        };
268        (read_only_structs, readonly_world_query_impl)
269    } else {
270        (quote! {}, quote! {})
271    };
272
273    let data_impl = {
274        let read_only_data_impl = if attributes.is_mutable {
275            quote! {
276                /// SAFETY: we assert fields are readonly below
277                unsafe impl #user_impl_generics #path::query::QueryData
278                for #read_only_struct_name #user_ty_generics #user_where_clauses {
279                    type ReadOnly = #read_only_struct_name #user_ty_generics;
280                }
281            }
282        } else {
283            quote! {}
284        };
285
286        quote! {
287            /// SAFETY: we assert fields are readonly below
288            unsafe impl #user_impl_generics #path::query::QueryData
289            for #struct_name #user_ty_generics #user_where_clauses {
290                type ReadOnly = #read_only_struct_name #user_ty_generics;
291            }
292
293            #read_only_data_impl
294        }
295    };
296
297    let read_only_data_impl = quote! {
298        /// SAFETY: we assert fields are readonly below
299        unsafe impl #user_impl_generics #path::query::ReadOnlyQueryData
300        for #read_only_struct_name #user_ty_generics #user_where_clauses {}
301    };
302
303    let read_only_asserts = if attributes.is_mutable {
304        quote! {
305            // Double-check that the data fetched by `<_ as WorldQuery>::ReadOnly` is read-only.
306            // This is technically unnecessary as `<_ as WorldQuery>::ReadOnly: ReadOnlyQueryData`
307            // but to protect against future mistakes we assert the assoc type implements `ReadOnlyQueryData` anyway
308            #( assert_readonly::<#read_only_field_types>(); )*
309        }
310    } else {
311        quote! {
312            // Statically checks that the safety guarantee of `ReadOnlyQueryData` for `$fetch_struct_name` actually holds true.
313            // We need this to make sure that we don't compile `ReadOnlyQueryData` if our struct contains nested `QueryData`
314            // members that don't implement it. I.e.:
315            // ```
316            // #[derive(QueryData)]
317            // pub struct Foo { a: &'static mut MyComponent }
318            // ```
319            #( assert_readonly::<#field_types>(); )*
320        }
321    };
322
323    let data_asserts = quote! {
324        #( assert_data::<#field_types>(); )*
325    };
326
327    TokenStream::from(quote! {
328        #mutable_item_struct
329
330        #read_only_struct
331
332        const _: () = {
333            #[doc(hidden)]
334            #[doc = "Automatically generated internal [`WorldQuery`] state type for [`"]
335            #[doc = stringify!(#struct_name)]
336            #[doc = "`], used for caching."]
337            #[automatically_derived]
338            #visibility struct #state_struct_name #user_impl_generics #user_where_clauses {
339                #(#named_field_idents: <#field_types as #path::query::WorldQuery>::State,)*
340            }
341
342            #mutable_world_query_impl
343
344            #read_only_impl
345
346            #data_impl
347
348            #read_only_data_impl
349        };
350
351        #[allow(dead_code)]
352        const _: () = {
353            fn assert_readonly<T>()
354            where
355                T: #path::query::ReadOnlyQueryData,
356            {
357            }
358
359            fn assert_data<T>()
360            where
361                T: #path::query::QueryData,
362            {
363            }
364
365            // We generate a readonly assertion for every struct member.
366            fn assert_all #user_impl_generics_with_world () #user_where_clauses_with_world {
367                #read_only_asserts
368                #data_asserts
369            }
370        };
371
372        // The original struct will most likely be left unused. As we don't want our users having
373        // to specify `#[allow(dead_code)]` for their custom queries, we are using this cursed
374        // workaround.
375        #[allow(dead_code)]
376        const _: () = {
377            fn dead_code_workaround #user_impl_generics (
378                q: #struct_name #user_ty_generics,
379                q2: #read_only_struct_name #user_ty_generics
380            ) #user_where_clauses {
381                #(q.#field_idents;)*
382                #(q2.#field_idents;)*
383            }
384        };
385    })
386}
387
388struct QueryDataFieldInfo {
389    /// All field attributes except for `query_data` ones.
390    attrs: Vec<Attribute>,
391}
392
393fn read_world_query_field_info(field: &Field) -> syn::Result<QueryDataFieldInfo> {
394    let mut attrs = Vec::new();
395    for attr in &field.attrs {
396        if attr
397            .path()
398            .get_ident()
399            .map_or(false, |ident| ident == QUERY_DATA_ATTRIBUTE_NAME)
400        {
401            return Err(syn::Error::new_spanned(
402                attr,
403                "#[derive(QueryData)] does not support field attributes.",
404            ));
405        }
406        attrs.push(attr.clone());
407    }
408
409    Ok(QueryDataFieldInfo { attrs })
410}