1use std::ops::{BitAnd, BitOr};
2
3use proc_macro2::{Ident, TokenStream, TokenTree};
4
5use crate::parser::Parser;
6use crate::util::is_punct;
7
8#[derive(Clone, Copy, PartialEq, Eq)]
9pub struct IgnoreFlags {
10 bits: u8,
11}
12
13#[allow(non_upper_case_globals)]
14impl IgnoreFlags {
15 pub const Empty: Self = Self::new(0x00);
16 pub const IgnoreCase: Self = Self::new(0x01);
17 pub const IgnoreAsciiCase: Self = Self::new(0x02);
18
19 #[inline]
20 pub const fn new(bits: u8) -> Self {
21 Self { bits }
22 }
23
24 #[inline]
26 pub fn enable(&mut self, variant: Self) {
27 self.bits |= variant.bits;
28 }
29
30 #[inline]
32 pub fn contains(&self, variants: Self) -> bool {
33 self.bits & variants.bits != 0
34 }
35
36 #[inline]
37 pub fn is_empty(&self) -> bool {
38 self.bits == 0
39 }
40
41 fn parse_ident(&mut self, ident: Ident, parser: &mut Parser) -> bool {
50 match ident.to_string().as_str() {
51 "case" => {
52 if self.contains(Self::IgnoreAsciiCase) {
53 parser.err(
54 "\
55 The flag \"case\" cannot be used along with \"ascii_case\"\
56 ",
57 ident.span(),
58 );
59 false
60 } else {
61 self.enable(Self::IgnoreCase);
62 true
63 }
64 }
65 "ascii_case" => {
66 if self.contains(Self::IgnoreCase) {
67 parser.err(
68 "\
69 The flag \"ascii_case\" cannot be used along with \"case\"\
70 ",
71 ident.span(),
72 );
73 false
74 } else {
75 self.enable(Self::IgnoreAsciiCase);
76 true
77 }
78 }
79 unknown => {
80 parser.err(
81 format!(
82 "\
83 Unknown flag: {}\n\
84 \n\
85 Expected one of: case, ascii_case\
86 ",
87 unknown
88 ),
89 ident.span(),
90 );
91 false
92 }
93 }
94 }
95
96 pub fn parse_group(&mut self, name: Ident, tokens: TokenStream, parser: &mut Parser) {
97 let mut state = 0u8;
112
113 let mut tokens = tokens.into_iter();
114
115 loop {
116 state = match state {
117 0 => match tokens.next() {
118 Some(TokenTree::Ident(ident)) => {
119 if self.parse_ident(ident, parser) {
120 1
121 } else {
122 return;
123 }
124 }
125 _ => {
126 parser.err(
127 "\
128 Invalid ignore flag\n\
129 \n\
130 Expected one of: case, ascii_case\
131 ",
132 name.span(),
133 );
134 return;
135 }
136 },
137 1 => match tokens.next() {
138 Some(tt) if is_punct(&tt, ',') => 2,
139 None => return,
140 Some(unexpected_tt) => {
141 parser.err(
142 format!(
143 "\
144 Unexpected token: {:?}\
145 ",
146 unexpected_tt.to_string(),
147 ),
148 unexpected_tt.span(),
149 );
150 return;
151 }
152 },
153 2 => match tokens.next() {
154 Some(TokenTree::Ident(ident)) => {
155 if self.parse_ident(ident, parser) {
156 1
157 } else {
158 return;
159 }
160 }
161 None => return,
162 Some(unexpected_tt) => {
163 parser.err(
164 format!(
165 "\
166 Unexpected token: {:?}\
167 ",
168 unexpected_tt.to_string(),
169 ),
170 unexpected_tt.span(),
171 );
172 return;
173 }
174 },
175 _ => unreachable!("Internal Error: invalid state ({})", state),
176 }
177 }
178 }
179}
180
181impl BitOr for IgnoreFlags {
182 type Output = Self;
183
184 fn bitor(self, other: Self) -> Self {
185 Self::new(self.bits | other.bits)
186 }
187}
188
189impl BitAnd for IgnoreFlags {
190 type Output = Self;
191
192 fn bitand(self, other: Self) -> Self {
193 Self::new(self.bits & other.bits)
194 }
195}
196
197pub mod ascii_case {
198 use regex_syntax::hir;
199
200 use crate::mir::Mir;
201 use crate::parser::Literal;
202
203 macro_rules! literal {
204 ($byte:expr) => {
205 hir::Literal(Box::new([$byte]))
206 };
207 (@char $c:expr) => {
208 hir::Literal(
209 $c.encode_utf8(&mut [0; 4])
210 .as_bytes()
211 .to_vec()
212 .into_boxed_slice(),
213 )
214 };
215 }
216
217 pub trait MakeAsciiCaseInsensitive {
218 fn make_ascii_case_insensitive(self) -> Mir;
221 }
222
223 impl MakeAsciiCaseInsensitive for u8 {
224 fn make_ascii_case_insensitive(self) -> Mir {
225 if self.is_ascii_lowercase() {
226 Mir::Alternation(vec![
227 Mir::Literal(literal!(self - 32)),
228 Mir::Literal(literal!(self)),
229 ])
230 } else if self.is_ascii_uppercase() {
231 Mir::Alternation(vec![
232 Mir::Literal(literal!(self)),
233 Mir::Literal(literal!(self + 32)),
234 ])
235 } else {
236 Mir::Literal(literal!(self))
237 }
238 }
239 }
240
241 impl MakeAsciiCaseInsensitive for char {
242 fn make_ascii_case_insensitive(self) -> Mir {
243 if self.is_ascii() {
244 (self as u8).make_ascii_case_insensitive()
245 } else {
246 Mir::Literal(literal!(@char self))
247 }
248 }
249 }
250
251 impl MakeAsciiCaseInsensitive for hir::Literal {
252 fn make_ascii_case_insensitive(self) -> Mir {
253 Mir::Concat(
254 self.0
255 .iter()
256 .map(|x| x.make_ascii_case_insensitive())
257 .collect(),
258 )
259 }
260 }
261
262 impl MakeAsciiCaseInsensitive for hir::ClassBytes {
263 fn make_ascii_case_insensitive(mut self) -> Mir {
264 self.case_fold_simple();
265 Mir::Class(hir::Class::Bytes(self))
266 }
267 }
268
269 impl MakeAsciiCaseInsensitive for hir::ClassUnicode {
270 fn make_ascii_case_insensitive(mut self) -> Mir {
271 use std::cmp;
272
273 let mut ranges = Vec::new();
276
277 for range in self.ranges() {
278 #[inline]
279 fn overlaps(st1: u8, end1: u8, st2: u8, end2: u8) -> bool {
280 (st2 <= st1 && st1 <= end2) || (st1 <= st2 && st2 <= end1)
281 }
282
283 #[inline]
284 fn make_ascii(c: char) -> Option<u8> {
285 if c.is_ascii() {
286 Some(c as u8)
287 } else {
288 None
289 }
290 }
291
292 match (make_ascii(range.start()), make_ascii(range.end())) {
293 (Some(start), Some(end)) => {
294 if overlaps(b'a', b'z', start, end) {
295 let lower = cmp::max(start, b'a');
296 let upper = cmp::min(end, b'z');
297 ranges.push(hir::ClassUnicodeRange::new(
298 (lower - 32) as char,
299 (upper - 32) as char,
300 ))
301 }
302
303 if overlaps(b'A', b'Z', start, end) {
304 let lower = cmp::max(start, b'A');
305 let upper = cmp::min(end, b'Z');
306 ranges.push(hir::ClassUnicodeRange::new(
307 (lower + 32) as char,
308 (upper + 32) as char,
309 ))
310 }
311 }
312 (Some(start), None) => {
313 if overlaps(b'a', b'z', start, b'z') {
314 let lower = cmp::max(start, b'a');
315 ranges.push(hir::ClassUnicodeRange::new((lower - 32) as char, 'Z'))
316 }
317
318 if overlaps(b'A', b'Z', start, b'Z') {
319 let lower = cmp::max(start, b'A');
320 ranges.push(hir::ClassUnicodeRange::new((lower + 32) as char, 'Z'))
321 }
322 }
323 _ => (),
324 }
325 }
326
327 self.union(&hir::ClassUnicode::new(ranges));
328
329 Mir::Class(hir::Class::Unicode(self))
330 }
331 }
332
333 impl MakeAsciiCaseInsensitive for hir::Class {
334 fn make_ascii_case_insensitive(self) -> Mir {
335 match self {
336 hir::Class::Bytes(b) => b.make_ascii_case_insensitive(),
337 hir::Class::Unicode(u) => u.make_ascii_case_insensitive(),
338 }
339 }
340 }
341
342 impl MakeAsciiCaseInsensitive for &Literal {
343 fn make_ascii_case_insensitive(self) -> Mir {
344 match self {
345 Literal::Bytes(bytes) => Mir::Concat(
346 bytes
347 .value()
348 .into_iter()
349 .map(|b| b.make_ascii_case_insensitive())
350 .collect(),
351 ),
352 Literal::Utf8(s) => Mir::Concat(
353 s.value()
354 .chars()
355 .map(|b| b.make_ascii_case_insensitive())
356 .collect(),
357 ),
358 }
359 }
360 }
361
362 impl MakeAsciiCaseInsensitive for Mir {
363 fn make_ascii_case_insensitive(self) -> Mir {
364 match self {
365 Mir::Empty => Mir::Empty,
366 Mir::Loop(l) => Mir::Loop(Box::new(l.make_ascii_case_insensitive())),
367 Mir::Maybe(m) => Mir::Maybe(Box::new(m.make_ascii_case_insensitive())),
368 Mir::Concat(c) => Mir::Concat(
369 c.into_iter()
370 .map(|m| m.make_ascii_case_insensitive())
371 .collect(),
372 ),
373 Mir::Alternation(a) => Mir::Alternation(
374 a.into_iter()
375 .map(|m| m.make_ascii_case_insensitive())
376 .collect(),
377 ),
378 Mir::Class(c) => c.make_ascii_case_insensitive(),
379 Mir::Literal(l) => l.make_ascii_case_insensitive(),
380 }
381 }
382 }
383
384 #[cfg(test)]
385 mod tests {
386 use super::MakeAsciiCaseInsensitive;
387 use crate::mir::{Class, Mir};
388 use regex_syntax::hir::{ClassUnicode, ClassUnicodeRange};
389
390 fn assert_range(in_s: char, in_e: char, expected: &[(char, char)]) {
391 let range = ClassUnicodeRange::new(in_s, in_e);
392 let class = ClassUnicode::new(vec![range]);
393
394 let expected =
395 ClassUnicode::new(expected.iter().map(|&(a, b)| ClassUnicodeRange::new(a, b)));
396
397 if let Mir::Class(Class::Unicode(result)) = class.make_ascii_case_insensitive() {
398 assert_eq!(result, expected);
399 } else {
400 panic!("Not a unicode class");
401 };
402 }
403
404 #[test]
405 fn no_letters_left() {
406 assert_range(' ', '+', &[(' ', '+')]);
407 }
408
409 #[test]
410 fn no_letters_right() {
411 assert_range('{', '~', &[('{', '~')]);
412 }
413
414 #[test]
415 fn no_letters_middle() {
416 assert_range('[', '`', &[('[', '`')]);
417 }
418
419 #[test]
420 fn lowercase_left_edge() {
421 assert_range('a', 'd', &[('a', 'd'), ('A', 'D')]);
422 }
423
424 #[test]
425 fn lowercase_right_edge() {
426 assert_range('r', 'z', &[('r', 'z'), ('R', 'Z')]);
427 }
428
429 #[test]
430 fn lowercase_total() {
431 assert_range('a', 'z', &[('a', 'z'), ('A', 'Z')]);
432 }
433
434 #[test]
435 fn uppercase_left_edge() {
436 assert_range('A', 'D', &[('a', 'd'), ('A', 'D')]);
437 }
438
439 #[test]
440 fn uppercase_right_edge() {
441 assert_range('R', 'Z', &[('r', 'z'), ('R', 'Z')]);
442 }
443
444 #[test]
445 fn uppercase_total() {
446 assert_range('A', 'Z', &[('a', 'z'), ('A', 'Z')]);
447 }
448
449 #[test]
450 fn lowercase_cross_left() {
451 assert_range('[', 'h', &[('[', 'h'), ('A', 'H')]);
452 }
453
454 #[test]
455 fn lowercase_cross_right() {
456 assert_range('d', '}', &[('d', '}'), ('D', 'Z')]);
457 }
458
459 #[test]
460 fn uppercase_cross_left() {
461 assert_range(';', 'H', &[(';', 'H'), ('a', 'h')]);
462 }
463
464 #[test]
465 fn uppercase_cross_right() {
466 assert_range('T', ']', &[('t', 'z'), ('T', ']')]);
467 }
468
469 #[test]
470 fn cross_both() {
471 assert_range('X', 'c', &[('X', 'c'), ('x', 'z'), ('A', 'C')]);
472 }
473
474 #[test]
475 fn all_letters() {
476 assert_range('+', '|', &[('+', '|')]);
477 }
478
479 #[test]
480 fn oob_all_letters() {
481 assert_range('#', 'é', &[('#', 'é')]);
482 }
483
484 #[test]
485 fn oob_from_uppercase() {
486 assert_range('Q', 'é', &[('A', 'é')]);
487 }
488
489 #[test]
490 fn oob_from_lowercase() {
491 assert_range('q', 'é', &[('q', 'é'), ('Q', 'Z')]);
492 }
493
494 #[test]
495 fn oob_no_letters() {
496 assert_range('|', 'é', &[('|', 'é')]);
497 }
498 }
499}