bevy_macro_utils/
label.rs

1use proc_macro::{TokenStream, TokenTree};
2use quote::{quote, quote_spanned};
3use std::collections::HashSet;
4use syn::{spanned::Spanned, Ident};
5
6/// Finds an identifier that will not conflict with the specified set of tokens.
7/// If the identifier is present in `haystack`, extra characters will be added
8/// to it until it no longer conflicts with anything.
9///
10/// Note that the returned identifier can still conflict in niche cases,
11/// such as if an identifier in `haystack` is hidden behind an un-expanded macro.
12pub fn ensure_no_collision(value: Ident, haystack: TokenStream) -> Ident {
13    // Collect all the identifiers in `haystack` into a set.
14    let idents = {
15        // List of token streams that will be visited in future loop iterations.
16        let mut unvisited = vec![haystack];
17        // Identifiers we have found while searching tokens.
18        let mut found = HashSet::new();
19        while let Some(tokens) = unvisited.pop() {
20            for t in tokens {
21                match t {
22                    // Collect any identifiers we encounter.
23                    TokenTree::Ident(ident) => {
24                        found.insert(ident.to_string());
25                    }
26                    // Queue up nested token streams to be visited in a future loop iteration.
27                    TokenTree::Group(g) => unvisited.push(g.stream()),
28                    TokenTree::Punct(_) | TokenTree::Literal(_) => {}
29                }
30            }
31        }
32
33        found
34    };
35
36    let span = value.span();
37
38    // If there's a collision, add more characters to the identifier
39    // until it doesn't collide with anything anymore.
40    let mut value = value.to_string();
41    while idents.contains(&value) {
42        value.push('X');
43    }
44
45    Ident::new(&value, span)
46}
47
48/// Derive a label trait
49///
50/// # Args
51///
52/// - `input`: The [`syn::DeriveInput`] for struct that is deriving the label trait
53/// - `trait_name`: Name of the label trait
54/// - `trait_path`: The [path](`syn::Path`) to the label trait
55/// - `dyn_eq_path`: The [path](`syn::Path`) to the `DynEq` trait
56pub fn derive_label(
57    input: syn::DeriveInput,
58    trait_name: &str,
59    trait_path: &syn::Path,
60    dyn_eq_path: &syn::Path,
61) -> TokenStream {
62    if let syn::Data::Union(_) = &input.data {
63        let message = format!("Cannot derive {trait_name} for unions.");
64        return quote_spanned! {
65            input.span() => compile_error!(#message);
66        }
67        .into();
68    }
69
70    let ident = input.ident.clone();
71    let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
72    let mut where_clause = where_clause.cloned().unwrap_or_else(|| syn::WhereClause {
73        where_token: Default::default(),
74        predicates: Default::default(),
75    });
76    where_clause.predicates.push(
77        syn::parse2(quote! {
78            Self: 'static + Send + Sync + Clone + Eq + ::std::fmt::Debug + ::std::hash::Hash
79        })
80        .unwrap(),
81    );
82    quote! {
83        impl #impl_generics #trait_path for #ident #ty_generics #where_clause {
84            fn dyn_clone(&self) -> ::std::boxed::Box<dyn #trait_path> {
85                ::std::boxed::Box::new(::std::clone::Clone::clone(self))
86            }
87
88            fn as_dyn_eq(&self) -> &dyn #dyn_eq_path {
89                self
90            }
91
92            fn dyn_hash(&self, mut state: &mut dyn ::std::hash::Hasher) {
93                let ty_id = ::std::any::TypeId::of::<Self>();
94                ::std::hash::Hash::hash(&ty_id, &mut state);
95                ::std::hash::Hash::hash(self, &mut state);
96            }
97        }
98    }
99    .into()
100}