bevy_ecs_macros/
query_filter.rs

1use bevy_macro_utils::ensure_no_collision;
2use proc_macro::TokenStream;
3use proc_macro2::{Ident, Span};
4use quote::{format_ident, quote};
5use syn::{parse_macro_input, parse_quote, Data, DataStruct, DeriveInput, Index};
6
7use crate::{
8    bevy_ecs_path,
9    world_query::{item_struct, world_query_impl},
10};
11
12mod field_attr_keywords {
13    syn::custom_keyword!(ignore);
14}
15
16pub fn derive_query_filter_impl(input: TokenStream) -> TokenStream {
17    let tokens = input.clone();
18
19    let ast = parse_macro_input!(input as DeriveInput);
20    let visibility = ast.vis;
21
22    let path = bevy_ecs_path();
23
24    let user_generics = ast.generics.clone();
25    let (user_impl_generics, user_ty_generics, user_where_clauses) = user_generics.split_for_impl();
26    let user_generics_with_world = {
27        let mut generics = ast.generics;
28        generics.params.insert(0, parse_quote!('__w));
29        generics
30    };
31    let (user_impl_generics_with_world, user_ty_generics_with_world, user_where_clauses_with_world) =
32        user_generics_with_world.split_for_impl();
33
34    let struct_name = ast.ident;
35
36    let item_struct_name = Ident::new(&format!("{struct_name}Item"), Span::call_site());
37
38    let fetch_struct_name = Ident::new(&format!("{struct_name}Fetch"), Span::call_site());
39    let fetch_struct_name = ensure_no_collision(fetch_struct_name, tokens.clone());
40
41    let marker_name =
42        ensure_no_collision(format_ident!("_world_query_derive_marker"), tokens.clone());
43
44    // Generate a name for the state struct that doesn't conflict
45    // with the struct definition.
46    let state_struct_name = Ident::new(&format!("{struct_name}State"), Span::call_site());
47    let state_struct_name = ensure_no_collision(state_struct_name, tokens);
48
49    let Data::Struct(DataStruct { fields, .. }) = &ast.data else {
50        return syn::Error::new(
51            Span::call_site(),
52            "#[derive(WorldQuery)]` only supports structs",
53        )
54        .into_compile_error()
55        .into();
56    };
57
58    let mut field_attrs = Vec::new();
59    let mut field_visibilities = Vec::new();
60    let mut field_idents = Vec::new();
61    let mut named_field_idents = Vec::new();
62    let mut field_types = Vec::new();
63    for (i, field) in fields.iter().enumerate() {
64        let attrs = field.attrs.clone();
65
66        let named_field_ident = field
67            .ident
68            .as_ref()
69            .cloned()
70            .unwrap_or_else(|| format_ident!("f{i}"));
71        let i = Index::from(i);
72        let field_ident = field
73            .ident
74            .as_ref()
75            .map_or(quote! { #i }, |i| quote! { #i });
76        field_idents.push(field_ident);
77        named_field_idents.push(named_field_ident);
78        field_attrs.push(attrs);
79        field_visibilities.push(field.vis.clone());
80        let field_ty = field.ty.clone();
81        field_types.push(quote!(#field_ty));
82    }
83
84    let derive_macro_call = quote!();
85
86    let item_struct = item_struct(
87        &path,
88        fields,
89        &derive_macro_call,
90        &struct_name,
91        &visibility,
92        &item_struct_name,
93        &field_types,
94        &user_impl_generics_with_world,
95        &field_attrs,
96        &field_visibilities,
97        &field_idents,
98        &user_ty_generics,
99        &user_ty_generics_with_world,
100        user_where_clauses_with_world,
101    );
102
103    let world_query_impl = world_query_impl(
104        &path,
105        &struct_name,
106        &visibility,
107        &item_struct_name,
108        &fetch_struct_name,
109        &field_types,
110        &user_impl_generics,
111        &user_impl_generics_with_world,
112        &field_idents,
113        &user_ty_generics,
114        &user_ty_generics_with_world,
115        &named_field_idents,
116        &marker_name,
117        &state_struct_name,
118        user_where_clauses,
119        user_where_clauses_with_world,
120    );
121
122    let filter_impl = quote! {
123        impl #user_impl_generics #path::query::QueryFilter
124        for #struct_name #user_ty_generics #user_where_clauses {
125            const IS_ARCHETYPAL: bool = true #(&& <#field_types>::IS_ARCHETYPAL)*;
126
127            #[allow(unused_variables)]
128            #[inline(always)]
129            unsafe fn filter_fetch<'__w>(
130                _fetch: &mut <Self as #path::query::WorldQuery>::Fetch<'__w>,
131                _entity: #path::entity::Entity,
132                _table_row: #path::storage::TableRow,
133            ) -> bool {
134                true #(&& <#field_types>::filter_fetch(&mut _fetch.#named_field_idents, _entity, _table_row))*
135            }
136        }
137    };
138
139    let filter_asserts = quote! {
140        #( assert_filter::<#field_types>(); )*
141    };
142
143    TokenStream::from(quote! {
144        #item_struct
145
146        const _: () = {
147            #[doc(hidden)]
148            #[doc = "Automatically generated internal [`WorldQuery`] state type for [`"]
149            #[doc = stringify!(#struct_name)]
150            #[doc = "`], used for caching."]
151            #[automatically_derived]
152            #visibility struct #state_struct_name #user_impl_generics #user_where_clauses {
153                #(#named_field_idents: <#field_types as #path::query::WorldQuery>::State,)*
154            }
155
156            #world_query_impl
157
158            #filter_impl
159        };
160
161        #[allow(dead_code)]
162        const _: () = {
163
164            fn assert_filter<T>()
165            where
166                T: #path::query::QueryFilter,
167            {
168            }
169
170            // We generate a filter assertion for every struct member.
171            fn assert_all #user_impl_generics_with_world () #user_where_clauses_with_world {
172                #filter_asserts
173            }
174        };
175
176        // The original struct will most likely be left unused. As we don't want our users having
177        // to specify `#[allow(dead_code)]` for their custom queries, we are using this cursed
178        // workaround.
179        #[allow(dead_code)]
180        const _: () = {
181            fn dead_code_workaround #user_impl_generics (
182                q: #struct_name #user_ty_generics,
183                q2: #struct_name #user_ty_generics
184            ) #user_where_clauses {
185                #(q.#field_idents;)*
186                #(q2.#field_idents;)*
187            }
188        };
189    })
190}