1use crate::algebra::{Expression, GraphPattern};
2use oxrdf::Variable;
3use spargebra::algebra::Function;
4use spargebra::term::{GroundTerm, GroundTermPattern, NamedNodePattern};
5use std::collections::HashMap;
6use std::ops::{BitAnd, BitOr};
7
8pub fn infer_graph_pattern_types(
9 pattern: &GraphPattern,
10 mut types: VariableTypes,
11) -> VariableTypes {
12 match pattern {
13 GraphPattern::QuadPattern {
14 subject,
15 predicate,
16 object,
17 graph_name,
18 } => {
19 add_ground_term_pattern_types(subject, &mut types, false);
20 if let NamedNodePattern::Variable(v) = predicate {
21 types.intersect_variable_with(v.clone(), VariableType::NAMED_NODE)
22 }
23 add_ground_term_pattern_types(object, &mut types, true);
24 if let Some(NamedNodePattern::Variable(v)) = graph_name {
25 types.intersect_variable_with(v.clone(), VariableType::NAMED_NODE)
26 }
27 types
28 }
29 GraphPattern::Path {
30 subject,
31 object,
32 graph_name,
33 ..
34 } => {
35 add_ground_term_pattern_types(subject, &mut types, false);
36 add_ground_term_pattern_types(object, &mut types, true);
37 if let Some(NamedNodePattern::Variable(v)) = graph_name {
38 types.intersect_variable_with(v.clone(), VariableType::NAMED_NODE)
39 }
40 types
41 }
42 GraphPattern::Graph { graph_name } => {
43 if let NamedNodePattern::Variable(v) = graph_name {
44 types.intersect_variable_with(v.clone(), VariableType::NAMED_NODE)
45 }
46 types
47 }
48 GraphPattern::Join { left, right, .. } => {
49 let mut output_types = infer_graph_pattern_types(left, types.clone());
50 output_types.intersect_with(infer_graph_pattern_types(right, types));
51 output_types
52 }
53 #[cfg(feature = "sep-0006")]
54 GraphPattern::Lateral { left, right } => {
55 infer_graph_pattern_types(right, infer_graph_pattern_types(left, types))
56 }
57 GraphPattern::LeftJoin { left, right, .. } => {
58 let mut right_types = infer_graph_pattern_types(right, types.clone()); for t in right_types.inner.values_mut() {
60 t.undef = true; }
62 let mut output_types = infer_graph_pattern_types(left, types);
63 output_types.intersect_with(right_types);
64 output_types
65 }
66 GraphPattern::Minus { left, .. } => infer_graph_pattern_types(left, types),
67 GraphPattern::Union { inner } => inner
68 .iter()
69 .map(|inner| infer_graph_pattern_types(inner, types.clone()))
70 .reduce(|mut a, b| {
71 a.union_with(b);
72 a
73 })
74 .unwrap_or_default(),
75 GraphPattern::Extend {
76 inner,
77 variable,
78 expression,
79 } => {
80 let mut types = infer_graph_pattern_types(inner, types);
81 types.intersect_variable_with(
82 variable.clone(),
83 infer_expression_type(expression, &types),
84 );
85 types
86 }
87 GraphPattern::Filter { inner, .. } => infer_graph_pattern_types(inner, types),
88 GraphPattern::Project { inner, variables } => VariableTypes {
89 inner: infer_graph_pattern_types(inner, types)
90 .inner
91 .into_iter()
92 .filter(|(v, _)| variables.contains(v))
93 .collect(),
94 },
95 GraphPattern::Distinct { inner }
96 | GraphPattern::Reduced { inner }
97 | GraphPattern::OrderBy { inner, .. }
98 | GraphPattern::Slice { inner, .. } => infer_graph_pattern_types(inner, types),
99 GraphPattern::Group {
100 inner,
101 variables,
102 aggregates,
103 } => {
104 let types = infer_graph_pattern_types(inner, types);
105 VariableTypes {
106 inner: infer_graph_pattern_types(inner, types)
107 .inner
108 .into_iter()
109 .filter(|(v, _)| variables.contains(v))
110 .chain(aggregates.iter().map(|(v, _)| (v.clone(), VariableType::ANY))) .collect(),
112 }
113 }
114 GraphPattern::Values {
115 variables,
116 bindings,
117 } => {
118 for (i, v) in variables.iter().enumerate() {
119 let mut t = VariableType::default();
120 for binding in bindings {
121 match binding[i] {
122 Some(GroundTerm::NamedNode(_)) => t.named_node = true,
123 Some(GroundTerm::Literal(_)) => t.literal = true,
124 #[cfg(feature = "rdf-star")]
125 Some(GroundTerm::Triple(_)) => t.triple = true,
126 None => t.undef = true,
127 }
128 }
129 types.intersect_variable_with(v.clone(), t)
130 }
131 types
132 }
133 GraphPattern::Service {
134 name,
135 inner,
136 silent,
137 } => {
138 let parent_types = types.clone();
139 let mut types = infer_graph_pattern_types(inner, types);
140 if *silent {
141 types.union_with(parent_types);
143 } else if let NamedNodePattern::Variable(v) = name {
144 types.intersect_variable_with(v.clone(), VariableType::NAMED_NODE)
145 }
146 types
147 }
148 }
149}
150
151fn add_ground_term_pattern_types(
152 pattern: &GroundTermPattern,
153 types: &mut VariableTypes,
154 is_object: bool,
155) {
156 if let GroundTermPattern::Variable(v) = pattern {
157 types.intersect_variable_with(
158 v.clone(),
159 if is_object {
160 VariableType::TERM
161 } else {
162 VariableType::SUBJECT
163 },
164 )
165 }
166 #[cfg(feature = "rdf-star")]
167 if let GroundTermPattern::Triple(t) = pattern {
168 add_ground_term_pattern_types(&t.subject, types, false);
169 if let NamedNodePattern::Variable(v) = &t.predicate {
170 types.intersect_variable_with(v.clone(), VariableType::NAMED_NODE)
171 }
172 add_ground_term_pattern_types(&t.object, types, true);
173 }
174}
175
176pub fn infer_expression_type(expression: &Expression, types: &VariableTypes) -> VariableType {
177 match expression {
178 Expression::NamedNode(_) => VariableType::NAMED_NODE,
179 Expression::Literal(_) | Expression::Exists(_) | Expression::Bound(_) => {
180 VariableType::LITERAL
181 }
182 Expression::Variable(v) => types.get(v),
183 Expression::FunctionCall(Function::Datatype | Function::Iri, _) => {
184 VariableType::NAMED_NODE | VariableType::UNDEF
185 }
186 #[cfg(feature = "rdf-star")]
187 Expression::FunctionCall(Function::Predicate, _) => {
188 VariableType::NAMED_NODE | VariableType::UNDEF
189 }
190 Expression::FunctionCall(Function::BNode, args) => {
191 if args.is_empty() {
192 VariableType::BLANK_NODE
193 } else {
194 VariableType::BLANK_NODE | VariableType::UNDEF
195 }
196 }
197 Expression::FunctionCall(
198 Function::Rand | Function::Now | Function::Uuid | Function::StrUuid,
199 _,
200 ) => VariableType::LITERAL,
201 Expression::Or(_)
202 | Expression::And(_)
203 | Expression::Equal(_, _)
204 | Expression::Greater(_, _)
205 | Expression::GreaterOrEqual(_, _)
206 | Expression::Less(_, _)
207 | Expression::LessOrEqual(_, _)
208 | Expression::Add(_, _)
209 | Expression::Subtract(_, _)
210 | Expression::Multiply(_, _)
211 | Expression::Divide(_, _)
212 | Expression::UnaryPlus(_)
213 | Expression::UnaryMinus(_)
214 | Expression::Not(_)
215 | Expression::FunctionCall(
216 Function::Str
217 | Function::Lang
218 | Function::LangMatches
219 | Function::Abs
220 | Function::Ceil
221 | Function::Floor
222 | Function::Round
223 | Function::Concat
224 | Function::SubStr
225 | Function::StrLen
226 | Function::Replace
227 | Function::UCase
228 | Function::LCase
229 | Function::EncodeForUri
230 | Function::Contains
231 | Function::StrStarts
232 | Function::StrEnds
233 | Function::StrBefore
234 | Function::StrAfter
235 | Function::Year
236 | Function::Month
237 | Function::Day
238 | Function::Hours
239 | Function::Minutes
240 | Function::Seconds
241 | Function::Timezone
242 | Function::Tz
243 | Function::Md5
244 | Function::Sha1
245 | Function::Sha256
246 | Function::Sha384
247 | Function::Sha512
248 | Function::StrLang
249 | Function::StrDt
250 | Function::IsIri
251 | Function::IsBlank
252 | Function::IsLiteral
253 | Function::IsNumeric
254 | Function::Regex,
255 _,
256 ) => VariableType::LITERAL | VariableType::UNDEF,
257 #[cfg(feature = "sep-0002")]
258 Expression::FunctionCall(Function::Adjust, _) => {
259 VariableType::LITERAL | VariableType::UNDEF
260 }
261 #[cfg(feature = "rdf-star")]
262 Expression::FunctionCall(Function::IsTriple, _) => {
263 VariableType::LITERAL | VariableType::UNDEF
264 }
265 Expression::SameTerm(left, right) => {
266 if infer_expression_type(left, types).undef || infer_expression_type(right, types).undef
267 {
268 VariableType::LITERAL | VariableType::UNDEF
269 } else {
270 VariableType::LITERAL
271 }
272 }
273 Expression::If(_, then, els) => {
274 infer_expression_type(then, types) | infer_expression_type(els, types)
275 }
276 Expression::Coalesce(inner) => {
277 let mut t = VariableType::UNDEF;
278 for e in inner {
279 let new = infer_expression_type(e, types);
280 t = t | new;
281 if !new.undef {
282 t.undef = false;
283 return t;
284 }
285 }
286 t
287 }
288 #[cfg(feature = "rdf-star")]
289 Expression::FunctionCall(Function::Triple, _) => VariableType::TRIPLE | VariableType::UNDEF,
290 #[cfg(feature = "rdf-star")]
291 Expression::FunctionCall(Function::Subject, _) => {
292 VariableType::SUBJECT | VariableType::UNDEF
293 }
294 #[cfg(feature = "rdf-star")]
295 Expression::FunctionCall(Function::Object, _) => VariableType::TERM | VariableType::UNDEF,
296 Expression::FunctionCall(Function::Custom(_), _) => VariableType::ANY,
297 }
298}
299
300#[derive(Default, Clone, Debug)]
301pub struct VariableTypes {
302 inner: HashMap<Variable, VariableType>,
303}
304
305impl VariableTypes {
306 pub fn get(&self, variable: &Variable) -> VariableType {
307 self.inner
308 .get(variable)
309 .copied()
310 .unwrap_or(VariableType::UNDEF)
311 }
312
313 pub fn iter(&self) -> impl Iterator<Item = (&Variable, &VariableType)> {
314 self.inner.iter()
315 }
316
317 pub fn intersect_with(&mut self, other: Self) {
318 for (v, t) in other.inner {
319 self.intersect_variable_with(v, t);
320 }
321 }
322
323 pub fn union_with(&mut self, other: Self) {
324 for (v, t) in &mut self.inner {
325 if other.get(v).undef {
326 t.undef = true; }
328 }
329 for (v, mut t) in other.inner {
330 self.inner
331 .entry(v)
332 .and_modify(|ex| *ex = *ex | t)
333 .or_insert({
334 t.undef = true;
335 t
336 });
337 }
338 }
339
340 fn intersect_variable_with(&mut self, variable: Variable, t: VariableType) {
341 let t = self.get(&variable) & t;
342 if t != VariableType::UNDEF {
343 self.inner.insert(variable, t);
344 }
345 }
346}
347
348#[allow(clippy::struct_excessive_bools)]
349#[derive(Clone, Copy, Eq, PartialEq, Debug, Default)]
350pub struct VariableType {
351 pub undef: bool,
352 pub named_node: bool,
353 pub blank_node: bool,
354 pub literal: bool,
355 #[cfg(feature = "rdf-star")]
356 pub triple: bool,
357}
358
359impl VariableType {
360 const ANY: Self = Self {
361 undef: true,
362 named_node: true,
363 blank_node: true,
364 literal: true,
365 #[cfg(feature = "rdf-star")]
366 triple: true,
367 };
368 const BLANK_NODE: Self = Self {
369 undef: false,
370 named_node: false,
371 blank_node: true,
372 literal: false,
373 #[cfg(feature = "rdf-star")]
374 triple: false,
375 };
376 const LITERAL: Self = Self {
377 undef: false,
378 named_node: false,
379 blank_node: false,
380 literal: true,
381 #[cfg(feature = "rdf-star")]
382 triple: false,
383 };
384 const NAMED_NODE: Self = Self {
385 undef: false,
386 named_node: true,
387 blank_node: false,
388 literal: false,
389 #[cfg(feature = "rdf-star")]
390 triple: false,
391 };
392 const SUBJECT: Self = Self {
393 undef: false,
394 named_node: true,
395 blank_node: true,
396 literal: false,
397 #[cfg(feature = "rdf-star")]
398 triple: true,
399 };
400 const TERM: Self = Self {
401 undef: false,
402 named_node: true,
403 blank_node: true,
404 literal: true,
405 #[cfg(feature = "rdf-star")]
406 triple: true,
407 };
408 #[cfg(feature = "rdf-star")]
409 const TRIPLE: Self = Self {
410 undef: false,
411 named_node: false,
412 blank_node: false,
413 literal: false,
414 triple: true,
415 };
416 pub const UNDEF: Self = Self {
417 undef: true,
418 named_node: false,
419 blank_node: false,
420 literal: false,
421 #[cfg(feature = "rdf-star")]
422 triple: false,
423 };
424}
425
426impl BitOr for VariableType {
427 type Output = Self;
428
429 fn bitor(self, rhs: Self) -> Self {
430 Self {
431 undef: self.undef || rhs.undef,
432 named_node: self.named_node || rhs.named_node,
433 blank_node: self.blank_node || rhs.blank_node,
434 literal: self.literal || rhs.literal,
435 #[cfg(feature = "rdf-star")]
436 triple: self.triple || rhs.triple,
437 }
438 }
439}
440
441impl BitAnd for VariableType {
442 type Output = Self;
443
444 #[allow(clippy::nonminimal_bool)]
445 fn bitand(self, rhs: Self) -> Self {
446 Self {
447 undef: self.undef && rhs.undef,
448 named_node: self.named_node && rhs.named_node
449 || (self.undef && rhs.named_node)
450 || (self.named_node && rhs.undef),
451 blank_node: self.blank_node && rhs.blank_node
452 || (self.undef && rhs.blank_node)
453 || (self.blank_node && rhs.undef),
454 literal: self.literal && rhs.literal
455 || (self.undef && rhs.literal)
456 || (self.literal && rhs.undef),
457 #[cfg(feature = "rdf-star")]
458 triple: self.triple && rhs.triple
459 || (self.undef && rhs.triple)
460 || (self.triple && rhs.undef),
461 }
462 }
463}