1use std::borrow::Borrow;
41use std::collections::hash_map::{self, Entry, RandomState};
42use std::collections::HashMap;
43use std::convert::Infallible;
44use std::hash::{BuildHasher, Hash};
45use std::mem::{transmute, ManuallyDrop};
46use std::sync::{Mutex, MutexGuard};
47
48macro_rules! lock {
49 ($mutex:expr) => {
50 match $mutex.lock() {
51 Ok(guard) => guard,
52 Err(poisoned) => poisoned.into_inner(),
53 }
54 };
55}
56
57macro_rules! get_mut {
58 (let $target:ident, $mutex:expr) => {
59 let mut $target = $mutex.get_mut();
60 let $target = match $target {
61 Ok(guard) => guard,
62 Err(ref mut poisoned) => poisoned.get_mut(),
63 };
64 };
65}
66
67#[derive(Debug)]
69pub struct MemoMap<K, V, S = RandomState> {
70 inner: Mutex<HashMap<K, Box<V>, S>>,
71}
72
73impl<K: Clone, V: Clone, S: Clone> Clone for MemoMap<K, V, S> {
74 fn clone(&self) -> Self {
75 Self {
76 inner: Mutex::new(lock!(self.inner).clone()),
77 }
78 }
79}
80
81impl<K, V, S: Default> Default for MemoMap<K, V, S> {
82 fn default() -> Self {
83 MemoMap {
84 inner: Mutex::new(HashMap::default()),
85 }
86 }
87}
88
89impl<K, V> MemoMap<K, V, RandomState> {
90 pub fn new() -> MemoMap<K, V, RandomState> {
92 MemoMap {
93 inner: Mutex::default(),
94 }
95 }
96}
97
98impl<K, V, S> MemoMap<K, V, S> {
99 pub fn with_hasher(hash_builder: S) -> MemoMap<K, V, S> {
102 MemoMap {
103 inner: Mutex::new(HashMap::with_hasher(hash_builder)),
104 }
105 }
106}
107
108impl<K, V, S> MemoMap<K, V, S>
109where
110 K: Eq + Hash,
111 S: BuildHasher,
112{
113 pub fn insert(&self, key: K, value: V) -> bool {
121 let mut inner = lock!(self.inner);
122 match inner.entry(key) {
123 Entry::Occupied(_) => false,
124 Entry::Vacant(vacant) => {
125 vacant.insert(Box::new(value));
126 true
127 }
128 }
129 }
130
131 pub fn replace(&mut self, key: K, value: V) {
137 lock!(self.inner).insert(key, Box::new(value));
138 }
139
140 pub fn contains_key<Q>(&self, key: &Q) -> bool
145 where
146 Q: Hash + Eq + ?Sized,
147 K: Borrow<Q>,
148 {
149 lock!(self.inner).contains_key(key)
150 }
151
152 pub fn get<Q>(&self, key: &Q) -> Option<&V>
157 where
158 Q: Hash + Eq + ?Sized,
159 K: Borrow<Q>,
160 {
161 let inner = lock!(self.inner);
162 let value = inner.get(key)?;
163 Some(unsafe { transmute::<&V, &V>(&**value) })
164 }
165
166 pub fn get_mut<Q>(&mut self, key: &Q) -> Option<&mut V>
171 where
172 Q: Hash + Eq + ?Sized,
173 K: Borrow<Q>,
174 {
175 get_mut!(let map, self.inner);
176 Some(unsafe { transmute::<&mut V, &mut V>(&mut **map.get_mut(key)?) })
177 }
178
179 pub fn get_or_try_insert<Q, F, E>(&self, key: &Q, creator: F) -> Result<&V, E>
188 where
189 Q: Hash + Eq + ToOwned<Owned = K> + ?Sized,
190 K: Borrow<Q>,
191 F: FnOnce() -> Result<V, E>,
192 {
193 let mut inner = lock!(self.inner);
194 let value = if let Some(value) = inner.get(key) {
195 value
196 } else {
197 inner.insert(key.to_owned(), Box::new(creator()?));
198 inner.get(key).unwrap()
199 };
200 Ok(unsafe { transmute::<&V, &V>(&**value) })
201 }
202
203 pub fn get_or_insert_owned<F>(&self, key: K, creator: F) -> &V
205 where
206 F: FnOnce() -> V,
207 {
208 self.get_or_try_insert_owned(key, || Ok::<_, Infallible>(creator()))
209 .unwrap()
210 }
211
212 pub fn get_or_try_insert_owned<F, E>(&self, key: K, creator: F) -> Result<&V, E>
216 where
217 F: FnOnce() -> Result<V, E>,
218 {
219 let mut inner = lock!(self.inner);
220 let entry = inner.entry(key);
221 let value = match entry {
222 Entry::Occupied(ref val) => val.get(),
223 Entry::Vacant(entry) => entry.insert(Box::new(creator()?)),
224 };
225 Ok(unsafe { transmute::<&V, &V>(&**value) })
226 }
227
228 pub fn get_or_insert<Q, F>(&self, key: &Q, creator: F) -> &V
251 where
252 Q: Hash + Eq + ToOwned<Owned = K> + ?Sized,
253 K: Borrow<Q>,
254 F: FnOnce() -> V,
255 {
256 self.get_or_try_insert(key, || Ok::<_, Infallible>(creator()))
257 .unwrap()
258 }
259
260 pub fn remove<Q>(&mut self, key: &Q) -> Option<V>
266 where
267 Q: Hash + Eq + ?Sized,
268 K: Borrow<Q>,
269 {
270 lock!(self.inner).remove(key).map(|x| *x)
271 }
272
273 pub fn clear(&mut self) {
275 lock!(self.inner).clear();
276 }
277
278 pub fn len(&self) -> usize {
293 lock!(self.inner).len()
294 }
295
296 pub fn is_empty(&self) -> bool {
298 lock!(self.inner).is_empty()
299 }
300
301 pub fn iter(&self) -> Iter<'_, K, V, S> {
308 let guard = lock!(self.inner);
309 let iter = guard.iter();
310 Iter {
311 iter: unsafe {
312 transmute::<hash_map::Iter<'_, K, Box<V>>, hash_map::Iter<'_, K, Box<V>>>(iter)
313 },
314 guard: ManuallyDrop::new(guard),
315 }
316 }
317
318 pub fn iter_mut(&mut self) -> IterMut<'_, K, V> {
323 get_mut!(let map, self.inner);
324 IterMut {
325 iter: unsafe {
326 transmute::<hash_map::IterMut<'_, K, Box<V>>, hash_map::IterMut<'_, K, Box<V>>>(
327 map.iter_mut(),
328 )
329 },
330 }
331 }
332
333 pub fn values_mut(&mut self) -> ValuesMut<'_, K, V> {
338 get_mut!(let map, self.inner);
339 ValuesMut {
340 iter: unsafe {
341 transmute::<hash_map::ValuesMut<'_, K, Box<V>>, hash_map::ValuesMut<'_, K, Box<V>>>(
342 map.values_mut(),
343 )
344 },
345 }
346 }
347
348 pub fn keys(&self) -> Keys<'_, K, V, S> {
351 Keys { iter: self.iter() }
352 }
353}
354
355pub struct Iter<'a, K, V, S> {
360 guard: ManuallyDrop<MutexGuard<'a, HashMap<K, Box<V>, S>>>,
361 iter: hash_map::Iter<'a, K, Box<V>>,
362}
363
364impl<'a, K, V, S> Drop for Iter<'a, K, V, S> {
365 fn drop(&mut self) {
366 unsafe {
367 ManuallyDrop::drop(&mut self.guard);
368 }
369 }
370}
371
372impl<'a, K, V, S> Iterator for Iter<'a, K, V, S> {
373 type Item = (&'a K, &'a V);
374
375 fn next(&mut self) -> Option<Self::Item> {
376 self.iter.next().map(|(k, v)| (k, &**v))
377 }
378}
379
380pub struct Keys<'a, K, V, S> {
385 iter: Iter<'a, K, V, S>,
386}
387
388impl<'a, K, V, S> Iterator for Keys<'a, K, V, S> {
389 type Item = &'a K;
390
391 fn next(&mut self) -> Option<Self::Item> {
392 self.iter.next().map(|(k, _)| k)
393 }
394}
395
396pub struct IterMut<'a, K, V> {
398 iter: hash_map::IterMut<'a, K, Box<V>>,
399}
400
401impl<'a, K, V> Iterator for IterMut<'a, K, V> {
402 type Item = (&'a K, &'a mut V);
403
404 fn next(&mut self) -> Option<Self::Item> {
405 self.iter.next().map(|(k, v)| (k, &mut **v))
406 }
407}
408
409pub struct ValuesMut<'a, K, V> {
411 iter: hash_map::ValuesMut<'a, K, Box<V>>,
412}
413
414impl<'a, K, V> Iterator for ValuesMut<'a, K, V> {
415 type Item = &'a mut V;
416
417 fn next(&mut self) -> Option<Self::Item> {
418 self.iter.next().map(|v| &mut **v)
419 }
420}
421
422#[cfg(test)]
423mod tests {
424 use super::*;
425
426 #[test]
427 fn test_insert() {
428 let memo = MemoMap::new();
429 assert!(memo.insert(23u32, Box::new(1u32)));
430 assert!(!memo.insert(23u32, Box::new(2u32)));
431 assert_eq!(memo.get(&23u32).cloned(), Some(Box::new(1)));
432 }
433
434 #[test]
435 fn test_iter() {
436 let memo = MemoMap::new();
437 memo.insert(1, "one");
438 memo.insert(2, "two");
439 memo.insert(3, "three");
440 let mut values = memo.iter().map(|(k, v)| (*k, *v)).collect::<Vec<_>>();
441 values.sort();
442 assert_eq!(values, vec![(1, "one"), (2, "two"), (3, "three")]);
443 }
444
445 #[test]
446 fn test_keys() {
447 let memo = MemoMap::new();
448 memo.insert(1, "one");
449 memo.insert(2, "two");
450 memo.insert(3, "three");
451 let mut values = memo.keys().map(|k| *k).collect::<Vec<_>>();
452 values.sort();
453 assert_eq!(values, vec![1, 2, 3]);
454 }
455
456 #[test]
457 fn test_contains() {
458 let memo = MemoMap::new();
459 memo.insert(1, "one");
460 assert!(memo.contains_key(&1));
461 assert!(!memo.contains_key(&2));
462 }
463
464 #[test]
465 fn test_remove() {
466 let mut memo = MemoMap::new();
467 memo.insert(1, "one");
468 let value = memo.get(&1);
469 assert!(value.is_some());
470 let old_value = memo.remove(&1);
471 assert_eq!(old_value, Some("one"));
472 let value = memo.get(&1);
473 assert!(value.is_none());
474 }
475
476 #[test]
477 fn test_clear() {
478 let mut memo = MemoMap::new();
479 memo.insert(1, "one");
480 memo.insert(2, "two");
481 assert_eq!(memo.len(), 2);
482 assert!(!memo.is_empty());
483 memo.clear();
484 assert_eq!(memo.len(), 0);
485 assert!(memo.is_empty());
486 }
487
488 #[test]
489 fn test_ref_after_resize() {
490 let memo = MemoMap::new();
491 let mut refs = Vec::new();
492
493 let iterations = if cfg!(miri) { 100 } else { 10000 };
494
495 for key in 0..iterations {
496 refs.push((key, memo.get_or_insert(&key, || Box::new(key))));
497 }
498 for (key, val) in refs {
499 dbg!(key, val);
500 assert_eq!(memo.get(&key), Some(val));
501 }
502 }
503
504 #[test]
505 fn test_ref_after_resize_owned() {
506 let memo = MemoMap::new();
507 let mut refs = Vec::new();
508
509 let iterations = if cfg!(miri) { 100 } else { 10000 };
510
511 for key in 0..iterations {
512 refs.push((
513 key,
514 memo.get_or_insert_owned(key.to_string(), || Box::new(key)),
515 ));
516 }
517 for (key, val) in refs {
518 dbg!(key, val);
519 assert_eq!(memo.get(&key.to_string()), Some(val));
520 }
521 }
522
523 #[test]
524 fn test_replace() {
525 let mut memo = MemoMap::new();
526 memo.insert("foo", "bar");
527 memo.replace("foo", "bar2");
528 assert_eq!(memo.get("foo"), Some(&"bar2"));
529 }
530
531 #[test]
532 fn test_get_mut() {
533 let mut memo = MemoMap::new();
534 memo.insert("foo", "bar");
535 *memo.get_mut("foo").unwrap() = "bar2";
536 assert_eq!(memo.get("foo"), Some(&"bar2"));
537 }
538
539 #[test]
540 fn test_iter_mut() {
541 let mut memo = MemoMap::new();
542 memo.insert("foo", "bar");
543 for item in memo.iter_mut() {
544 *item.1 = "bar2";
545 }
546 assert_eq!(memo.get("foo"), Some(&"bar2"));
547 }
548
549 #[test]
550 fn test_values_mut() {
551 let mut memo = MemoMap::new();
552 memo.insert("foo", "bar");
553 for item in memo.values_mut() {
554 *item = "bar2";
555 }
556 assert_eq!(memo.get("foo"), Some(&"bar2"));
557 }
558}