1pub(crate) mod r#mut;
4pub(crate) mod r#ref;
5
6use std::{borrow::Cow, iter};
7
8use proc_macro2::TokenStream;
9use quote::{quote, ToTokens};
10use syn::{parse_quote, spanned::Spanned, Token};
11
12use crate::utils::{
13 attr::{self, ParseMultiple as _},
14 Either, GenericsSearch, Spanning,
15};
16
17pub fn expand(
19 input: &syn::DeriveInput,
20 trait_info: ExpansionCtx<'_>,
21) -> syn::Result<TokenStream> {
22 let (trait_ident, attr_name, _) = trait_info;
23
24 let data = match &input.data {
25 syn::Data::Struct(data) => Ok(data),
26 syn::Data::Enum(e) => Err(syn::Error::new(
27 e.enum_token.span(),
28 format!("`{trait_ident}` cannot be derived for enums"),
29 )),
30 syn::Data::Union(u) => Err(syn::Error::new(
31 u.union_token.span(),
32 format!("`{trait_ident}` cannot be derived for unions"),
33 )),
34 }?;
35
36 let expansions = if let Some(attr) =
37 StructAttribute::parse_attrs(&input.attrs, attr_name)?
38 {
39 if data.fields.len() != 1 {
40 return Err(syn::Error::new(
41 if data.fields.is_empty() {
42 data.struct_token.span
43 } else {
44 data.fields.span()
45 },
46 format!(
47 "`#[{attr_name}(...)]` attribute can only be placed on structs with exactly \
48 one field",
49 ),
50 ));
51 }
52
53 let field = data.fields.iter().next().unwrap();
54 if FieldAttribute::parse_attrs(&field.attrs, attr_name)?.is_some() {
55 return Err(syn::Error::new(
56 field.span(),
57 format!("`#[{attr_name}(...)]` cannot be placed on both struct and its field"),
58 ));
59 }
60
61 vec![Expansion {
62 trait_info,
63 ident: &input.ident,
64 generics: &input.generics,
65 field,
66 field_index: 0,
67 conversions: Some(attr.into_inner()),
68 }]
69 } else {
70 let attrs = data
71 .fields
72 .iter()
73 .map(|field| FieldAttribute::parse_attrs(&field.attrs, attr_name))
74 .collect::<syn::Result<Vec<_>>>()?;
75
76 let present_attrs = attrs.iter().filter_map(Option::as_ref).collect::<Vec<_>>();
77
78 let all = present_attrs
79 .iter()
80 .all(|attr| matches!(attr.item, FieldAttribute::Skip(_)));
81
82 if !all {
83 if let Some(skip_attr) = present_attrs.iter().find_map(|attr| {
84 if let FieldAttribute::Skip(skip) = &attr.item {
85 Some(attr.as_ref().map(|_| skip))
86 } else {
87 None
88 }
89 }) {
90 return Err(syn::Error::new(
91 skip_attr.span(),
92 format!(
93 "`#[{attr_name}({})]` cannot be used in the same struct with other \
94 `#[{attr_name}(...)]` attributes",
95 skip_attr.name(),
96 ),
97 ));
98 }
99 }
100
101 if all {
102 data.fields
103 .iter()
104 .enumerate()
105 .zip(attrs)
106 .filter_map(|((i, field), attr)| {
107 attr.is_none().then_some(Expansion {
108 trait_info,
109 ident: &input.ident,
110 generics: &input.generics,
111 field,
112 field_index: i,
113 conversions: None,
114 })
115 })
116 .collect()
117 } else {
118 data.fields
119 .iter()
120 .enumerate()
121 .zip(attrs)
122 .filter_map(|((i, field), attr)| match attr.map(Spanning::into_inner) {
123 Some(
124 attr @ (FieldAttribute::Empty(_)
125 | FieldAttribute::Forward(_)
126 | FieldAttribute::Types(_)),
127 ) => Some(Expansion {
128 trait_info,
129 ident: &input.ident,
130 generics: &input.generics,
131 field,
132 field_index: i,
133 conversions: attr.into(),
134 }),
135 Some(FieldAttribute::Skip(_)) => unreachable!(),
136 None => None,
137 })
138 .collect()
139 }
140 };
141 Ok(expansions
142 .into_iter()
143 .map(ToTokens::into_token_stream)
144 .collect())
145}
146
147type ExpansionCtx<'a> = (&'a syn::Ident, &'a syn::Ident, Option<&'a Token![mut]>);
154
155struct Expansion<'a> {
158 trait_info: ExpansionCtx<'a>,
160
161 ident: &'a syn::Ident,
165
166 generics: &'a syn::Generics,
168
169 field: &'a syn::Field,
171
172 field_index: usize,
174
175 conversions: Option<attr::Conversion>,
177}
178
179impl<'a> ToTokens for Expansion<'a> {
180 fn to_tokens(&self, tokens: &mut TokenStream) {
181 let field_ty = &self.field.ty;
182 let field_ident = self.field.ident.as_ref().map_or_else(
183 || Either::Right(syn::Index::from(self.field_index)),
184 Either::Left,
185 );
186
187 let (trait_ident, method_ident, mut_) = &self.trait_info;
188 let ty_ident = &self.ident;
189
190 let field_ref = quote! { & #mut_ self.#field_ident };
191
192 let generics_search = GenericsSearch {
193 types: self.generics.type_params().map(|p| &p.ident).collect(),
194 lifetimes: self
195 .generics
196 .lifetimes()
197 .map(|p| &p.lifetime.ident)
198 .collect(),
199 consts: self.generics.const_params().map(|p| &p.ident).collect(),
200 };
201 let field_contains_generics = generics_search.any_in(field_ty);
202
203 let is_blanket =
204 matches!(&self.conversions, Some(attr::Conversion::Forward(_)));
205
206 let return_tys = match &self.conversions {
207 Some(attr::Conversion::Forward(_)) => {
208 Either::Left(iter::once(Cow::Owned(parse_quote! { __AsT })))
209 }
210 Some(attr::Conversion::Types(tys)) => {
211 Either::Right(tys.0.iter().map(Cow::Borrowed))
212 }
213 None => Either::Left(iter::once(Cow::Borrowed(field_ty))),
214 };
215
216 for return_ty in return_tys {
217 enum ImplKind {
219 Direct,
221
222 Forwarded,
224
225 Specialized,
230 }
231
232 let impl_kind = if is_blanket {
233 ImplKind::Forwarded
234 } else if field_ty == return_ty.as_ref() {
235 ImplKind::Direct
236 } else if field_contains_generics || generics_search.any_in(&return_ty) {
237 ImplKind::Forwarded
238 } else {
239 ImplKind::Specialized
240 };
241
242 let trait_ty = quote! {
243 derive_more::#trait_ident <#return_ty>
244 };
245
246 let generics = match &impl_kind {
247 ImplKind::Forwarded => {
248 let mut generics = self.generics.clone();
249 generics
250 .make_where_clause()
251 .predicates
252 .push(parse_quote! { #field_ty: #trait_ty });
253 if is_blanket {
254 generics
255 .params
256 .push(parse_quote! { #return_ty: ?derive_more::core::marker::Sized });
257 }
258 Cow::Owned(generics)
259 }
260 ImplKind::Direct | ImplKind::Specialized => {
261 Cow::Borrowed(self.generics)
262 }
263 };
264 let (impl_gens, _, where_clause) = generics.split_for_impl();
265 let (_, ty_gens, _) = self.generics.split_for_impl();
266
267 let body = match &impl_kind {
268 ImplKind::Direct => Cow::Borrowed(&field_ref),
269 ImplKind::Forwarded => Cow::Owned(quote! {
270 <#field_ty as #trait_ty>::#method_ident(#field_ref)
271 }),
272 ImplKind::Specialized => Cow::Owned(quote! {
273 use derive_more::__private::ExtractRef as _;
274
275 let conv =
276 <derive_more::__private::Conv<& #mut_ #field_ty, #return_ty>
277 as derive_more::core::default::Default>::default();
278 (&&conv).__extract_ref(#field_ref)
279 }),
280 };
281
282 quote! {
283 #[automatically_derived]
284 impl #impl_gens #trait_ty for #ty_ident #ty_gens #where_clause {
285 #[inline]
286 fn #method_ident(& #mut_ self) -> & #mut_ #return_ty {
287 #body
288 }
289 }
290 }
291 .to_tokens(tokens);
292 }
293 }
294}
295
296type StructAttribute = attr::Conversion;
303
304type FieldAttribute = attr::FieldConversion;