1use std::borrow::Cow;
2use std::collections::BTreeMap;
3use std::fmt;
4use std::fs;
5use std::io;
6use std::path::Path;
7use std::path::PathBuf;
8use std::sync::Arc;
9
10use memo_map::MemoMap;
11use self_cell::self_cell;
12
13use crate::compiler::instructions::Instructions;
14use crate::error::{Error, ErrorKind};
15use crate::template::CompiledTemplate;
16use crate::template::TemplateConfig;
17
18type LoadFunc = dyn for<'a> Fn(&'a str) -> Result<Option<String>, Error> + Send + Sync;
19
20#[derive(Clone)]
27pub(crate) struct LoaderStore<'source> {
28 pub template_config: TemplateConfig,
29 loader: Option<Arc<LoadFunc>>,
30 owned_templates: MemoMap<Arc<str>, Arc<LoadedTemplate>>,
31 borrowed_templates: BTreeMap<&'source str, Arc<CompiledTemplate<'source>>>,
32}
33
34impl fmt::Debug for LoaderStore<'_> {
35 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
36 let mut l = f.debug_list();
37 for key in self.owned_templates.keys() {
38 l.entry(key);
39 }
40 for key in self.borrowed_templates.keys() {
41 if !self.owned_templates.contains_key(*key) {
42 l.entry(key);
43 }
44 }
45 l.finish()
46 }
47}
48
49self_cell! {
50 struct LoadedTemplate {
51 owner: (Arc<str>, Box<str>),
52 #[covariant]
53 dependent: CompiledTemplate,
54 }
55}
56
57self_cell! {
58 pub(crate) struct OwnedInstructions {
59 owner: Box<str>,
60 #[covariant]
61 dependent: Instructions,
62 }
63}
64
65impl fmt::Debug for LoadedTemplate {
66 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
67 fmt::Debug::fmt(&self.borrow_dependent(), f)
68 }
69}
70
71impl<'source> LoaderStore<'source> {
72 pub fn new(template_config: TemplateConfig) -> LoaderStore<'source> {
73 LoaderStore {
74 template_config,
75 loader: None,
76 owned_templates: MemoMap::default(),
77 borrowed_templates: BTreeMap::default(),
78 }
79 }
80
81 pub fn insert(&mut self, name: &'source str, source: &'source str) -> Result<(), Error> {
82 self.insert_cow(Cow::Borrowed(name), Cow::Borrowed(source))
83 }
84
85 pub fn insert_cow(
86 &mut self,
87 name: Cow<'source, str>,
88 source: Cow<'source, str>,
89 ) -> Result<(), Error> {
90 match (source, name) {
91 (Cow::Borrowed(source), Cow::Borrowed(name)) => {
92 self.owned_templates.remove(name);
93 self.borrowed_templates.insert(
94 name,
95 Arc::new(ok!(CompiledTemplate::new(
96 name,
97 source,
98 &self.template_config
99 ))),
100 );
101 }
102 (source, name) => {
103 self.borrowed_templates.remove(&name as &str);
104 let name: Arc<str> = name.into();
105 self.owned_templates.replace(
106 name.clone(),
107 ok!(self.make_owned_template(name, source.to_string())),
108 );
109 }
110 }
111
112 Ok(())
113 }
114
115 pub fn remove(&mut self, name: &str) {
116 self.borrowed_templates.remove(name);
117 self.owned_templates.remove(name);
118 }
119
120 pub fn clear(&mut self) {
121 self.borrowed_templates.clear();
122 self.owned_templates.clear();
123 }
124
125 pub fn get(&self, name: &str) -> Result<&CompiledTemplate<'_>, Error> {
126 if let Some(rv) = self.borrowed_templates.get(name) {
127 Ok(&**rv)
128 } else {
129 let name: Arc<str> = name.into();
130 self.owned_templates
131 .get_or_try_insert(&name.clone(), || -> Result<_, Error> {
132 let loader_result = match self.loader {
133 Some(ref loader) => ok!(loader(&name)),
134 None => None,
135 }
136 .ok_or_else(|| Error::new_not_found(&name));
137 self.make_owned_template(name, ok!(loader_result))
138 })
139 .map(|x| x.borrow_dependent())
140 }
141 }
142
143 pub fn set_loader<F>(&mut self, f: F)
144 where
145 F: Fn(&str) -> Result<Option<String>, Error> + Send + Sync + 'static,
146 {
147 self.loader = Some(Arc::new(f));
148 }
149
150 fn make_owned_template(
151 &self,
152 name: Arc<str>,
153 source: String,
154 ) -> Result<Arc<LoadedTemplate>, Error> {
155 LoadedTemplate::try_new(
156 (name, source.into_boxed_str()),
157 |(name, source)| -> Result<_, Error> {
158 CompiledTemplate::new(name, source, &self.template_config)
159 },
160 )
161 .map(Arc::new)
162 }
163
164 pub fn iter(&self) -> impl Iterator<Item = (&str, &CompiledTemplate<'_>)> {
165 let borrowed = self
166 .borrowed_templates
167 .iter()
168 .map(|(name, template)| (*name, &**template));
169
170 let owned = self
171 .owned_templates
172 .iter()
173 .map(|(name, template)| (&**name, template.borrow_dependent()));
174
175 borrowed.chain(owned)
176 }
177}
178
179pub fn safe_join(base: &Path, template: &str) -> Option<PathBuf> {
181 let mut rv = base.to_path_buf();
182 for segment in template.split('/') {
183 if segment.starts_with('.') || segment.contains('\\') {
184 return None;
185 }
186 rv.push(segment);
187 }
188 Some(rv)
189}
190
191#[cfg_attr(docsrs, doc(cfg(feature = "loader")))]
208pub fn path_loader<'x, P: AsRef<Path> + 'x>(
209 dir: P,
210) -> impl for<'a> Fn(&'a str) -> Result<Option<String>, Error> + Send + Sync + 'static {
211 let dir = dir.as_ref().to_path_buf();
212 move |name| {
213 let Some(path) = safe_join(&dir, name) else {
214 return Ok(None);
215 };
216 match fs::read_to_string(path) {
217 Ok(result) => Ok(Some(result)),
218 Err(err) if err.kind() == io::ErrorKind::NotFound => Ok(None),
219 Err(err) => Err(
220 Error::new(ErrorKind::InvalidOperation, "could not read template").with_source(err),
221 ),
222 }
223 }
224}
225
226#[cfg(test)]
227mod tests {
228 use super::*;
229
230 use similar_asserts::assert_eq;
231
232 #[test]
233 fn test_safe_join() {
234 assert_eq!(
235 safe_join(Path::new("foo"), "bar/baz"),
236 Some(PathBuf::from("foo").join("bar").join("baz"))
237 );
238 assert_eq!(safe_join(Path::new("foo"), ".bar/baz"), None);
239 assert_eq!(safe_join(Path::new("foo"), "bar/.baz"), None);
240 assert_eq!(safe_join(Path::new("foo"), "bar/../baz"), None);
241 }
242}