logos_codegen/generator/
fork.rs

1use std::cmp::max;
2
3use fnv::FnvHashMap as Map;
4use proc_macro2::TokenStream;
5use quote::quote;
6
7use crate::generator::{Context, Generator};
8use crate::graph::{Fork, NodeId, Range};
9use crate::util::ToIdent;
10
11type Targets = Map<NodeId, Vec<Range>>;
12
13impl Generator<'_> {
14    pub fn generate_fork(&mut self, this: NodeId, fork: &Fork, mut ctx: Context) -> TokenStream {
15        let mut targets: Targets = Map::default();
16
17        for (range, then) in fork.branches() {
18            targets.entry(then).or_default().push(range);
19        }
20        let loops_to_self = self.meta[this].loop_entry_from.contains(&this);
21
22        match targets.len() {
23            1 if loops_to_self => return self.generate_fast_loop(fork, ctx),
24            0..=2 => (),
25            _ => return self.generate_fork_jump_table(this, fork, targets, ctx),
26        }
27        let miss = ctx.miss(fork.miss, self);
28        let end = self.fork_end(this, &miss);
29        let (byte, read) = self.fork_read(this, end, &mut ctx);
30        let branches = targets.into_iter().map(|(id, ranges)| {
31            let next = self.goto(id, ctx.advance(1));
32
33            match *ranges {
34                [range] => {
35                    quote!(#range => #next,)
36                }
37                [a, b] if a.is_byte() && b.is_byte() => {
38                    quote!(#a | #b => #next,)
39                }
40                _ => {
41                    let test = self.generate_test(ranges).clone();
42                    let next = self.goto(id, ctx.advance(1));
43
44                    quote!(byte if #test(byte) => #next,)
45                }
46            }
47        });
48
49        quote! {
50            #read
51
52            match #byte {
53                #(#branches)*
54                _ => #miss,
55            }
56        }
57    }
58
59    fn generate_fork_jump_table(
60        &mut self,
61        this: NodeId,
62        fork: &Fork,
63        targets: Targets,
64        mut ctx: Context,
65    ) -> TokenStream {
66        let miss = ctx.miss(fork.miss, self);
67        let end = self.fork_end(this, &miss);
68        let (byte, read) = self.fork_read(this, end, &mut ctx);
69
70        let mut table: [u8; 256] = [0; 256];
71        let mut jumps = vec!["__".to_ident()];
72
73        let branches = targets
74            .into_iter()
75            .enumerate()
76            .map(|(idx, (id, ranges))| {
77                let idx = (idx as u8) + 1;
78                let next = self.goto(id, ctx.advance(1));
79                jumps.push(format!("J{}", id).to_ident());
80
81                for byte in ranges.into_iter().flatten() {
82                    table[byte as usize] = idx;
83                }
84                let jump = jumps.last().unwrap();
85
86                quote!(Jump::#jump => #next,)
87            })
88            .collect::<TokenStream>();
89
90        let may_error = table.iter().any(|&idx| idx == 0);
91
92        let jumps = jumps.as_slice();
93        let table = table.iter().copied().map(|idx| &jumps[idx as usize]);
94
95        let jumps = if may_error { jumps } else { &jumps[1..] };
96        let error_branch = if may_error {
97            Some(quote!(Jump::__ => #miss))
98        } else {
99            None
100        };
101
102        quote! {
103            enum Jump {
104                #(#jumps,)*
105            }
106
107            const LUT: [Jump; 256] = {
108                use Jump::*;
109
110                [#(#table),*]
111            };
112
113            #read
114
115            match LUT[#byte as usize] {
116                #branches
117                #error_branch
118            }
119        }
120    }
121
122    fn fork_end(&self, this: NodeId, miss: &TokenStream) -> TokenStream {
123        if this == self.root {
124            quote!(_end(lex))
125        } else {
126            miss.clone()
127        }
128    }
129
130    fn fork_read(
131        &self,
132        this: NodeId,
133        end: TokenStream,
134        ctx: &mut Context,
135    ) -> (TokenStream, TokenStream) {
136        let min_read = self.meta[this].min_read;
137
138        if ctx.remainder() >= max(min_read, 1) {
139            let read = ctx.read_byte();
140
141            return (quote!(byte), quote!(let byte = #read;));
142        }
143
144        match min_read {
145            0 | 1 => {
146                let read = ctx.read(0);
147
148                (
149                    quote!(byte),
150                    quote! {
151                        let byte = match #read {
152                            Some(byte) => byte,
153                            None => return #end,
154                        };
155                    },
156                )
157            }
158            len => {
159                let read = ctx.read(len);
160
161                (
162                    quote!(arr[0]),
163                    quote! {
164                        let arr = match #read {
165                            Some(arr) => arr,
166                            None => return #end,
167                        };
168                    },
169                )
170            }
171        }
172    }
173
174    fn generate_fast_loop(&mut self, fork: &Fork, ctx: Context) -> TokenStream {
175        let miss = ctx.miss(fork.miss, self);
176        let ranges = fork.branches().map(|(range, _)| range).collect::<Vec<_>>();
177        let test = self.generate_test(ranges);
178
179        quote! {
180            _fast_loop!(lex, #test, #miss);
181        }
182    }
183
184    pub fn fast_loop_macro() -> TokenStream {
185        quote! {
186            macro_rules! _fast_loop {
187                ($lex:ident, $test:ident, $miss:expr) => {
188                    // Do one bounds check for multiple bytes till EOF
189                    while let Some(arr) = $lex.read::<&[u8; 16]>() {
190                        if $test(arr[0])  { if $test(arr[1])  { if $test(arr[2])  { if $test(arr[3]) {
191                        if $test(arr[4])  { if $test(arr[5])  { if $test(arr[6])  { if $test(arr[7]) {
192                        if $test(arr[8])  { if $test(arr[9])  { if $test(arr[10]) { if $test(arr[11]) {
193                        if $test(arr[12]) { if $test(arr[13]) { if $test(arr[14]) { if $test(arr[15]) {
194
195                        $lex.bump_unchecked(16); continue;     } $lex.bump_unchecked(15); return $miss; }
196                        $lex.bump_unchecked(14); return $miss; } $lex.bump_unchecked(13); return $miss; }
197                        $lex.bump_unchecked(12); return $miss; } $lex.bump_unchecked(11); return $miss; }
198                        $lex.bump_unchecked(10); return $miss; } $lex.bump_unchecked(9); return $miss;  }
199                        $lex.bump_unchecked(8); return $miss;  } $lex.bump_unchecked(7); return $miss;  }
200                        $lex.bump_unchecked(6); return $miss;  } $lex.bump_unchecked(5); return $miss;  }
201                        $lex.bump_unchecked(4); return $miss;  } $lex.bump_unchecked(3); return $miss;  }
202                        $lex.bump_unchecked(2); return $miss;  } $lex.bump_unchecked(1); return $miss;  }
203
204                        return $miss;
205                    }
206
207                    while $lex.test($test) {
208                        $lex.bump_unchecked(1);
209                    }
210
211                    $miss
212                };
213            }
214        }
215    }
216}