logos_codegen/parser/
type_params.rs

1use proc_macro2::{Ident, Span, TokenStream};
2use quote::quote;
3use syn::spanned::Spanned;
4use syn::{Lifetime, LifetimeParam, Path, Type};
5
6use crate::error::Errors;
7
8#[derive(Default)]
9pub struct TypeParams {
10    lifetime: bool,
11    type_params: Vec<(Ident, Option<Type>)>,
12}
13
14impl TypeParams {
15    pub fn explicit_lifetime(&mut self, lt: LifetimeParam, errors: &mut Errors) {
16        if self.lifetime {
17            let span = lt.span();
18
19            errors.err("Logos types can only have one lifetime can be set", span);
20        }
21
22        self.lifetime = true;
23    }
24
25    pub fn add(&mut self, param: Ident) {
26        self.type_params.push((param, None));
27    }
28
29    pub fn set(&mut self, param: Ident, ty: TokenStream, errors: &mut Errors) {
30        let ty = match syn::parse2::<Type>(ty) {
31            Ok(mut ty) => {
32                replace_lifetimes(&mut ty);
33                ty
34            }
35            Err(err) => {
36                errors.err(err.to_string(), err.span());
37                return;
38            }
39        };
40
41        match self.type_params.iter_mut().find(|(name, _)| *name == param) {
42            Some((_, slot)) => {
43                if let Some(previous) = slot.replace(ty) {
44                    errors
45                        .err(
46                            format!("{} can only have one type assigned to it", param),
47                            param.span(),
48                        )
49                        .err("Previously assigned here", previous.span());
50                }
51            }
52            None => {
53                errors.err(
54                    format!("{} is not a declared type parameter", param),
55                    param.span(),
56                );
57            }
58        }
59    }
60
61    pub fn find(&self, path: &Path) -> Option<Type> {
62        for (ident, ty) in &self.type_params {
63            if path.is_ident(ident) {
64                return ty.clone();
65            }
66        }
67
68        None
69    }
70
71    pub fn generics(&self, errors: &mut Errors) -> Option<TokenStream> {
72        if !self.lifetime && self.type_params.is_empty() {
73            return None;
74        }
75
76        let mut generics = Vec::new();
77
78        if self.lifetime {
79            generics.push(quote!('s));
80        }
81
82        for (ty, replace) in self.type_params.iter() {
83            match replace {
84                Some(ty) => generics.push(quote!(#ty)),
85                None => {
86                    errors.err(
87                        format!(
88                            "Generic type parameter without a concrete type\n\
89                            \n\
90                            Define a concrete type Logos can use: #[logos(type {} = Type)]",
91                            ty,
92                        ),
93                        ty.span(),
94                    );
95                }
96            }
97        }
98
99        if generics.is_empty() {
100            None
101        } else {
102            Some(quote!(<#(#generics),*>))
103        }
104    }
105}
106
107pub fn replace_lifetimes(ty: &mut Type) {
108    traverse_type(ty, &mut replace_lifetime)
109}
110
111pub fn replace_lifetime(ty: &mut Type) {
112    use syn::{GenericArgument, PathArguments};
113
114    match ty {
115        Type::Path(p) => {
116            p.path
117                .segments
118                .iter_mut()
119                .filter_map(|segment| match &mut segment.arguments {
120                    PathArguments::AngleBracketed(ab) => Some(ab),
121                    _ => None,
122                })
123                .flat_map(|ab| ab.args.iter_mut())
124                .for_each(|arg| {
125                    if let GenericArgument::Lifetime(lt) = arg {
126                        *lt = Lifetime::new("'s", lt.span());
127                    }
128                });
129        }
130        Type::Reference(r) => {
131            let span = match r.lifetime.take() {
132                Some(lt) => lt.span(),
133                None => Span::call_site(),
134            };
135
136            r.lifetime = Some(Lifetime::new("'s", span));
137        }
138        _ => (),
139    }
140}
141
142pub fn traverse_type(ty: &mut Type, f: &mut impl FnMut(&mut Type)) {
143    f(ty);
144    match ty {
145        Type::Array(array) => traverse_type(&mut array.elem, f),
146        Type::BareFn(bare_fn) => {
147            for input in &mut bare_fn.inputs {
148                traverse_type(&mut input.ty, f);
149            }
150            if let syn::ReturnType::Type(_, ty) = &mut bare_fn.output {
151                traverse_type(ty, f);
152            }
153        }
154        Type::Group(group) => traverse_type(&mut group.elem, f),
155        Type::Paren(paren) => traverse_type(&mut paren.elem, f),
156        Type::Path(path) => traverse_path(&mut path.path, f),
157        Type::Ptr(p) => traverse_type(&mut p.elem, f),
158        Type::Reference(r) => traverse_type(&mut r.elem, f),
159        Type::Slice(slice) => traverse_type(&mut slice.elem, f),
160        Type::TraitObject(object) => object.bounds.iter_mut().for_each(|bound| {
161            if let syn::TypeParamBound::Trait(trait_bound) = bound {
162                traverse_path(&mut trait_bound.path, f);
163            }
164        }),
165        Type::Tuple(tuple) => tuple
166            .elems
167            .iter_mut()
168            .for_each(|elem| traverse_type(elem, f)),
169        _ => (),
170    }
171}
172
173fn traverse_path(path: &mut Path, f: &mut impl FnMut(&mut Type)) {
174    for segment in &mut path.segments {
175        match &mut segment.arguments {
176            syn::PathArguments::None => (),
177            syn::PathArguments::AngleBracketed(args) => {
178                for arg in &mut args.args {
179                    match arg {
180                        syn::GenericArgument::Type(ty) => {
181                            traverse_type(ty, f);
182                        }
183                        syn::GenericArgument::AssocType(assoc) => {
184                            traverse_type(&mut assoc.ty, f);
185                        }
186                        _ => (),
187                    }
188                }
189            }
190            syn::PathArguments::Parenthesized(args) => {
191                for arg in &mut args.inputs {
192                    traverse_type(arg, f);
193                }
194                if let syn::ReturnType::Type(_, ty) = &mut args.output {
195                    traverse_type(ty, f);
196                }
197            }
198        }
199    }
200}