logos_codegen/parser/
ignore_flags.rs

1use std::ops::{BitAnd, BitOr};
2
3use proc_macro2::{Ident, TokenStream, TokenTree};
4
5use crate::parser::Parser;
6use crate::util::is_punct;
7
8#[derive(Clone, Copy, PartialEq, Eq)]
9pub struct IgnoreFlags {
10    bits: u8,
11}
12
13#[allow(non_upper_case_globals)]
14impl IgnoreFlags {
15    pub const Empty: Self = Self::new(0x00);
16    pub const IgnoreCase: Self = Self::new(0x01);
17    pub const IgnoreAsciiCase: Self = Self::new(0x02);
18
19    #[inline]
20    pub const fn new(bits: u8) -> Self {
21        Self { bits }
22    }
23
24    /// Enables a variant.
25    #[inline]
26    pub fn enable(&mut self, variant: Self) {
27        self.bits |= variant.bits;
28    }
29
30    /// Checks if this `IgnoreFlags` contains *any* of the given variants.
31    #[inline]
32    pub fn contains(&self, variants: Self) -> bool {
33        self.bits & variants.bits != 0
34    }
35
36    #[inline]
37    pub fn is_empty(&self) -> bool {
38        self.bits == 0
39    }
40
41    /// Parses an identifier an enables it for `self`.
42    ///
43    /// Valid inputs are (that produces `true`):
44    /// * `"case"` (incompatible with `"ascii_case"`)
45    /// * `"ascii_case"` (incompatible with `"case"`)
46    ///
47    /// An error causes this function to return `false` and emits an error to
48    /// the given `Parser`.
49    fn parse_ident(&mut self, ident: Ident, parser: &mut Parser) -> bool {
50        match ident.to_string().as_str() {
51            "case" => {
52                if self.contains(Self::IgnoreAsciiCase) {
53                    parser.err(
54                        "\
55                        The flag \"case\" cannot be used along with \"ascii_case\"\
56                        ",
57                        ident.span(),
58                    );
59                    false
60                } else {
61                    self.enable(Self::IgnoreCase);
62                    true
63                }
64            }
65            "ascii_case" => {
66                if self.contains(Self::IgnoreCase) {
67                    parser.err(
68                        "\
69                        The flag \"ascii_case\" cannot be used along with \"case\"\
70                        ",
71                        ident.span(),
72                    );
73                    false
74                } else {
75                    self.enable(Self::IgnoreAsciiCase);
76                    true
77                }
78            }
79            unknown => {
80                parser.err(
81                    format!(
82                        "\
83                        Unknown flag: {}\n\
84                        \n\
85                        Expected one of: case, ascii_case\
86                        ",
87                        unknown
88                    ),
89                    ident.span(),
90                );
91                false
92            }
93        }
94    }
95
96    pub fn parse_group(&mut self, name: Ident, tokens: TokenStream, parser: &mut Parser) {
97        // Little finite state machine to parse "<flag>(,<flag>)*,?"
98
99        // FSM description for future maintenance
100        // 0: Initial state
101        //   <flag> -> 1
102        //        _ -> error
103        // 1: A flag was found
104        //        , -> 2
105        //     None -> done
106        //        _ -> error
107        // 2: A comma was found (after a <flag>)
108        //   <flag> -> 1
109        //     None -> done
110        //        _ -> error
111        let mut state = 0u8;
112
113        let mut tokens = tokens.into_iter();
114
115        loop {
116            state = match state {
117                0 => match tokens.next() {
118                    Some(TokenTree::Ident(ident)) => {
119                        if self.parse_ident(ident, parser) {
120                            1
121                        } else {
122                            return;
123                        }
124                    }
125                    _ => {
126                        parser.err(
127                            "\
128                            Invalid ignore flag\n\
129                            \n\
130                            Expected one of: case, ascii_case\
131                            ",
132                            name.span(),
133                        );
134                        return;
135                    }
136                },
137                1 => match tokens.next() {
138                    Some(tt) if is_punct(&tt, ',') => 2,
139                    None => return,
140                    Some(unexpected_tt) => {
141                        parser.err(
142                            format!(
143                                "\
144                                Unexpected token: {:?}\
145                                ",
146                                unexpected_tt.to_string(),
147                            ),
148                            unexpected_tt.span(),
149                        );
150                        return;
151                    }
152                },
153                2 => match tokens.next() {
154                    Some(TokenTree::Ident(ident)) => {
155                        if self.parse_ident(ident, parser) {
156                            1
157                        } else {
158                            return;
159                        }
160                    }
161                    None => return,
162                    Some(unexpected_tt) => {
163                        parser.err(
164                            format!(
165                                "\
166                                Unexpected token: {:?}\
167                                ",
168                                unexpected_tt.to_string(),
169                            ),
170                            unexpected_tt.span(),
171                        );
172                        return;
173                    }
174                },
175                _ => unreachable!("Internal Error: invalid state ({})", state),
176            }
177        }
178    }
179}
180
181impl BitOr for IgnoreFlags {
182    type Output = Self;
183
184    fn bitor(self, other: Self) -> Self {
185        Self::new(self.bits | other.bits)
186    }
187}
188
189impl BitAnd for IgnoreFlags {
190    type Output = Self;
191
192    fn bitand(self, other: Self) -> Self {
193        Self::new(self.bits & other.bits)
194    }
195}
196
197pub mod ascii_case {
198    use regex_syntax::hir;
199
200    use crate::mir::Mir;
201    use crate::parser::Literal;
202
203    macro_rules! literal {
204        ($byte:expr) => {
205            hir::Literal(Box::new([$byte]))
206        };
207        (@char $c:expr) => {
208            hir::Literal(
209                $c.encode_utf8(&mut [0; 4])
210                    .as_bytes()
211                    .to_vec()
212                    .into_boxed_slice(),
213            )
214        };
215    }
216
217    pub trait MakeAsciiCaseInsensitive {
218        /// Creates a equivalent regular expression which ignore the letter casing
219        /// of ascii characters.
220        fn make_ascii_case_insensitive(self) -> Mir;
221    }
222
223    impl MakeAsciiCaseInsensitive for u8 {
224        fn make_ascii_case_insensitive(self) -> Mir {
225            if self.is_ascii_lowercase() {
226                Mir::Alternation(vec![
227                    Mir::Literal(literal!(self - 32)),
228                    Mir::Literal(literal!(self)),
229                ])
230            } else if self.is_ascii_uppercase() {
231                Mir::Alternation(vec![
232                    Mir::Literal(literal!(self)),
233                    Mir::Literal(literal!(self + 32)),
234                ])
235            } else {
236                Mir::Literal(literal!(self))
237            }
238        }
239    }
240
241    impl MakeAsciiCaseInsensitive for char {
242        fn make_ascii_case_insensitive(self) -> Mir {
243            if self.is_ascii() {
244                (self as u8).make_ascii_case_insensitive()
245            } else {
246                Mir::Literal(literal!(@char self))
247            }
248        }
249    }
250
251    impl MakeAsciiCaseInsensitive for hir::Literal {
252        fn make_ascii_case_insensitive(self) -> Mir {
253            Mir::Concat(
254                self.0
255                    .iter()
256                    .map(|x| x.make_ascii_case_insensitive())
257                    .collect(),
258            )
259        }
260    }
261
262    impl MakeAsciiCaseInsensitive for hir::ClassBytes {
263        fn make_ascii_case_insensitive(mut self) -> Mir {
264            self.case_fold_simple();
265            Mir::Class(hir::Class::Bytes(self))
266        }
267    }
268
269    impl MakeAsciiCaseInsensitive for hir::ClassUnicode {
270        fn make_ascii_case_insensitive(mut self) -> Mir {
271            use std::cmp;
272
273            // Manuall implementation to only perform the case folding on ascii characters.
274
275            let mut ranges = Vec::new();
276
277            for range in self.ranges() {
278                #[inline]
279                fn overlaps(st1: u8, end1: u8, st2: u8, end2: u8) -> bool {
280                    (st2 <= st1 && st1 <= end2) || (st1 <= st2 && st2 <= end1)
281                }
282
283                #[inline]
284                fn make_ascii(c: char) -> Option<u8> {
285                    if c.is_ascii() {
286                        Some(c as u8)
287                    } else {
288                        None
289                    }
290                }
291
292                match (make_ascii(range.start()), make_ascii(range.end())) {
293                    (Some(start), Some(end)) => {
294                        if overlaps(b'a', b'z', start, end) {
295                            let lower = cmp::max(start, b'a');
296                            let upper = cmp::min(end, b'z');
297                            ranges.push(hir::ClassUnicodeRange::new(
298                                (lower - 32) as char,
299                                (upper - 32) as char,
300                            ))
301                        }
302
303                        if overlaps(b'A', b'Z', start, end) {
304                            let lower = cmp::max(start, b'A');
305                            let upper = cmp::min(end, b'Z');
306                            ranges.push(hir::ClassUnicodeRange::new(
307                                (lower + 32) as char,
308                                (upper + 32) as char,
309                            ))
310                        }
311                    }
312                    (Some(start), None) => {
313                        if overlaps(b'a', b'z', start, b'z') {
314                            let lower = cmp::max(start, b'a');
315                            ranges.push(hir::ClassUnicodeRange::new((lower - 32) as char, 'Z'))
316                        }
317
318                        if overlaps(b'A', b'Z', start, b'Z') {
319                            let lower = cmp::max(start, b'A');
320                            ranges.push(hir::ClassUnicodeRange::new((lower + 32) as char, 'Z'))
321                        }
322                    }
323                    _ => (),
324                }
325            }
326
327            self.union(&hir::ClassUnicode::new(ranges));
328
329            Mir::Class(hir::Class::Unicode(self))
330        }
331    }
332
333    impl MakeAsciiCaseInsensitive for hir::Class {
334        fn make_ascii_case_insensitive(self) -> Mir {
335            match self {
336                hir::Class::Bytes(b) => b.make_ascii_case_insensitive(),
337                hir::Class::Unicode(u) => u.make_ascii_case_insensitive(),
338            }
339        }
340    }
341
342    impl MakeAsciiCaseInsensitive for &Literal {
343        fn make_ascii_case_insensitive(self) -> Mir {
344            match self {
345                Literal::Bytes(bytes) => Mir::Concat(
346                    bytes
347                        .value()
348                        .into_iter()
349                        .map(|b| b.make_ascii_case_insensitive())
350                        .collect(),
351                ),
352                Literal::Utf8(s) => Mir::Concat(
353                    s.value()
354                        .chars()
355                        .map(|b| b.make_ascii_case_insensitive())
356                        .collect(),
357                ),
358            }
359        }
360    }
361
362    impl MakeAsciiCaseInsensitive for Mir {
363        fn make_ascii_case_insensitive(self) -> Mir {
364            match self {
365                Mir::Empty => Mir::Empty,
366                Mir::Loop(l) => Mir::Loop(Box::new(l.make_ascii_case_insensitive())),
367                Mir::Maybe(m) => Mir::Maybe(Box::new(m.make_ascii_case_insensitive())),
368                Mir::Concat(c) => Mir::Concat(
369                    c.into_iter()
370                        .map(|m| m.make_ascii_case_insensitive())
371                        .collect(),
372                ),
373                Mir::Alternation(a) => Mir::Alternation(
374                    a.into_iter()
375                        .map(|m| m.make_ascii_case_insensitive())
376                        .collect(),
377                ),
378                Mir::Class(c) => c.make_ascii_case_insensitive(),
379                Mir::Literal(l) => l.make_ascii_case_insensitive(),
380            }
381        }
382    }
383
384    #[cfg(test)]
385    mod tests {
386        use super::MakeAsciiCaseInsensitive;
387        use crate::mir::{Class, Mir};
388        use regex_syntax::hir::{ClassUnicode, ClassUnicodeRange};
389
390        fn assert_range(in_s: char, in_e: char, expected: &[(char, char)]) {
391            let range = ClassUnicodeRange::new(in_s, in_e);
392            let class = ClassUnicode::new(vec![range]);
393
394            let expected =
395                ClassUnicode::new(expected.iter().map(|&(a, b)| ClassUnicodeRange::new(a, b)));
396
397            if let Mir::Class(Class::Unicode(result)) = class.make_ascii_case_insensitive() {
398                assert_eq!(result, expected);
399            } else {
400                panic!("Not a unicode class");
401            };
402        }
403
404        #[test]
405        fn no_letters_left() {
406            assert_range(' ', '+', &[(' ', '+')]);
407        }
408
409        #[test]
410        fn no_letters_right() {
411            assert_range('{', '~', &[('{', '~')]);
412        }
413
414        #[test]
415        fn no_letters_middle() {
416            assert_range('[', '`', &[('[', '`')]);
417        }
418
419        #[test]
420        fn lowercase_left_edge() {
421            assert_range('a', 'd', &[('a', 'd'), ('A', 'D')]);
422        }
423
424        #[test]
425        fn lowercase_right_edge() {
426            assert_range('r', 'z', &[('r', 'z'), ('R', 'Z')]);
427        }
428
429        #[test]
430        fn lowercase_total() {
431            assert_range('a', 'z', &[('a', 'z'), ('A', 'Z')]);
432        }
433
434        #[test]
435        fn uppercase_left_edge() {
436            assert_range('A', 'D', &[('a', 'd'), ('A', 'D')]);
437        }
438
439        #[test]
440        fn uppercase_right_edge() {
441            assert_range('R', 'Z', &[('r', 'z'), ('R', 'Z')]);
442        }
443
444        #[test]
445        fn uppercase_total() {
446            assert_range('A', 'Z', &[('a', 'z'), ('A', 'Z')]);
447        }
448
449        #[test]
450        fn lowercase_cross_left() {
451            assert_range('[', 'h', &[('[', 'h'), ('A', 'H')]);
452        }
453
454        #[test]
455        fn lowercase_cross_right() {
456            assert_range('d', '}', &[('d', '}'), ('D', 'Z')]);
457        }
458
459        #[test]
460        fn uppercase_cross_left() {
461            assert_range(';', 'H', &[(';', 'H'), ('a', 'h')]);
462        }
463
464        #[test]
465        fn uppercase_cross_right() {
466            assert_range('T', ']', &[('t', 'z'), ('T', ']')]);
467        }
468
469        #[test]
470        fn cross_both() {
471            assert_range('X', 'c', &[('X', 'c'), ('x', 'z'), ('A', 'C')]);
472        }
473
474        #[test]
475        fn all_letters() {
476            assert_range('+', '|', &[('+', '|')]);
477        }
478
479        #[test]
480        fn oob_all_letters() {
481            assert_range('#', 'é', &[('#', 'é')]);
482        }
483
484        #[test]
485        fn oob_from_uppercase() {
486            assert_range('Q', 'é', &[('A', 'é')]);
487        }
488
489        #[test]
490        fn oob_from_lowercase() {
491            assert_range('q', 'é', &[('q', 'é'), ('Q', 'Z')]);
492        }
493
494        #[test]
495        fn oob_no_letters() {
496            assert_range('|', 'é', &[('|', 'é')]);
497        }
498    }
499}