1use hashbag::{HashBag, SetIter};
4use serde::{de::SeqAccess, ser::SerializeSeq, Deserialize, Deserializer, Serialize, Serializer};
5use std::{
6 fmt::{self, Debug, Display},
7 hash::{Hash, Hasher},
8 marker::PhantomData,
9};
10
11#[derive(Clone, PartialEq, Eq, Default)]
12pub struct Bag<T>
13where
14 T: Hash + Eq + PartialEq,
15{
16 bag: HashBag<T>,
17}
18
19impl<T: Hash + Eq> Bag<T> {
20 #[inline]
21 pub fn new() -> Bag<T> {
22 Bag {
23 bag: HashBag::new(),
24 }
25 }
26
27 pub fn insert(&mut self, value: T) -> usize {
28 self.bag.insert(value)
29 }
30
31 pub fn insert_many(&mut self, value: T, n: usize) -> usize {
32 self.bag.insert_many(value, n)
33 }
34
35 pub fn contains(&self, value: &T) -> usize {
36 self.bag.contains(value)
37 }
38
39 pub fn len(&self) -> usize {
40 self.bag.len()
41 }
42
43 pub fn is_empty(&self) -> bool {
44 self.bag.is_empty()
45 }
46
47 pub fn iter(&self) -> SetIter<'_, T> {
48 self.bag.set_iter()
49 }
50}
51
52impl<T: Hash + Eq + PartialEq> Display for Bag<T>
53where
54 T: Display,
55{
56 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
57 let v: Vec<String> = self
58 .bag
59 .set_iter()
60 .map(|(t, n)| format!("{}/{}", t, n))
61 .collect();
62 write!(f, "Bag [{}]", v.join(", "))
63 }
64}
65
66impl<T> Debug for Bag<T>
73where
74 T: Hash + Eq + Debug,
75{
76 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
77 let v: Vec<String> = self
78 .bag
79 .set_iter()
80 .map(|(t, n)| format!("{:?}/{}", t, n))
81 .collect();
82 write!(f, "Bag [{}]", v.join(", "))
83 }
84}
85
86impl<T> FromIterator<T> for Bag<T>
87where
88 T: Eq + Hash,
89{
90 fn from_iter<I: IntoIterator<Item = T>>(iter: I) -> Self {
91 let mut bag = Bag::new();
92 for t in iter {
93 bag.insert(t);
94 }
95 bag
96 }
97}
98
99impl<T, const N: usize> From<[T; N]> for Bag<T>
100where
101 T: Eq + Hash,
102{
103 fn from(arr: [T; N]) -> Self {
104 let mut bag = Bag::new();
105 for x in arr {
106 bag.insert(x);
107 }
108 bag
109 }
110}
111
112impl<T> Serialize for Bag<T>
113where
114 T: Hash + Eq + Serialize,
115{
116 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
117 where
118 S: Serializer,
119 {
120 let mut bag = serializer.serialize_seq(Some(self.len()))?;
121
122 for (entry, count) in self.iter() {
123 bag.serialize_element(&(entry, count))?;
124 }
125
126 bag.end()
127 }
128}
129
130impl<'de, T> Deserialize<'de> for Bag<T>
131where
132 T: Deserialize<'de> + Eq + Hash,
133{
134 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
135 where
136 D: Deserializer<'de>,
137 {
138 deserializer.deserialize_seq(BagVisitor::new())
139 }
140}
141
142use serde::de::Visitor;
143struct BagVisitor<T>
144where
145 T: Hash + Eq,
146{
147 marker: PhantomData<fn() -> Bag<T>>,
148}
149
150impl<T> BagVisitor<T>
151where
152 T: Hash + Eq,
153{
154 fn new() -> Self {
155 BagVisitor {
156 marker: PhantomData,
157 }
158 }
159}
160
161impl<'de, T> Visitor<'de> for BagVisitor<T>
162where
163 T: Hash + Eq + Deserialize<'de>,
164{
165 type Value = Bag<T>;
166
167 fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
168 formatter.write_str("a Bag")
169 }
170
171 fn visit_seq<M>(self, mut access: M) -> Result<Self::Value, M::Error>
172 where
173 M: SeqAccess<'de>,
174 {
175 let mut bag: Bag<T> = Bag::new();
176 while let Some(entry) = access.next_element::<(T, usize)>()? {
177 let (t, n) = entry;
178 bag.insert_many(t, n);
179 }
180 Ok(bag)
181 }
182}
183
184impl<T: Hash + Eq> Hash for Bag<T> {
186 fn hash<H: Hasher>(&self, hasher: &mut H) {
187 let vec = Vec::from_iter(self.bag.iter());
188 vec.hash(hasher)
189 }
190}
191
192#[cfg(test)]
193mod tests {
194 use super::*;
195
196 #[test]
197 fn bag_test() {
198 let mut bag = Bag::new();
199 bag.insert("a");
200 bag.insert("b");
201 bag.insert("b");
202 assert_eq!(bag.contains(&"b"), 2);
203 }
204
205 #[test]
206 fn deser_test() {
207 let str = r#"[ ["a",2],["b",2],["a",1]]"#;
208 let bag: Bag<char> = serde_json::from_str(str).unwrap();
209 assert_eq!(bag, Bag::from(['a', 'a', 'a', 'b', 'b']));
210 }
211
212 #[test]
213 fn bag_from_iter() {
214 let bag = Bag::from_iter(vec!['a', 'b', 'a']);
215 assert_eq!(bag.contains(&'a'), 2);
216 assert_eq!(bag.contains(&'b'), 1);
217 assert_eq!(bag.contains(&'c'), 0);
218 }
219}