sparopt/
type_inference.rs

1use crate::algebra::{Expression, GraphPattern};
2use oxrdf::Variable;
3use spargebra::algebra::Function;
4use spargebra::term::{GroundTerm, GroundTermPattern, NamedNodePattern};
5use std::collections::HashMap;
6use std::ops::{BitAnd, BitOr};
7
8pub fn infer_graph_pattern_types(
9    pattern: &GraphPattern,
10    mut types: VariableTypes,
11) -> VariableTypes {
12    match pattern {
13        GraphPattern::QuadPattern {
14            subject,
15            predicate,
16            object,
17            graph_name,
18        } => {
19            add_ground_term_pattern_types(subject, &mut types, false);
20            if let NamedNodePattern::Variable(v) = predicate {
21                types.intersect_variable_with(v.clone(), VariableType::NAMED_NODE)
22            }
23            add_ground_term_pattern_types(object, &mut types, true);
24            if let Some(NamedNodePattern::Variable(v)) = graph_name {
25                types.intersect_variable_with(v.clone(), VariableType::NAMED_NODE)
26            }
27            types
28        }
29        GraphPattern::Path {
30            subject,
31            object,
32            graph_name,
33            ..
34        } => {
35            add_ground_term_pattern_types(subject, &mut types, false);
36            add_ground_term_pattern_types(object, &mut types, true);
37            if let Some(NamedNodePattern::Variable(v)) = graph_name {
38                types.intersect_variable_with(v.clone(), VariableType::NAMED_NODE)
39            }
40            types
41        }
42        GraphPattern::Graph { graph_name } => {
43            if let NamedNodePattern::Variable(v) = graph_name {
44                types.intersect_variable_with(v.clone(), VariableType::NAMED_NODE)
45            }
46            types
47        }
48        GraphPattern::Join { left, right, .. } => {
49            let mut output_types = infer_graph_pattern_types(left, types.clone());
50            output_types.intersect_with(infer_graph_pattern_types(right, types));
51            output_types
52        }
53        #[cfg(feature = "sep-0006")]
54        GraphPattern::Lateral { left, right } => {
55            infer_graph_pattern_types(right, infer_graph_pattern_types(left, types))
56        }
57        GraphPattern::LeftJoin { left, right, .. } => {
58            let mut right_types = infer_graph_pattern_types(right, types.clone()); // TODO: expression
59            for t in right_types.inner.values_mut() {
60                t.undef = true; // Right might be unset
61            }
62            let mut output_types = infer_graph_pattern_types(left, types);
63            output_types.intersect_with(right_types);
64            output_types
65        }
66        GraphPattern::Minus { left, .. } => infer_graph_pattern_types(left, types),
67        GraphPattern::Union { inner } => inner
68            .iter()
69            .map(|inner| infer_graph_pattern_types(inner, types.clone()))
70            .reduce(|mut a, b| {
71                a.union_with(b);
72                a
73            })
74            .unwrap_or_default(),
75        GraphPattern::Extend {
76            inner,
77            variable,
78            expression,
79        } => {
80            let mut types = infer_graph_pattern_types(inner, types);
81            types.intersect_variable_with(
82                variable.clone(),
83                infer_expression_type(expression, &types),
84            );
85            types
86        }
87        GraphPattern::Filter { inner, .. } => infer_graph_pattern_types(inner, types),
88        GraphPattern::Project { inner, variables } => VariableTypes {
89            inner: infer_graph_pattern_types(inner, types)
90                .inner
91                .into_iter()
92                .filter(|(v, _)| variables.contains(v))
93                .collect(),
94        },
95        GraphPattern::Distinct { inner }
96        | GraphPattern::Reduced { inner }
97        | GraphPattern::OrderBy { inner, .. }
98        | GraphPattern::Slice { inner, .. } => infer_graph_pattern_types(inner, types),
99        GraphPattern::Group {
100            inner,
101            variables,
102            aggregates,
103        } => {
104            let types = infer_graph_pattern_types(inner, types);
105            VariableTypes {
106                inner: infer_graph_pattern_types(inner, types)
107                    .inner
108                    .into_iter()
109                    .filter(|(v, _)| variables.contains(v))
110                    .chain(aggregates.iter().map(|(v, _)| (v.clone(), VariableType::ANY))) //TODO: guess from aggregate
111                    .collect(),
112            }
113        }
114        GraphPattern::Values {
115            variables,
116            bindings,
117        } => {
118            for (i, v) in variables.iter().enumerate() {
119                let mut t = VariableType::default();
120                for binding in bindings {
121                    match binding[i] {
122                        Some(GroundTerm::NamedNode(_)) => t.named_node = true,
123                        Some(GroundTerm::Literal(_)) => t.literal = true,
124                        #[cfg(feature = "rdf-star")]
125                        Some(GroundTerm::Triple(_)) => t.triple = true,
126                        None => t.undef = true,
127                    }
128                }
129                types.intersect_variable_with(v.clone(), t)
130            }
131            types
132        }
133        GraphPattern::Service {
134            name,
135            inner,
136            silent,
137        } => {
138            let parent_types = types.clone();
139            let mut types = infer_graph_pattern_types(inner, types);
140            if *silent {
141                // On failure, single empty solution
142                types.union_with(parent_types);
143            } else if let NamedNodePattern::Variable(v) = name {
144                types.intersect_variable_with(v.clone(), VariableType::NAMED_NODE)
145            }
146            types
147        }
148    }
149}
150
151fn add_ground_term_pattern_types(
152    pattern: &GroundTermPattern,
153    types: &mut VariableTypes,
154    is_object: bool,
155) {
156    if let GroundTermPattern::Variable(v) = pattern {
157        types.intersect_variable_with(
158            v.clone(),
159            if is_object {
160                VariableType::TERM
161            } else {
162                VariableType::SUBJECT
163            },
164        )
165    }
166    #[cfg(feature = "rdf-star")]
167    if let GroundTermPattern::Triple(t) = pattern {
168        add_ground_term_pattern_types(&t.subject, types, false);
169        if let NamedNodePattern::Variable(v) = &t.predicate {
170            types.intersect_variable_with(v.clone(), VariableType::NAMED_NODE)
171        }
172        add_ground_term_pattern_types(&t.object, types, true);
173    }
174}
175
176pub fn infer_expression_type(expression: &Expression, types: &VariableTypes) -> VariableType {
177    match expression {
178        Expression::NamedNode(_) => VariableType::NAMED_NODE,
179        Expression::Literal(_) | Expression::Exists(_) | Expression::Bound(_) => {
180            VariableType::LITERAL
181        }
182        Expression::Variable(v) => types.get(v),
183        Expression::FunctionCall(Function::Datatype | Function::Iri, _) => {
184            VariableType::NAMED_NODE | VariableType::UNDEF
185        }
186        #[cfg(feature = "rdf-star")]
187        Expression::FunctionCall(Function::Predicate, _) => {
188            VariableType::NAMED_NODE | VariableType::UNDEF
189        }
190        Expression::FunctionCall(Function::BNode, args) => {
191            if args.is_empty() {
192                VariableType::BLANK_NODE
193            } else {
194                VariableType::BLANK_NODE | VariableType::UNDEF
195            }
196        }
197        Expression::FunctionCall(
198            Function::Rand | Function::Now | Function::Uuid | Function::StrUuid,
199            _,
200        ) => VariableType::LITERAL,
201        Expression::Or(_)
202        | Expression::And(_)
203        | Expression::Equal(_, _)
204        | Expression::Greater(_, _)
205        | Expression::GreaterOrEqual(_, _)
206        | Expression::Less(_, _)
207        | Expression::LessOrEqual(_, _)
208        | Expression::Add(_, _)
209        | Expression::Subtract(_, _)
210        | Expression::Multiply(_, _)
211        | Expression::Divide(_, _)
212        | Expression::UnaryPlus(_)
213        | Expression::UnaryMinus(_)
214        | Expression::Not(_)
215        | Expression::FunctionCall(
216            Function::Str
217            | Function::Lang
218            | Function::LangMatches
219            | Function::Abs
220            | Function::Ceil
221            | Function::Floor
222            | Function::Round
223            | Function::Concat
224            | Function::SubStr
225            | Function::StrLen
226            | Function::Replace
227            | Function::UCase
228            | Function::LCase
229            | Function::EncodeForUri
230            | Function::Contains
231            | Function::StrStarts
232            | Function::StrEnds
233            | Function::StrBefore
234            | Function::StrAfter
235            | Function::Year
236            | Function::Month
237            | Function::Day
238            | Function::Hours
239            | Function::Minutes
240            | Function::Seconds
241            | Function::Timezone
242            | Function::Tz
243            | Function::Md5
244            | Function::Sha1
245            | Function::Sha256
246            | Function::Sha384
247            | Function::Sha512
248            | Function::StrLang
249            | Function::StrDt
250            | Function::IsIri
251            | Function::IsBlank
252            | Function::IsLiteral
253            | Function::IsNumeric
254            | Function::Regex,
255            _,
256        ) => VariableType::LITERAL | VariableType::UNDEF,
257        #[cfg(feature = "sep-0002")]
258        Expression::FunctionCall(Function::Adjust, _) => {
259            VariableType::LITERAL | VariableType::UNDEF
260        }
261        #[cfg(feature = "rdf-star")]
262        Expression::FunctionCall(Function::IsTriple, _) => {
263            VariableType::LITERAL | VariableType::UNDEF
264        }
265        Expression::SameTerm(left, right) => {
266            if infer_expression_type(left, types).undef || infer_expression_type(right, types).undef
267            {
268                VariableType::LITERAL | VariableType::UNDEF
269            } else {
270                VariableType::LITERAL
271            }
272        }
273        Expression::If(_, then, els) => {
274            infer_expression_type(then, types) | infer_expression_type(els, types)
275        }
276        Expression::Coalesce(inner) => {
277            let mut t = VariableType::UNDEF;
278            for e in inner {
279                let new = infer_expression_type(e, types);
280                t = t | new;
281                if !new.undef {
282                    t.undef = false;
283                    return t;
284                }
285            }
286            t
287        }
288        #[cfg(feature = "rdf-star")]
289        Expression::FunctionCall(Function::Triple, _) => VariableType::TRIPLE | VariableType::UNDEF,
290        #[cfg(feature = "rdf-star")]
291        Expression::FunctionCall(Function::Subject, _) => {
292            VariableType::SUBJECT | VariableType::UNDEF
293        }
294        #[cfg(feature = "rdf-star")]
295        Expression::FunctionCall(Function::Object, _) => VariableType::TERM | VariableType::UNDEF,
296        Expression::FunctionCall(Function::Custom(_), _) => VariableType::ANY,
297    }
298}
299
300#[derive(Default, Clone, Debug)]
301pub struct VariableTypes {
302    inner: HashMap<Variable, VariableType>,
303}
304
305impl VariableTypes {
306    pub fn get(&self, variable: &Variable) -> VariableType {
307        self.inner
308            .get(variable)
309            .copied()
310            .unwrap_or(VariableType::UNDEF)
311    }
312
313    pub fn iter(&self) -> impl Iterator<Item = (&Variable, &VariableType)> {
314        self.inner.iter()
315    }
316
317    pub fn intersect_with(&mut self, other: Self) {
318        for (v, t) in other.inner {
319            self.intersect_variable_with(v, t);
320        }
321    }
322
323    pub fn union_with(&mut self, other: Self) {
324        for (v, t) in &mut self.inner {
325            if other.get(v).undef {
326                t.undef = true; // Might be undefined
327            }
328        }
329        for (v, mut t) in other.inner {
330            self.inner
331                .entry(v)
332                .and_modify(|ex| *ex = *ex | t)
333                .or_insert({
334                    t.undef = true;
335                    t
336                });
337        }
338    }
339
340    fn intersect_variable_with(&mut self, variable: Variable, t: VariableType) {
341        let t = self.get(&variable) & t;
342        if t != VariableType::UNDEF {
343            self.inner.insert(variable, t);
344        }
345    }
346}
347
348#[allow(clippy::struct_excessive_bools)]
349#[derive(Clone, Copy, Eq, PartialEq, Debug, Default)]
350pub struct VariableType {
351    pub undef: bool,
352    pub named_node: bool,
353    pub blank_node: bool,
354    pub literal: bool,
355    #[cfg(feature = "rdf-star")]
356    pub triple: bool,
357}
358
359impl VariableType {
360    const ANY: Self = Self {
361        undef: true,
362        named_node: true,
363        blank_node: true,
364        literal: true,
365        #[cfg(feature = "rdf-star")]
366        triple: true,
367    };
368    const BLANK_NODE: Self = Self {
369        undef: false,
370        named_node: false,
371        blank_node: true,
372        literal: false,
373        #[cfg(feature = "rdf-star")]
374        triple: false,
375    };
376    const LITERAL: Self = Self {
377        undef: false,
378        named_node: false,
379        blank_node: false,
380        literal: true,
381        #[cfg(feature = "rdf-star")]
382        triple: false,
383    };
384    const NAMED_NODE: Self = Self {
385        undef: false,
386        named_node: true,
387        blank_node: false,
388        literal: false,
389        #[cfg(feature = "rdf-star")]
390        triple: false,
391    };
392    const SUBJECT: Self = Self {
393        undef: false,
394        named_node: true,
395        blank_node: true,
396        literal: false,
397        #[cfg(feature = "rdf-star")]
398        triple: true,
399    };
400    const TERM: Self = Self {
401        undef: false,
402        named_node: true,
403        blank_node: true,
404        literal: true,
405        #[cfg(feature = "rdf-star")]
406        triple: true,
407    };
408    #[cfg(feature = "rdf-star")]
409    const TRIPLE: Self = Self {
410        undef: false,
411        named_node: false,
412        blank_node: false,
413        literal: false,
414        triple: true,
415    };
416    pub const UNDEF: Self = Self {
417        undef: true,
418        named_node: false,
419        blank_node: false,
420        literal: false,
421        #[cfg(feature = "rdf-star")]
422        triple: false,
423    };
424}
425
426impl BitOr for VariableType {
427    type Output = Self;
428
429    fn bitor(self, rhs: Self) -> Self {
430        Self {
431            undef: self.undef || rhs.undef,
432            named_node: self.named_node || rhs.named_node,
433            blank_node: self.blank_node || rhs.blank_node,
434            literal: self.literal || rhs.literal,
435            #[cfg(feature = "rdf-star")]
436            triple: self.triple || rhs.triple,
437        }
438    }
439}
440
441impl BitAnd for VariableType {
442    type Output = Self;
443
444    #[allow(clippy::nonminimal_bool)]
445    fn bitand(self, rhs: Self) -> Self {
446        Self {
447            undef: self.undef && rhs.undef,
448            named_node: self.named_node && rhs.named_node
449                || (self.undef && rhs.named_node)
450                || (self.named_node && rhs.undef),
451            blank_node: self.blank_node && rhs.blank_node
452                || (self.undef && rhs.blank_node)
453                || (self.blank_node && rhs.undef),
454            literal: self.literal && rhs.literal
455                || (self.undef && rhs.literal)
456                || (self.literal && rhs.undef),
457            #[cfg(feature = "rdf-star")]
458            triple: self.triple && rhs.triple
459                || (self.undef && rhs.triple)
460                || (self.triple && rhs.undef),
461        }
462    }
463}