bevy_ecs_macros/
lib.rs

1// FIXME(3492): remove once docs are ready
2#![allow(missing_docs)]
3#![cfg_attr(docsrs, feature(doc_auto_cfg))]
4
5extern crate proc_macro;
6
7mod component;
8mod query_data;
9mod query_filter;
10mod states;
11mod world_query;
12
13use crate::{query_data::derive_query_data_impl, query_filter::derive_query_filter_impl};
14use bevy_macro_utils::{derive_label, ensure_no_collision, get_struct_fields, BevyManifest};
15use proc_macro::TokenStream;
16use proc_macro2::Span;
17use quote::{format_ident, quote};
18use syn::{
19    parse_macro_input, parse_quote, punctuated::Punctuated, spanned::Spanned, token::Comma,
20    ConstParam, DeriveInput, GenericParam, Ident, Index, TypeParam,
21};
22
23enum BundleFieldKind {
24    Component,
25    Ignore,
26}
27
28const BUNDLE_ATTRIBUTE_NAME: &str = "bundle";
29const BUNDLE_ATTRIBUTE_IGNORE_NAME: &str = "ignore";
30
31#[proc_macro_derive(Bundle, attributes(bundle))]
32pub fn derive_bundle(input: TokenStream) -> TokenStream {
33    let ast = parse_macro_input!(input as DeriveInput);
34    let ecs_path = bevy_ecs_path();
35
36    let named_fields = match get_struct_fields(&ast.data) {
37        Ok(fields) => fields,
38        Err(e) => return e.into_compile_error().into(),
39    };
40
41    let mut field_kind = Vec::with_capacity(named_fields.len());
42
43    for field in named_fields {
44        for attr in field
45            .attrs
46            .iter()
47            .filter(|a| a.path().is_ident(BUNDLE_ATTRIBUTE_NAME))
48        {
49            if let Err(error) = attr.parse_nested_meta(|meta| {
50                if meta.path.is_ident(BUNDLE_ATTRIBUTE_IGNORE_NAME) {
51                    field_kind.push(BundleFieldKind::Ignore);
52                    Ok(())
53                } else {
54                    Err(meta.error(format!(
55                        "Invalid bundle attribute. Use `{BUNDLE_ATTRIBUTE_IGNORE_NAME}`"
56                    )))
57                }
58            }) {
59                return error.into_compile_error().into();
60            }
61        }
62
63        field_kind.push(BundleFieldKind::Component);
64    }
65
66    let field = named_fields
67        .iter()
68        .map(|field| field.ident.as_ref())
69        .collect::<Vec<_>>();
70
71    let field_type = named_fields
72        .iter()
73        .map(|field| &field.ty)
74        .collect::<Vec<_>>();
75
76    let mut field_component_ids = Vec::new();
77    let mut field_get_component_ids = Vec::new();
78    let mut field_get_components = Vec::new();
79    let mut field_from_components = Vec::new();
80    for (((i, field_type), field_kind), field) in field_type
81        .iter()
82        .enumerate()
83        .zip(field_kind.iter())
84        .zip(field.iter())
85    {
86        match field_kind {
87            BundleFieldKind::Component => {
88                field_component_ids.push(quote! {
89                <#field_type as #ecs_path::bundle::Bundle>::component_ids(components, storages, &mut *ids);
90                });
91                field_get_component_ids.push(quote! {
92                    <#field_type as #ecs_path::bundle::Bundle>::get_component_ids(components, &mut *ids);
93                });
94                match field {
95                    Some(field) => {
96                        field_get_components.push(quote! {
97                            self.#field.get_components(&mut *func);
98                        });
99                        field_from_components.push(quote! {
100                            #field: <#field_type as #ecs_path::bundle::Bundle>::from_components(ctx, &mut *func),
101                        });
102                    }
103                    None => {
104                        let index = syn::Index::from(i);
105                        field_get_components.push(quote! {
106                            self.#index.get_components(&mut *func);
107                        });
108                        field_from_components.push(quote! {
109                            #index: <#field_type as #ecs_path::bundle::Bundle>::from_components(ctx, &mut *func),
110                        });
111                    }
112                }
113            }
114
115            BundleFieldKind::Ignore => {
116                field_from_components.push(quote! {
117                    #field: ::std::default::Default::default(),
118                });
119            }
120        }
121    }
122    let generics = ast.generics;
123    let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
124    let struct_name = &ast.ident;
125
126    TokenStream::from(quote! {
127        // SAFETY:
128        // - ComponentId is returned in field-definition-order. [from_components] and [get_components] use field-definition-order
129        // - `Bundle::get_components` is exactly once for each member. Rely's on the Component -> Bundle implementation to properly pass
130        //   the correct `StorageType` into the callback.
131        unsafe impl #impl_generics #ecs_path::bundle::Bundle for #struct_name #ty_generics #where_clause {
132            fn component_ids(
133                components: &mut #ecs_path::component::Components,
134                storages: &mut #ecs_path::storage::Storages,
135                ids: &mut impl FnMut(#ecs_path::component::ComponentId)
136            ){
137                #(#field_component_ids)*
138            }
139
140            fn get_component_ids(
141                components: &#ecs_path::component::Components,
142                ids: &mut impl FnMut(Option<#ecs_path::component::ComponentId>)
143            ){
144                #(#field_get_component_ids)*
145            }
146
147            #[allow(unused_variables, non_snake_case)]
148            unsafe fn from_components<__T, __F>(ctx: &mut __T, func: &mut __F) -> Self
149            where
150                __F: FnMut(&mut __T) -> #ecs_path::ptr::OwningPtr<'_>
151            {
152                Self{
153                    #(#field_from_components)*
154                }
155            }
156        }
157
158        impl #impl_generics #ecs_path::bundle::DynamicBundle for #struct_name #ty_generics #where_clause {
159            #[allow(unused_variables)]
160            #[inline]
161            fn get_components(
162                self,
163                func: &mut impl FnMut(#ecs_path::component::StorageType, #ecs_path::ptr::OwningPtr<'_>)
164            ) {
165                #(#field_get_components)*
166            }
167        }
168    })
169}
170
171fn get_idents(fmt_string: fn(usize) -> String, count: usize) -> Vec<Ident> {
172    (0..count)
173        .map(|i| Ident::new(&fmt_string(i), Span::call_site()))
174        .collect::<Vec<Ident>>()
175}
176
177#[proc_macro]
178pub fn impl_param_set(_input: TokenStream) -> TokenStream {
179    let mut tokens = TokenStream::new();
180    let max_params = 8;
181    let params = get_idents(|i| format!("P{i}"), max_params);
182    let metas = get_idents(|i| format!("m{i}"), max_params);
183    let mut param_fn_muts = Vec::new();
184    for (i, param) in params.iter().enumerate() {
185        let fn_name = Ident::new(&format!("p{i}"), Span::call_site());
186        let index = Index::from(i);
187        let ordinal = match i {
188            1 => "1st".to_owned(),
189            2 => "2nd".to_owned(),
190            3 => "3rd".to_owned(),
191            x => format!("{x}th"),
192        };
193        let comment =
194            format!("Gets exclusive access to the {ordinal} parameter in this [`ParamSet`].");
195        param_fn_muts.push(quote! {
196            #[doc = #comment]
197            /// No other parameters may be accessed while this one is active.
198            pub fn #fn_name<'a>(&'a mut self) -> SystemParamItem<'a, 'a, #param> {
199                // SAFETY: systems run without conflicts with other systems.
200                // Conflicting params in ParamSet are not accessible at the same time
201                // ParamSets are guaranteed to not conflict with other SystemParams
202                unsafe {
203                    #param::get_param(&mut self.param_states.#index, &self.system_meta, self.world, self.change_tick)
204                }
205            }
206        });
207    }
208
209    for param_count in 1..=max_params {
210        let param = &params[0..param_count];
211        let meta = &metas[0..param_count];
212        let param_fn_mut = &param_fn_muts[0..param_count];
213        tokens.extend(TokenStream::from(quote! {
214            // SAFETY: All parameters are constrained to ReadOnlySystemParam, so World is only read
215            unsafe impl<'w, 's, #(#param,)*> ReadOnlySystemParam for ParamSet<'w, 's, (#(#param,)*)>
216            where #(#param: ReadOnlySystemParam,)*
217            { }
218
219            // SAFETY: Relevant parameter ComponentId and ArchetypeComponentId access is applied to SystemMeta. If any ParamState conflicts
220            // with any prior access, a panic will occur.
221            unsafe impl<'_w, '_s, #(#param: SystemParam,)*> SystemParam for ParamSet<'_w, '_s, (#(#param,)*)>
222            {
223                type State = (#(#param::State,)*);
224                type Item<'w, 's> = ParamSet<'w, 's, (#(#param,)*)>;
225
226                // Note: We allow non snake case so the compiler don't complain about the creation of non_snake_case variables
227                #[allow(non_snake_case)]
228                fn init_state(world: &mut World, system_meta: &mut SystemMeta) -> Self::State {
229                    #(
230                        // Pretend to add each param to the system alone, see if it conflicts
231                        let mut #meta = system_meta.clone();
232                        #meta.component_access_set.clear();
233                        #meta.archetype_component_access.clear();
234                        #param::init_state(world, &mut #meta);
235                        // The variable is being defined with non_snake_case here
236                        let #param = #param::init_state(world, &mut system_meta.clone());
237                    )*
238                    // Make the ParamSet non-send if any of its parameters are non-send.
239                    if false #(|| !#meta.is_send())* {
240                        system_meta.set_non_send();
241                    }
242                    #(
243                        system_meta
244                            .component_access_set
245                            .extend(#meta.component_access_set);
246                        system_meta
247                            .archetype_component_access
248                            .extend(&#meta.archetype_component_access);
249                    )*
250                    (#(#param,)*)
251                }
252
253                unsafe fn new_archetype(state: &mut Self::State, archetype: &Archetype, system_meta: &mut SystemMeta) {
254                    // SAFETY: The caller ensures that `archetype` is from the World the state was initialized from in `init_state`.
255                    unsafe { <(#(#param,)*) as SystemParam>::new_archetype(state, archetype, system_meta); }
256                }
257
258                fn apply(state: &mut Self::State, system_meta: &SystemMeta, world: &mut World) {
259                    <(#(#param,)*) as SystemParam>::apply(state, system_meta, world);
260                }
261
262                #[inline]
263                unsafe fn get_param<'w, 's>(
264                    state: &'s mut Self::State,
265                    system_meta: &SystemMeta,
266                    world: UnsafeWorldCell<'w>,
267                    change_tick: Tick,
268                ) -> Self::Item<'w, 's> {
269                    ParamSet {
270                        param_states: state,
271                        system_meta: system_meta.clone(),
272                        world,
273                        change_tick,
274                    }
275                }
276            }
277
278            impl<'w, 's, #(#param: SystemParam,)*> ParamSet<'w, 's, (#(#param,)*)>
279            {
280                #(#param_fn_mut)*
281            }
282        }));
283    }
284
285    tokens
286}
287
288/// Implement `SystemParam` to use a struct as a parameter in a system
289#[proc_macro_derive(SystemParam, attributes(system_param))]
290pub fn derive_system_param(input: TokenStream) -> TokenStream {
291    let token_stream = input.clone();
292    let ast = parse_macro_input!(input as DeriveInput);
293    let syn::Data::Struct(syn::DataStruct {
294        fields: field_definitions,
295        ..
296    }) = ast.data
297    else {
298        return syn::Error::new(
299            ast.span(),
300            "Invalid `SystemParam` type: expected a `struct`",
301        )
302        .into_compile_error()
303        .into();
304    };
305    let path = bevy_ecs_path();
306
307    let mut field_locals = Vec::new();
308    let mut fields = Vec::new();
309    let mut field_types = Vec::new();
310    for (i, field) in field_definitions.iter().enumerate() {
311        field_locals.push(format_ident!("f{i}"));
312        let i = Index::from(i);
313        fields.push(
314            field
315                .ident
316                .as_ref()
317                .map(|f| quote! { #f })
318                .unwrap_or_else(|| quote! { #i }),
319        );
320        field_types.push(&field.ty);
321    }
322
323    let generics = ast.generics;
324
325    // Emit an error if there's any unrecognized lifetime names.
326    for lt in generics.lifetimes() {
327        let ident = &lt.lifetime.ident;
328        let w = format_ident!("w");
329        let s = format_ident!("s");
330        if ident != &w && ident != &s {
331            return syn::Error::new_spanned(
332                lt,
333                r#"invalid lifetime name: expected `'w` or `'s`
334 'w -- refers to data stored in the World.
335 's -- refers to data stored in the SystemParam's state.'"#,
336            )
337            .into_compile_error()
338            .into();
339        }
340    }
341
342    let (_impl_generics, ty_generics, where_clause) = generics.split_for_impl();
343
344    let lifetimeless_generics: Vec<_> = generics
345        .params
346        .iter()
347        .filter(|g| !matches!(g, GenericParam::Lifetime(_)))
348        .collect();
349
350    let shadowed_lifetimes: Vec<_> = generics.lifetimes().map(|_| quote!('_)).collect();
351
352    let mut punctuated_generics = Punctuated::<_, Comma>::new();
353    punctuated_generics.extend(lifetimeless_generics.iter().map(|g| match g {
354        GenericParam::Type(g) => GenericParam::Type(TypeParam {
355            default: None,
356            ..g.clone()
357        }),
358        GenericParam::Const(g) => GenericParam::Const(ConstParam {
359            default: None,
360            ..g.clone()
361        }),
362        _ => unreachable!(),
363    }));
364
365    let mut punctuated_generic_idents = Punctuated::<_, Comma>::new();
366    punctuated_generic_idents.extend(lifetimeless_generics.iter().map(|g| match g {
367        GenericParam::Type(g) => &g.ident,
368        GenericParam::Const(g) => &g.ident,
369        _ => unreachable!(),
370    }));
371
372    let punctuated_generics_no_bounds: Punctuated<_, Comma> = lifetimeless_generics
373        .iter()
374        .map(|&g| match g.clone() {
375            GenericParam::Type(mut g) => {
376                g.bounds.clear();
377                GenericParam::Type(g)
378            }
379            g => g,
380        })
381        .collect();
382
383    let mut tuple_types: Vec<_> = field_types.iter().map(|x| quote! { #x }).collect();
384    let mut tuple_patterns: Vec<_> = field_locals.iter().map(|x| quote! { #x }).collect();
385
386    // If the number of fields exceeds the 16-parameter limit,
387    // fold the fields into tuples of tuples until we are below the limit.
388    const LIMIT: usize = 16;
389    while tuple_types.len() > LIMIT {
390        let end = Vec::from_iter(tuple_types.drain(..LIMIT));
391        tuple_types.push(parse_quote!( (#(#end,)*) ));
392
393        let end = Vec::from_iter(tuple_patterns.drain(..LIMIT));
394        tuple_patterns.push(parse_quote!( (#(#end,)*) ));
395    }
396
397    // Create a where clause for the `ReadOnlySystemParam` impl.
398    // Ensure that each field implements `ReadOnlySystemParam`.
399    let mut read_only_generics = generics.clone();
400    let read_only_where_clause = read_only_generics.make_where_clause();
401    for field_type in &field_types {
402        read_only_where_clause
403            .predicates
404            .push(syn::parse_quote!(#field_type: #path::system::ReadOnlySystemParam));
405    }
406
407    let fields_alias =
408        ensure_no_collision(format_ident!("__StructFieldsAlias"), token_stream.clone());
409
410    let struct_name = &ast.ident;
411    let state_struct_visibility = &ast.vis;
412    let state_struct_name = ensure_no_collision(format_ident!("FetchState"), token_stream);
413
414    TokenStream::from(quote! {
415        // We define the FetchState struct in an anonymous scope to avoid polluting the user namespace.
416        // The struct can still be accessed via SystemParam::State, e.g. EventReaderState can be accessed via
417        // <EventReader<'static, 'static, T> as SystemParam>::State
418        const _: () = {
419            // Allows rebinding the lifetimes of each field type.
420            type #fields_alias <'w, 's, #punctuated_generics_no_bounds> = (#(#tuple_types,)*);
421
422            #[doc(hidden)]
423            #state_struct_visibility struct #state_struct_name <#(#lifetimeless_generics,)*>
424            #where_clause {
425                state: <#fields_alias::<'static, 'static, #punctuated_generic_idents> as #path::system::SystemParam>::State,
426            }
427
428            unsafe impl<#punctuated_generics> #path::system::SystemParam for
429                #struct_name <#(#shadowed_lifetimes,)* #punctuated_generic_idents> #where_clause
430            {
431                type State = #state_struct_name<#punctuated_generic_idents>;
432                type Item<'w, 's> = #struct_name #ty_generics;
433
434                fn init_state(world: &mut #path::world::World, system_meta: &mut #path::system::SystemMeta) -> Self::State {
435                    #state_struct_name {
436                        state: <#fields_alias::<'_, '_, #punctuated_generic_idents> as #path::system::SystemParam>::init_state(world, system_meta),
437                    }
438                }
439
440                unsafe fn new_archetype(state: &mut Self::State, archetype: &#path::archetype::Archetype, system_meta: &mut #path::system::SystemMeta) {
441                    // SAFETY: The caller ensures that `archetype` is from the World the state was initialized from in `init_state`.
442                    unsafe { <#fields_alias::<'_, '_, #punctuated_generic_idents> as #path::system::SystemParam>::new_archetype(&mut state.state, archetype, system_meta) }
443                }
444
445                fn apply(state: &mut Self::State, system_meta: &#path::system::SystemMeta, world: &mut #path::world::World) {
446                    <#fields_alias::<'_, '_, #punctuated_generic_idents> as #path::system::SystemParam>::apply(&mut state.state, system_meta, world);
447                }
448
449                fn queue(state: &mut Self::State, system_meta: &#path::system::SystemMeta, world: #path::world::DeferredWorld) {
450                    <#fields_alias::<'_, '_, #punctuated_generic_idents> as #path::system::SystemParam>::queue(&mut state.state, system_meta, world);
451                }
452
453                unsafe fn get_param<'w, 's>(
454                    state: &'s mut Self::State,
455                    system_meta: &#path::system::SystemMeta,
456                    world: #path::world::unsafe_world_cell::UnsafeWorldCell<'w>,
457                    change_tick: #path::component::Tick,
458                ) -> Self::Item<'w, 's> {
459                    let (#(#tuple_patterns,)*) = <
460                        (#(#tuple_types,)*) as #path::system::SystemParam
461                    >::get_param(&mut state.state, system_meta, world, change_tick);
462                    #struct_name {
463                        #(#fields: #field_locals,)*
464                    }
465                }
466            }
467
468            // Safety: Each field is `ReadOnlySystemParam`, so this can only read from the `World`
469            unsafe impl<'w, 's, #punctuated_generics> #path::system::ReadOnlySystemParam for #struct_name #ty_generics #read_only_where_clause {}
470        };
471    })
472}
473
474/// Implement `QueryData` to use a struct as a data parameter in a query
475#[proc_macro_derive(QueryData, attributes(query_data))]
476pub fn derive_query_data(input: TokenStream) -> TokenStream {
477    derive_query_data_impl(input)
478}
479
480/// Implement `QueryFilter` to use a struct as a filter parameter in a query
481#[proc_macro_derive(QueryFilter, attributes(query_filter))]
482pub fn derive_query_filter(input: TokenStream) -> TokenStream {
483    derive_query_filter_impl(input)
484}
485
486/// Derive macro generating an impl of the trait `ScheduleLabel`.
487///
488/// This does not work for unions.
489#[proc_macro_derive(ScheduleLabel)]
490pub fn derive_schedule_label(input: TokenStream) -> TokenStream {
491    let input = parse_macro_input!(input as DeriveInput);
492    let mut trait_path = bevy_ecs_path();
493    trait_path.segments.push(format_ident!("schedule").into());
494    let mut dyn_eq_path = trait_path.clone();
495    trait_path
496        .segments
497        .push(format_ident!("ScheduleLabel").into());
498    dyn_eq_path.segments.push(format_ident!("DynEq").into());
499    derive_label(input, "ScheduleLabel", &trait_path, &dyn_eq_path)
500}
501
502/// Derive macro generating an impl of the trait `SystemSet`.
503///
504/// This does not work for unions.
505#[proc_macro_derive(SystemSet)]
506pub fn derive_system_set(input: TokenStream) -> TokenStream {
507    let input = parse_macro_input!(input as DeriveInput);
508    let mut trait_path = bevy_ecs_path();
509    trait_path.segments.push(format_ident!("schedule").into());
510    let mut dyn_eq_path = trait_path.clone();
511    trait_path.segments.push(format_ident!("SystemSet").into());
512    dyn_eq_path.segments.push(format_ident!("DynEq").into());
513    derive_label(input, "SystemSet", &trait_path, &dyn_eq_path)
514}
515
516pub(crate) fn bevy_ecs_path() -> syn::Path {
517    BevyManifest::default().get_path("bevy_ecs")
518}
519
520#[proc_macro_derive(Event)]
521pub fn derive_event(input: TokenStream) -> TokenStream {
522    component::derive_event(input)
523}
524
525#[proc_macro_derive(Resource)]
526pub fn derive_resource(input: TokenStream) -> TokenStream {
527    component::derive_resource(input)
528}
529
530#[proc_macro_derive(Component, attributes(component))]
531pub fn derive_component(input: TokenStream) -> TokenStream {
532    component::derive_component(input)
533}
534
535#[proc_macro_derive(States)]
536pub fn derive_states(input: TokenStream) -> TokenStream {
537    states::derive_states(input)
538}
539
540#[proc_macro_derive(SubStates, attributes(source))]
541pub fn derive_substates(input: TokenStream) -> TokenStream {
542    states::derive_substates(input)
543}