oxigraph/storage/
memory.rs

1use crate::model::{GraphNameRef, NamedOrBlankNodeRef, QuadRef, TermRef};
2pub use crate::storage::error::StorageError;
3use crate::storage::numeric_encoder::{
4    insert_term, Decoder, EncodedQuad, EncodedTerm, StrHash, StrHashHasher, StrLookup,
5};
6use crate::storage::CorruptionError;
7use dashmap::iter::Iter;
8use dashmap::mapref::entry::Entry;
9use dashmap::{DashMap, DashSet};
10use oxrdf::Quad;
11use rustc_hash::FxHasher;
12use std::borrow::Borrow;
13use std::error::Error;
14use std::hash::{BuildHasherDefault, Hash, Hasher};
15use std::mem::transmute;
16use std::sync::atomic::{AtomicUsize, Ordering};
17use std::sync::{Arc, Mutex, RwLock, Weak};
18
19/// In-memory storage working with MVCC
20///
21/// Each quad and graph name is annotated by a version range, allowing to read old versions while updates are applied.
22/// To simplify the implementation a single write transaction is currently allowed. This restriction should be lifted in the future.
23#[derive(Clone)]
24pub struct MemoryStorage {
25    content: Arc<Content>,
26    id2str: Arc<DashMap<StrHash, String, BuildHasherDefault<StrHashHasher>>>,
27    version_counter: Arc<AtomicUsize>,
28    transaction_counter: Arc<Mutex<usize>>,
29}
30
31struct Content {
32    quad_set: DashSet<Arc<QuadListNode>, BuildHasherDefault<FxHasher>>,
33    last_quad: RwLock<Option<Weak<QuadListNode>>>,
34    last_quad_by_subject:
35        DashMap<EncodedTerm, (Weak<QuadListNode>, u64), BuildHasherDefault<FxHasher>>,
36    last_quad_by_predicate:
37        DashMap<EncodedTerm, (Weak<QuadListNode>, u64), BuildHasherDefault<FxHasher>>,
38    last_quad_by_object:
39        DashMap<EncodedTerm, (Weak<QuadListNode>, u64), BuildHasherDefault<FxHasher>>,
40    last_quad_by_graph_name:
41        DashMap<EncodedTerm, (Weak<QuadListNode>, u64), BuildHasherDefault<FxHasher>>,
42    graphs: DashMap<EncodedTerm, VersionRange>,
43}
44
45impl MemoryStorage {
46    pub fn new() -> Self {
47        Self {
48            content: Arc::new(Content {
49                quad_set: DashSet::default(),
50                last_quad: RwLock::new(None),
51                last_quad_by_subject: DashMap::default(),
52                last_quad_by_predicate: DashMap::default(),
53                last_quad_by_object: DashMap::default(),
54                last_quad_by_graph_name: DashMap::default(),
55                graphs: DashMap::default(),
56            }),
57            id2str: Arc::new(DashMap::default()),
58            version_counter: Arc::new(AtomicUsize::new(0)),
59            #[allow(clippy::mutex_atomic)]
60            transaction_counter: Arc::new(Mutex::new(usize::MAX >> 1)),
61        }
62    }
63
64    pub fn snapshot(&self) -> MemoryStorageReader {
65        MemoryStorageReader {
66            storage: self.clone(),
67            snapshot_id: self.version_counter.load(Ordering::Acquire),
68        }
69    }
70
71    #[allow(clippy::unwrap_in_result)]
72    pub fn transaction<T, E: Error + 'static + From<StorageError>>(
73        &self,
74        f: impl for<'a> Fn(MemoryStorageWriter<'a>) -> Result<T, E>,
75    ) -> Result<T, E> {
76        let mut transaction_mutex = self.transaction_counter.lock().unwrap();
77        *transaction_mutex += 1;
78        let transaction_id = *transaction_mutex;
79        let snapshot_id = self.version_counter.load(Ordering::Acquire);
80        let mut operations = Vec::new();
81        let result = f(MemoryStorageWriter {
82            storage: self,
83            log: &mut operations,
84            transaction_id,
85        });
86        if result.is_ok() {
87            let new_version_id = snapshot_id + 1;
88            for operation in operations {
89                match operation {
90                    LogEntry::QuadNode(node) => {
91                        node.range
92                            .lock()
93                            .unwrap()
94                            .upgrade_transaction(transaction_id, new_version_id);
95                    }
96                    LogEntry::Graph(graph_name) => {
97                        if let Some(mut entry) = self.content.graphs.get_mut(&graph_name) {
98                            entry
99                                .value_mut()
100                                .upgrade_transaction(transaction_id, new_version_id)
101                        }
102                    }
103                }
104            }
105            self.version_counter
106                .store(new_version_id, Ordering::Release);
107        } else {
108            for operation in operations {
109                match operation {
110                    LogEntry::QuadNode(node) => {
111                        node.range
112                            .lock()
113                            .unwrap()
114                            .rollback_transaction(transaction_id);
115                    }
116                    LogEntry::Graph(graph_name) => {
117                        if let Some(mut entry) = self.content.graphs.get_mut(&graph_name) {
118                            entry.value_mut().rollback_transaction(transaction_id)
119                        }
120                    }
121                }
122            }
123        }
124        // TODO: garbage collection
125        result
126    }
127
128    pub fn bulk_loader(&self) -> MemoryStorageBulkLoader {
129        MemoryStorageBulkLoader {
130            storage: self.clone(),
131            hooks: Vec::new(),
132        }
133    }
134}
135
136#[derive(Clone)]
137pub struct MemoryStorageReader {
138    storage: MemoryStorage,
139    snapshot_id: usize,
140}
141
142impl MemoryStorageReader {
143    pub fn len(&self) -> usize {
144        self.storage
145            .content
146            .quad_set
147            .iter()
148            .filter(|e| self.is_node_in_range(e))
149            .count()
150    }
151
152    pub fn is_empty(&self) -> bool {
153        !self
154            .storage
155            .content
156            .quad_set
157            .iter()
158            .any(|e| self.is_node_in_range(&e))
159    }
160
161    pub fn contains(&self, quad: &EncodedQuad) -> bool {
162        self.storage
163            .content
164            .quad_set
165            .get(quad)
166            .is_some_and(|node| self.is_node_in_range(&node))
167    }
168
169    pub fn quads_for_pattern(
170        &self,
171        subject: Option<&EncodedTerm>,
172        predicate: Option<&EncodedTerm>,
173        object: Option<&EncodedTerm>,
174        graph_name: Option<&EncodedTerm>,
175    ) -> QuadIterator {
176        fn get_start_and_count(
177            map: &DashMap<EncodedTerm, (Weak<QuadListNode>, u64), BuildHasherDefault<FxHasher>>,
178            term: Option<&EncodedTerm>,
179        ) -> (Option<Weak<QuadListNode>>, u64) {
180            let Some(term) = term else {
181                return (None, u64::MAX);
182            };
183            map.view(term, |_, (node, count)| (Some(Weak::clone(node)), *count))
184                .unwrap_or_default()
185        }
186
187        let (subject_start, subject_count) =
188            get_start_and_count(&self.storage.content.last_quad_by_subject, subject);
189        let (predicate_start, predicate_count) =
190            get_start_and_count(&self.storage.content.last_quad_by_predicate, predicate);
191        let (object_start, object_count) =
192            get_start_and_count(&self.storage.content.last_quad_by_object, object);
193        let (graph_name_start, graph_name_count) =
194            get_start_and_count(&self.storage.content.last_quad_by_graph_name, graph_name);
195
196        let (start, kind) = if subject.is_some()
197            && subject_count <= predicate_count
198            && subject_count <= object_count
199            && subject_count <= graph_name_count
200        {
201            (subject_start, QuadIteratorKind::Subject)
202        } else if predicate.is_some()
203            && predicate_count <= object_count
204            && predicate_count <= graph_name_count
205        {
206            (predicate_start, QuadIteratorKind::Predicate)
207        } else if object.is_some() && object_count <= graph_name_count {
208            (object_start, QuadIteratorKind::Object)
209        } else if graph_name.is_some() {
210            (graph_name_start, QuadIteratorKind::GraphName)
211        } else {
212            (
213                self.storage.content.last_quad.read().unwrap().clone(),
214                QuadIteratorKind::All,
215            )
216        };
217        QuadIterator {
218            reader: self.clone(),
219            current: start,
220            kind,
221            expect_subject: if kind == QuadIteratorKind::Subject {
222                None
223            } else {
224                subject.cloned()
225            },
226            expect_predicate: if kind == QuadIteratorKind::Predicate {
227                None
228            } else {
229                predicate.cloned()
230            },
231            expect_object: if kind == QuadIteratorKind::Object {
232                None
233            } else {
234                object.cloned()
235            },
236            expect_graph_name: if kind == QuadIteratorKind::GraphName {
237                None
238            } else {
239                graph_name.cloned()
240            },
241        }
242    }
243
244    #[allow(unsafe_code)]
245    pub fn named_graphs(&self) -> MemoryDecodingGraphIterator {
246        MemoryDecodingGraphIterator {
247            reader: self.clone(),
248            // SAFETY: this is fine, the owning struct also owns the iterated data structure
249            iter: unsafe {
250                transmute::<Iter<'_, _, _>, Iter<'static, _, _>>(self.storage.content.graphs.iter())
251            },
252        }
253    }
254
255    pub fn contains_named_graph(&self, graph_name: &EncodedTerm) -> bool {
256        self.storage
257            .content
258            .graphs
259            .get(graph_name)
260            .is_some_and(|range| self.is_in_range(&range))
261    }
262
263    pub fn contains_str(&self, key: &StrHash) -> bool {
264        self.storage.id2str.contains_key(key)
265    }
266
267    /// Validates that all the storage invariants held in the data
268    #[allow(clippy::unwrap_in_result)]
269    pub fn validate(&self) -> Result<(), StorageError> {
270        // All used named graphs are in graph set
271        let expected_quad_len = self.storage.content.quad_set.len() as u64;
272
273        // last quad chain
274        let mut next = self.storage.content.last_quad.read().unwrap().clone();
275        let mut count_last_quad = 0;
276        while let Some(current) = next.take().and_then(|c| c.upgrade()) {
277            count_last_quad += 1;
278            if !self
279                .storage
280                .content
281                .quad_set
282                .get(&current.quad)
283                .is_some_and(|e| Arc::ptr_eq(&e, &current))
284            {
285                return Err(
286                    CorruptionError::new("Quad in previous chain but not in quad set").into(),
287                );
288            }
289            self.decode_quad(&current.quad)?;
290            if !current.quad.graph_name.is_default_graph()
291                && !self
292                    .storage
293                    .content
294                    .graphs
295                    .contains_key(&current.quad.graph_name)
296            {
297                return Err(
298                    CorruptionError::new("Quad in named graph that does not exists").into(),
299                );
300            };
301            next.clone_from(&current.previous);
302        }
303        if count_last_quad != expected_quad_len {
304            return Err(CorruptionError::new("Too many quads in quad_set").into());
305        }
306
307        // By subject chain
308        let mut count_last_by_subject = 0;
309        for entry in &self.storage.content.last_quad_by_subject {
310            let mut next = Some(Weak::clone(&entry.value().0));
311            let mut element_count = 0;
312            while let Some(current) = next.take().and_then(|n| n.upgrade()) {
313                element_count += 1;
314                if current.quad.subject != *entry.key() {
315                    return Err(CorruptionError::new("Quad in wrong list").into());
316                }
317                if !self
318                    .storage
319                    .content
320                    .quad_set
321                    .get(&current.quad)
322                    .is_some_and(|e| Arc::ptr_eq(&e, &current))
323                {
324                    return Err(
325                        CorruptionError::new("Quad in previous chain but not in quad set").into(),
326                    );
327                }
328                next.clone_from(&current.previous_subject);
329            }
330            if element_count != entry.value().1 {
331                return Err(CorruptionError::new("Too many quads in a chain").into());
332            }
333            count_last_by_subject += element_count;
334        }
335        if count_last_by_subject != expected_quad_len {
336            return Err(CorruptionError::new("Too many quads in quad_set").into());
337        }
338
339        // By predicate chains
340        let mut count_last_by_predicate = 0;
341        for entry in &self.storage.content.last_quad_by_predicate {
342            let mut next = Some(Weak::clone(&entry.value().0));
343            let mut element_count = 0;
344            while let Some(current) = next.take().and_then(|n| n.upgrade()) {
345                element_count += 1;
346                if current.quad.predicate != *entry.key() {
347                    return Err(CorruptionError::new("Quad in wrong list").into());
348                }
349                if !self
350                    .storage
351                    .content
352                    .quad_set
353                    .get(&current.quad)
354                    .is_some_and(|e| Arc::ptr_eq(&e, &current))
355                {
356                    return Err(
357                        CorruptionError::new("Quad in previous chain but not in quad set").into(),
358                    );
359                }
360                next.clone_from(&current.previous_predicate);
361            }
362            if element_count != entry.value().1 {
363                return Err(CorruptionError::new("Too many quads in a chain").into());
364            }
365            count_last_by_predicate += element_count;
366        }
367        if count_last_by_predicate != expected_quad_len {
368            return Err(CorruptionError::new("Too many quads in quad_set").into());
369        }
370
371        // By object chains
372        let mut count_last_by_object = 0;
373        for entry in &self.storage.content.last_quad_by_object {
374            let mut next = Some(Weak::clone(&entry.value().0));
375            let mut element_count = 0;
376            while let Some(current) = next.take().and_then(|n| n.upgrade()) {
377                element_count += 1;
378                if current.quad.object != *entry.key() {
379                    return Err(CorruptionError::new("Quad in wrong list").into());
380                }
381                if !self
382                    .storage
383                    .content
384                    .quad_set
385                    .get(&current.quad)
386                    .is_some_and(|e| Arc::ptr_eq(&e, &current))
387                {
388                    return Err(
389                        CorruptionError::new("Quad in previous chain but not in quad set").into(),
390                    );
391                }
392                next.clone_from(&current.previous_object);
393            }
394            if element_count != entry.value().1 {
395                return Err(CorruptionError::new("Too many quads in a chain").into());
396            }
397            count_last_by_object += element_count;
398        }
399        if count_last_by_object != expected_quad_len {
400            return Err(CorruptionError::new("Too many quads in quad_set").into());
401        }
402
403        // By graph_name chains
404        let mut count_last_by_graph_name = 0;
405        for entry in &self.storage.content.last_quad_by_graph_name {
406            let mut next = Some(Weak::clone(&entry.value().0));
407            let mut element_count = 0;
408            while let Some(current) = next.take().and_then(|n| n.upgrade()) {
409                element_count += 1;
410                if current.quad.graph_name != *entry.key() {
411                    return Err(CorruptionError::new("Quad in wrong list").into());
412                }
413                if !self
414                    .storage
415                    .content
416                    .quad_set
417                    .get(&current.quad)
418                    .is_some_and(|e| Arc::ptr_eq(&e, &current))
419                {
420                    return Err(
421                        CorruptionError::new("Quad in previous chain but not in quad set").into(),
422                    );
423                }
424                next.clone_from(&current.previous_graph_name);
425            }
426            if element_count != entry.value().1 {
427                return Err(CorruptionError::new("Too many quads in a chain").into());
428            }
429            count_last_by_graph_name += element_count;
430        }
431        if count_last_by_graph_name != expected_quad_len {
432            return Err(CorruptionError::new("Too many quads in quad_set").into());
433        }
434
435        Ok(())
436    }
437
438    fn is_in_range(&self, range: &VersionRange) -> bool {
439        range.contains(self.snapshot_id)
440    }
441
442    fn is_node_in_range(&self, node: &QuadListNode) -> bool {
443        let range = node.range.lock().unwrap();
444        self.is_in_range(&range)
445    }
446}
447
448impl StrLookup for MemoryStorageReader {
449    fn get_str(&self, key: &StrHash) -> Result<Option<String>, StorageError> {
450        Ok(self.storage.id2str.view(key, |_, v| v.clone()))
451    }
452}
453
454pub struct MemoryStorageWriter<'a> {
455    storage: &'a MemoryStorage,
456    log: &'a mut Vec<LogEntry>,
457    transaction_id: usize,
458}
459
460impl MemoryStorageWriter<'_> {
461    pub fn reader(&self) -> MemoryStorageReader {
462        MemoryStorageReader {
463            storage: self.storage.clone(),
464            snapshot_id: self.transaction_id,
465        }
466    }
467
468    pub fn insert(&mut self, quad: QuadRef<'_>) -> bool {
469        let encoded: EncodedQuad = quad.into();
470        if let Some(node) = self
471            .storage
472            .content
473            .quad_set
474            .get(&encoded)
475            .map(|node| Arc::clone(&node))
476        {
477            let added = node.range.lock().unwrap().add(self.transaction_id);
478            if added {
479                self.log.push(LogEntry::QuadNode(node));
480                if !quad.graph_name.is_default_graph()
481                    && self
482                        .storage
483                        .content
484                        .graphs
485                        .get_mut(&encoded.graph_name)
486                        .unwrap()
487                        .add(self.transaction_id)
488                {
489                    self.log.push(LogEntry::Graph(encoded.graph_name.clone()));
490                }
491            }
492            added
493        } else {
494            let node = Arc::new(QuadListNode {
495                quad: encoded.clone(),
496                range: Mutex::new(VersionRange::Start(self.transaction_id)),
497                previous: self.storage.content.last_quad.read().unwrap().clone(),
498                previous_subject: self
499                    .storage
500                    .content
501                    .last_quad_by_subject
502                    .view(&encoded.subject, |_, (node, _)| Weak::clone(node)),
503                previous_predicate: self
504                    .storage
505                    .content
506                    .last_quad_by_predicate
507                    .view(&encoded.predicate, |_, (node, _)| Weak::clone(node)),
508                previous_object: self
509                    .storage
510                    .content
511                    .last_quad_by_object
512                    .view(&encoded.object, |_, (node, _)| Weak::clone(node)),
513                previous_graph_name: self
514                    .storage
515                    .content
516                    .last_quad_by_graph_name
517                    .view(&encoded.graph_name, |_, (node, _)| Weak::clone(node)),
518            });
519            self.storage.content.quad_set.insert(Arc::clone(&node));
520            *self.storage.content.last_quad.write().unwrap() = Some(Arc::downgrade(&node));
521            self.storage
522                .content
523                .last_quad_by_subject
524                .entry(encoded.subject.clone())
525                .and_modify(|(e, count)| {
526                    *e = Arc::downgrade(&node);
527                    *count += 1;
528                })
529                .or_insert_with(|| (Arc::downgrade(&node), 1));
530            self.storage
531                .content
532                .last_quad_by_predicate
533                .entry(encoded.predicate.clone())
534                .and_modify(|(e, count)| {
535                    *e = Arc::downgrade(&node);
536                    *count += 1;
537                })
538                .or_insert_with(|| (Arc::downgrade(&node), 1));
539            self.storage
540                .content
541                .last_quad_by_object
542                .entry(encoded.object.clone())
543                .and_modify(|(e, count)| {
544                    *e = Arc::downgrade(&node);
545                    *count += 1;
546                })
547                .or_insert_with(|| (Arc::downgrade(&node), 1));
548            self.storage
549                .content
550                .last_quad_by_graph_name
551                .entry(encoded.graph_name.clone())
552                .and_modify(|(e, count)| {
553                    *e = Arc::downgrade(&node);
554                    *count += 1;
555                })
556                .or_insert_with(|| (Arc::downgrade(&node), 1));
557
558            self.insert_term(quad.subject.into(), &encoded.subject);
559            self.insert_term(quad.predicate.into(), &encoded.predicate);
560            self.insert_term(quad.object, &encoded.object);
561
562            match quad.graph_name {
563                GraphNameRef::NamedNode(graph_name) => {
564                    self.insert_encoded_named_graph(graph_name.into(), encoded.graph_name.clone());
565                }
566                GraphNameRef::BlankNode(graph_name) => {
567                    self.insert_encoded_named_graph(graph_name.into(), encoded.graph_name.clone());
568                }
569                GraphNameRef::DefaultGraph => (),
570            }
571            self.log.push(LogEntry::QuadNode(node));
572            true
573        }
574    }
575
576    pub fn insert_named_graph(&mut self, graph_name: NamedOrBlankNodeRef<'_>) -> bool {
577        self.insert_encoded_named_graph(graph_name, graph_name.into())
578    }
579
580    fn insert_encoded_named_graph(
581        &mut self,
582        graph_name: NamedOrBlankNodeRef<'_>,
583        encoded_graph_name: EncodedTerm,
584    ) -> bool {
585        let added = match self
586            .storage
587            .content
588            .graphs
589            .entry(encoded_graph_name.clone())
590        {
591            Entry::Occupied(mut entry) => entry.get_mut().add(self.transaction_id),
592            Entry::Vacant(entry) => {
593                entry.insert(VersionRange::Start(self.transaction_id));
594                self.insert_term(graph_name.into(), &encoded_graph_name);
595                true
596            }
597        };
598        if added {
599            self.log.push(LogEntry::Graph(encoded_graph_name));
600        }
601        added
602    }
603
604    fn insert_term(&self, term: TermRef<'_>, encoded: &EncodedTerm) {
605        insert_term(term, encoded, &mut |key, value| {
606            self.insert_str(key, value);
607            Ok(())
608        })
609        .unwrap()
610    }
611
612    fn insert_str(&self, key: &StrHash, value: &str) {
613        let inserted = self
614            .storage
615            .id2str
616            .entry(*key)
617            .or_insert_with(|| value.into());
618        debug_assert_eq!(*inserted, value, "Hash conflict for two strings");
619    }
620
621    pub fn remove(&mut self, quad: QuadRef<'_>) -> bool {
622        self.remove_encoded(&quad.into())
623    }
624
625    fn remove_encoded(&mut self, quad: &EncodedQuad) -> bool {
626        let Some(node) = self
627            .storage
628            .content
629            .quad_set
630            .get(quad)
631            .map(|node| Arc::clone(&node))
632        else {
633            return false;
634        };
635        let removed = node.range.lock().unwrap().remove(self.transaction_id);
636        if removed {
637            self.log.push(LogEntry::QuadNode(node));
638        }
639        removed
640    }
641
642    pub fn clear_graph(&mut self, graph_name: GraphNameRef<'_>) {
643        self.clear_encoded_graph(&graph_name.into())
644    }
645
646    fn clear_encoded_graph(&mut self, graph_name: &EncodedTerm) {
647        let mut next = self
648            .storage
649            .content
650            .last_quad_by_graph_name
651            .view(graph_name, |_, (node, _)| Weak::clone(node));
652        while let Some(current) = next.take().and_then(|c| c.upgrade()) {
653            if current.range.lock().unwrap().remove(self.transaction_id) {
654                self.log.push(LogEntry::QuadNode(Arc::clone(&current)));
655            }
656            next.clone_from(&current.previous_graph_name);
657        }
658    }
659
660    pub fn clear_all_named_graphs(&mut self) {
661        for graph_name in self.reader().named_graphs() {
662            self.clear_encoded_graph(&graph_name)
663        }
664    }
665
666    pub fn clear_all_graphs(&mut self) {
667        self.storage.content.quad_set.iter().for_each(|node| {
668            if node.range.lock().unwrap().remove(self.transaction_id) {
669                self.log.push(LogEntry::QuadNode(Arc::clone(&node)));
670            }
671        });
672    }
673
674    pub fn remove_named_graph(&mut self, graph_name: NamedOrBlankNodeRef<'_>) -> bool {
675        self.remove_encoded_named_graph(&graph_name.into())
676    }
677
678    fn remove_encoded_named_graph(&mut self, graph_name: &EncodedTerm) -> bool {
679        self.clear_encoded_graph(graph_name);
680        let removed = self
681            .storage
682            .content
683            .graphs
684            .get_mut(graph_name)
685            .is_some_and(|mut entry| entry.value_mut().remove(self.transaction_id));
686        if removed {
687            self.log.push(LogEntry::Graph(graph_name.clone()));
688        }
689        removed
690    }
691
692    pub fn remove_all_named_graphs(&mut self) {
693        self.clear_all_named_graphs();
694        self.do_remove_graphs();
695    }
696
697    fn do_remove_graphs(&mut self) {
698        self.storage
699            .content
700            .graphs
701            .iter_mut()
702            .for_each(|mut entry| {
703                if entry.value_mut().remove(self.transaction_id) {
704                    self.log.push(LogEntry::Graph(entry.key().clone()));
705                }
706            });
707    }
708
709    pub fn clear(&mut self) {
710        self.clear_all_graphs();
711        self.do_remove_graphs();
712    }
713}
714
715pub struct QuadIterator {
716    reader: MemoryStorageReader,
717    current: Option<Weak<QuadListNode>>,
718    kind: QuadIteratorKind,
719    expect_subject: Option<EncodedTerm>,
720    expect_predicate: Option<EncodedTerm>,
721    expect_object: Option<EncodedTerm>,
722    expect_graph_name: Option<EncodedTerm>,
723}
724
725#[derive(PartialEq, Eq, Clone, Copy)]
726enum QuadIteratorKind {
727    All,
728    Subject,
729    Predicate,
730    Object,
731    GraphName,
732}
733
734impl Iterator for QuadIterator {
735    type Item = EncodedQuad;
736
737    fn next(&mut self) -> Option<EncodedQuad> {
738        loop {
739            let current = self.current.take()?.upgrade()?;
740            self.current = match self.kind {
741                QuadIteratorKind::All => current.previous.clone(),
742                QuadIteratorKind::Subject => current.previous_subject.clone(),
743                QuadIteratorKind::Predicate => current.previous_predicate.clone(),
744                QuadIteratorKind::Object => current.previous_object.clone(),
745                QuadIteratorKind::GraphName => current.previous_graph_name.clone(),
746            };
747            if !self.reader.is_node_in_range(&current) {
748                continue;
749            }
750            if let Some(expect_subject) = &self.expect_subject {
751                if current.quad.subject != *expect_subject {
752                    continue;
753                }
754            }
755            if let Some(expect_predicate) = &self.expect_predicate {
756                if current.quad.predicate != *expect_predicate {
757                    continue;
758                }
759            }
760            if let Some(expect_object) = &self.expect_object {
761                if current.quad.object != *expect_object {
762                    continue;
763                }
764            }
765            if let Some(expect_graph_name) = &self.expect_graph_name {
766                if current.quad.graph_name != *expect_graph_name {
767                    continue;
768                }
769            }
770            return Some(current.quad.clone());
771        }
772    }
773}
774
775pub struct MemoryDecodingGraphIterator {
776    reader: MemoryStorageReader, // Needed to make sure the underlying map is not GCed
777    iter: Iter<'static, EncodedTerm, VersionRange>,
778}
779
780impl Iterator for MemoryDecodingGraphIterator {
781    type Item = EncodedTerm;
782
783    fn next(&mut self) -> Option<EncodedTerm> {
784        loop {
785            let entry = self.iter.next()?;
786            if self.reader.is_in_range(entry.value()) {
787                return Some(entry.key().clone());
788            }
789        }
790    }
791}
792
793#[must_use]
794pub struct MemoryStorageBulkLoader {
795    storage: MemoryStorage,
796    hooks: Vec<Box<dyn Fn(u64)>>,
797}
798
799impl MemoryStorageBulkLoader {
800    pub fn on_progress(mut self, callback: impl Fn(u64) + 'static) -> Self {
801        self.hooks.push(Box::new(callback));
802        self
803    }
804
805    #[allow(clippy::unwrap_in_result)]
806    pub fn load<EI, EO: From<StorageError> + From<EI>>(
807        &self,
808        quads: impl IntoIterator<Item = Result<Quad, EI>>,
809    ) -> Result<(), EO> {
810        // We lock content here to make sure there is not a transaction committing at the same time
811        let _transaction_lock = self.storage.transaction_counter.lock().unwrap();
812        let mut done_counter = 0;
813        let version_id = self.storage.version_counter.load(Ordering::Acquire) + 1;
814        let mut log = Vec::new();
815        for quad in quads {
816            MemoryStorageWriter {
817                storage: &self.storage,
818                log: &mut log,
819                transaction_id: version_id,
820            }
821            .insert(quad?.as_ref());
822            log.clear();
823            done_counter += 1;
824            if done_counter % 1_000_000 == 0 {
825                for hook in &self.hooks {
826                    hook(done_counter);
827                }
828            }
829        }
830        self.storage
831            .version_counter
832            .store(version_id, Ordering::Release);
833        Ok(())
834    }
835}
836
837enum LogEntry {
838    QuadNode(Arc<QuadListNode>),
839    Graph(EncodedTerm),
840}
841
842struct QuadListNode {
843    quad: EncodedQuad,
844    range: Mutex<VersionRange>,
845    previous: Option<Weak<Self>>,
846    previous_subject: Option<Weak<Self>>,
847    previous_predicate: Option<Weak<Self>>,
848    previous_object: Option<Weak<Self>>,
849    previous_graph_name: Option<Weak<Self>>,
850}
851
852impl PartialEq for QuadListNode {
853    #[inline]
854    fn eq(&self, other: &Self) -> bool {
855        self.quad == other.quad
856    }
857}
858
859impl Eq for QuadListNode {}
860
861impl Hash for QuadListNode {
862    #[inline]
863    fn hash<H: Hasher>(&self, state: &mut H) {
864        self.quad.hash(state)
865    }
866}
867
868impl Borrow<EncodedQuad> for Arc<QuadListNode> {
869    fn borrow(&self) -> &EncodedQuad {
870        &self.quad
871    }
872}
873
874// TODO: reduce the size to 128bits
875#[derive(Default, Eq, PartialEq, Clone)]
876enum VersionRange {
877    #[default]
878    Empty,
879    Start(usize),
880    StartEnd(usize, usize),
881    Bigger(Box<[usize]>),
882}
883
884impl VersionRange {
885    fn contains(&self, version: usize) -> bool {
886        match self {
887            VersionRange::Empty => false,
888            VersionRange::Start(start) => *start <= version,
889            VersionRange::StartEnd(start, end) => *start <= version && version < *end,
890            VersionRange::Bigger(range) => {
891                for start_end in range.chunks(2) {
892                    match start_end {
893                        [start, end] => {
894                            if *start <= version && version < *end {
895                                return true;
896                            }
897                        }
898                        [start] => {
899                            if *start <= version {
900                                return true;
901                            }
902                        }
903                        _ => (),
904                    }
905                }
906                false
907            }
908        }
909    }
910
911    fn add(&mut self, version: usize) -> bool {
912        match self {
913            VersionRange::Empty => {
914                *self = VersionRange::Start(version);
915                true
916            }
917            VersionRange::Start(_) => false,
918            VersionRange::StartEnd(start, end) => {
919                *self = if version == *end {
920                    VersionRange::Start(*start)
921                } else {
922                    VersionRange::Bigger(Box::new([*start, *end, version]))
923                };
924                true
925            }
926            VersionRange::Bigger(vec) => {
927                if vec.len() % 2 == 0 {
928                    *self = VersionRange::Bigger(if vec.ends_with(&[version]) {
929                        pop_boxed_slice(vec)
930                    } else {
931                        push_boxed_slice(vec, version)
932                    });
933                    true
934                } else {
935                    false
936                }
937            }
938        }
939    }
940
941    fn remove(&mut self, version: usize) -> bool {
942        match self {
943            VersionRange::Empty | VersionRange::StartEnd(_, _) => false,
944            VersionRange::Start(start) => {
945                *self = if *start == version {
946                    VersionRange::Empty
947                } else {
948                    VersionRange::StartEnd(*start, version)
949                };
950                true
951            }
952            VersionRange::Bigger(vec) => {
953                if vec.len() % 2 == 0 {
954                    false
955                } else {
956                    *self = if vec.ends_with(&[version]) {
957                        match vec.as_ref() {
958                            [start, end, _] => Self::StartEnd(*start, *end),
959                            _ => Self::Bigger(pop_boxed_slice(vec)),
960                        }
961                    } else {
962                        Self::Bigger(push_boxed_slice(vec, version))
963                    };
964                    true
965                }
966            }
967        }
968    }
969
970    fn upgrade_transaction(&mut self, transaction_id: usize, version_id: usize) {
971        match self {
972            VersionRange::Empty => (),
973            VersionRange::Start(start) => {
974                if *start == transaction_id {
975                    *start = version_id;
976                }
977            }
978            VersionRange::StartEnd(_, end) => {
979                if *end == transaction_id {
980                    *end = version_id
981                }
982            }
983            VersionRange::Bigger(vec) => {
984                if vec.ends_with(&[transaction_id]) {
985                    vec[vec.len() - 1] = version_id
986                }
987            }
988        }
989    }
990
991    fn rollback_transaction(&mut self, transaction_id: usize) {
992        match self {
993            VersionRange::Empty => (),
994            VersionRange::Start(start) => {
995                if *start == transaction_id {
996                    *self = VersionRange::Empty;
997                }
998            }
999            VersionRange::StartEnd(start, end) => {
1000                if *end == transaction_id {
1001                    *self = VersionRange::Start(*start)
1002                }
1003            }
1004            VersionRange::Bigger(vec) => {
1005                if vec.ends_with(&[transaction_id]) {
1006                    *self = match vec.as_ref() {
1007                        [start, end, _] => Self::StartEnd(*start, *end),
1008                        _ => Self::Bigger(pop_boxed_slice(vec)),
1009                    }
1010                }
1011            }
1012        }
1013    }
1014}
1015
1016fn push_boxed_slice<T: Copy>(slice: &[T], element: T) -> Box<[T]> {
1017    let mut out = Vec::with_capacity(slice.len() + 1);
1018    out.extend_from_slice(slice);
1019    out.push(element);
1020    out.into_boxed_slice()
1021}
1022
1023fn pop_boxed_slice<T: Copy>(slice: &[T]) -> Box<[T]> {
1024    slice[..slice.len() - 1].into()
1025}
1026
1027#[cfg(test)]
1028#[allow(clippy::panic_in_result_fn)]
1029mod tests {
1030    use super::*;
1031    use oxrdf::NamedNodeRef;
1032
1033    #[test]
1034    fn test_range() {
1035        let mut range = VersionRange::default();
1036
1037        assert!(range.add(1));
1038        assert!(!range.add(1));
1039        assert!(range.contains(1));
1040        assert!(!range.contains(0));
1041        assert!(range.contains(2));
1042
1043        assert!(range.remove(1));
1044        assert!(!range.remove(1));
1045        assert!(!range.contains(1));
1046
1047        assert!(range.add(1));
1048        assert!(range.remove(2));
1049        assert!(!range.remove(2));
1050        assert!(range.contains(1));
1051        assert!(!range.contains(2));
1052
1053        assert!(range.add(2));
1054        assert!(range.contains(3));
1055
1056        assert!(range.remove(2));
1057        assert!(range.add(4));
1058        assert!(range.remove(6));
1059        assert!(!range.contains(3));
1060        assert!(range.contains(4));
1061        assert!(!range.contains(6));
1062    }
1063
1064    #[test]
1065    fn test_upgrade() {
1066        let mut range = VersionRange::default();
1067
1068        assert!(range.add(1000));
1069        range.upgrade_transaction(999, 1);
1070        assert!(!range.contains(1));
1071        range.upgrade_transaction(1000, 1);
1072        assert!(range.contains(1));
1073
1074        assert!(range.remove(1000));
1075        range.upgrade_transaction(999, 2);
1076        assert!(range.contains(2));
1077        range.upgrade_transaction(1000, 2);
1078        assert!(!range.contains(2));
1079
1080        assert!(range.add(1000));
1081        range.upgrade_transaction(999, 3);
1082        assert!(!range.contains(3));
1083        range.upgrade_transaction(1000, 3);
1084        assert!(range.contains(3));
1085    }
1086
1087    #[test]
1088    fn test_rollback() {
1089        let mut range = VersionRange::default();
1090
1091        assert!(range.add(1000));
1092        range.rollback_transaction(999);
1093        assert!(range.contains(1000));
1094        range.rollback_transaction(1000);
1095        assert!(!range.contains(1));
1096    }
1097
1098    #[test]
1099    fn test_transaction() -> Result<(), StorageError> {
1100        let example = NamedNodeRef::new_unchecked("http://example.com/1");
1101        let example2 = NamedNodeRef::new_unchecked("http://example.com/2");
1102        let encoded_example = EncodedTerm::from(example);
1103        let encoded_example2 = EncodedTerm::from(example2);
1104        let default_quad = QuadRef::new(example, example, example, GraphNameRef::DefaultGraph);
1105        let encoded_default_quad = EncodedQuad::from(default_quad);
1106        let named_graph_quad = QuadRef::new(example, example, example, example);
1107        let encoded_named_graph_quad = EncodedQuad::from(named_graph_quad);
1108
1109        let storage = MemoryStorage::new();
1110
1111        // We start with a graph
1112        let snapshot = storage.snapshot();
1113        storage.transaction(|mut writer| {
1114            writer.insert_named_graph(example.into());
1115            Ok::<_, StorageError>(())
1116        })?;
1117        assert!(!snapshot.contains_named_graph(&encoded_example));
1118        assert!(storage.snapshot().contains_named_graph(&encoded_example));
1119        storage.snapshot().validate()?;
1120
1121        // We add two quads
1122        let snapshot = storage.snapshot();
1123        storage.transaction(|mut writer| {
1124            writer.insert(default_quad);
1125            writer.insert(named_graph_quad);
1126            Ok::<_, StorageError>(())
1127        })?;
1128        assert!(!snapshot.contains(&encoded_default_quad));
1129        assert!(!snapshot.contains(&encoded_named_graph_quad));
1130        assert!(storage.snapshot().contains(&encoded_default_quad));
1131        assert!(storage.snapshot().contains(&encoded_named_graph_quad));
1132        storage.snapshot().validate()?;
1133
1134        // We remove the quads
1135        let snapshot = storage.snapshot();
1136        storage.transaction(|mut writer| {
1137            writer.remove(default_quad);
1138            writer.remove_named_graph(example.into());
1139            Ok::<_, StorageError>(())
1140        })?;
1141        assert!(snapshot.contains(&encoded_default_quad));
1142        assert!(snapshot.contains(&encoded_named_graph_quad));
1143        assert!(snapshot.contains_named_graph(&encoded_example));
1144        assert!(!storage.snapshot().contains(&encoded_default_quad));
1145        assert!(!storage.snapshot().contains(&encoded_named_graph_quad));
1146        assert!(!storage.snapshot().contains_named_graph(&encoded_example));
1147        storage.snapshot().validate()?;
1148
1149        // We add the quads again but rollback
1150        let snapshot = storage.snapshot();
1151        assert!(storage
1152            .transaction(|mut writer| {
1153                writer.insert(default_quad);
1154                writer.insert(named_graph_quad);
1155                writer.insert_named_graph(example2.into());
1156                Err::<(), _>(StorageError::Other("foo".into()))
1157            })
1158            .is_err());
1159        assert!(!snapshot.contains(&encoded_default_quad));
1160        assert!(!snapshot.contains(&encoded_named_graph_quad));
1161        assert!(!snapshot.contains_named_graph(&encoded_example));
1162        assert!(!snapshot.contains_named_graph(&encoded_example2));
1163        assert!(!storage.snapshot().contains(&encoded_default_quad));
1164        assert!(!storage.snapshot().contains(&encoded_named_graph_quad));
1165        assert!(!storage.snapshot().contains_named_graph(&encoded_example));
1166        assert!(!storage.snapshot().contains_named_graph(&encoded_example2));
1167        storage.snapshot().validate()?;
1168
1169        // We add quads and graph, then clear
1170        storage.bulk_loader().load::<StorageError, StorageError>([
1171            Ok(default_quad.into_owned()),
1172            Ok(named_graph_quad.into_owned()),
1173        ])?;
1174        storage.transaction(|mut writer| {
1175            writer.insert_named_graph(example2.into());
1176            Ok::<_, StorageError>(())
1177        })?;
1178        storage.transaction(|mut writer| {
1179            writer.clear();
1180            Ok::<_, StorageError>(())
1181        })?;
1182        assert!(!storage.snapshot().contains(&encoded_default_quad));
1183        assert!(!storage.snapshot().contains(&encoded_named_graph_quad));
1184        assert!(!storage.snapshot().contains_named_graph(&encoded_example));
1185        assert!(!storage.snapshot().contains_named_graph(&encoded_example2));
1186        assert!(storage.snapshot().is_empty());
1187        storage.snapshot().validate()?;
1188
1189        Ok(())
1190    }
1191}