bevy_ecs_macros/
states.rs

1use proc_macro::TokenStream;
2use quote::{format_ident, quote};
3use syn::{parse_macro_input, spanned::Spanned, DeriveInput, Pat, Path, Result};
4
5use crate::bevy_ecs_path;
6
7pub fn derive_states(input: TokenStream) -> TokenStream {
8    let ast = parse_macro_input!(input as DeriveInput);
9
10    let generics = ast.generics;
11    let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
12
13    let mut base_trait_path = bevy_ecs_path();
14    base_trait_path
15        .segments
16        .push(format_ident!("schedule").into());
17
18    let mut trait_path = base_trait_path.clone();
19    trait_path.segments.push(format_ident!("States").into());
20
21    let mut state_mutation_trait_path = base_trait_path.clone();
22    state_mutation_trait_path
23        .segments
24        .push(format_ident!("FreelyMutableState").into());
25
26    let struct_name = &ast.ident;
27
28    quote! {
29        impl #impl_generics #trait_path for #struct_name #ty_generics #where_clause {}
30
31        impl #impl_generics #state_mutation_trait_path for #struct_name #ty_generics #where_clause {
32        }
33    }
34    .into()
35}
36
37struct Source {
38    source_type: Path,
39    source_value: Pat,
40}
41
42fn parse_sources_attr(ast: &DeriveInput) -> Result<Source> {
43    let mut result = ast
44        .attrs
45        .iter()
46        .filter(|a| a.path().is_ident("source"))
47        .map(|meta| {
48            let mut source = None;
49            let value = meta.parse_nested_meta(|nested| {
50                let source_type = nested.path.clone();
51                let source_value = Pat::parse_multi(nested.value()?)?;
52                source = Some(Source {
53                    source_type,
54                    source_value,
55                });
56                Ok(())
57            });
58            match source {
59                Some(value) => Ok(value),
60                None => match value {
61                    Ok(_) => Err(syn::Error::new(
62                        ast.span(),
63                        "Couldn't parse SubStates source",
64                    )),
65                    Err(e) => Err(e),
66                },
67            }
68        })
69        .collect::<Result<Vec<_>>>()?;
70
71    if result.len() > 1 {
72        return Err(syn::Error::new(
73            ast.span(),
74            "Only one source is allowed for SubStates",
75        ));
76    }
77
78    let Some(result) = result.pop() else {
79        return Err(syn::Error::new(ast.span(), "SubStates require a source"));
80    };
81
82    Ok(result)
83}
84
85pub fn derive_substates(input: TokenStream) -> TokenStream {
86    let ast = parse_macro_input!(input as DeriveInput);
87    let sources = parse_sources_attr(&ast).expect("Failed to parse substate sources");
88
89    let generics = ast.generics;
90    let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
91
92    let mut base_trait_path = bevy_ecs_path();
93    base_trait_path
94        .segments
95        .push(format_ident!("schedule").into());
96
97    let mut trait_path = base_trait_path.clone();
98    trait_path.segments.push(format_ident!("SubStates").into());
99
100    let mut state_set_trait_path = base_trait_path.clone();
101    state_set_trait_path
102        .segments
103        .push(format_ident!("StateSet").into());
104
105    let mut state_trait_path = base_trait_path.clone();
106    state_trait_path
107        .segments
108        .push(format_ident!("States").into());
109
110    let mut state_mutation_trait_path = base_trait_path.clone();
111    state_mutation_trait_path
112        .segments
113        .push(format_ident!("FreelyMutableState").into());
114
115    let struct_name = &ast.ident;
116
117    let source_state_type = sources.source_type;
118    let source_state_value = sources.source_value;
119
120    let result = quote! {
121        impl #impl_generics #trait_path for #struct_name #ty_generics #where_clause {
122            type SourceStates = #source_state_type;
123
124            fn should_exist(sources: #source_state_type) -> Option<Self> {
125                if matches!(sources, #source_state_value) {
126                    Some(Self::default())
127                } else {
128                    None
129                }
130            }
131        }
132
133        impl #impl_generics #state_trait_path for #struct_name #ty_generics #where_clause {
134            const DEPENDENCY_DEPTH : usize = <Self as #trait_path>::SourceStates::SET_DEPENDENCY_DEPTH + 1;
135        }
136
137        impl #impl_generics #state_mutation_trait_path for #struct_name #ty_generics #where_clause {
138        }
139    };
140
141    // panic!("Got Result\n{}", result.to_string());
142
143    result.into()
144}