1use std::{
22 collections::{hash_map::Entry, HashMap, HashSet},
23 fmt,
24 fmt::Debug,
25};
26
27use libp2p_identity::PeerId;
28
29use crate::{
30 topic::TopicHash,
31 types::{MessageId, RawMessage},
32};
33
34#[derive(Debug, Clone, PartialEq, Eq, Hash)]
36pub(crate) struct CacheEntry {
37 mid: MessageId,
38 topic: TopicHash,
39}
40
41#[derive(Clone)]
43pub(crate) struct MessageCache {
44 msgs: HashMap<MessageId, (RawMessage, HashSet<PeerId>)>,
45 iwant_counts: HashMap<MessageId, HashMap<PeerId, u32>>,
47 history: Vec<Vec<CacheEntry>>,
48 gossip: usize,
52}
53
54impl fmt::Debug for MessageCache {
55 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
56 f.debug_struct("MessageCache")
57 .field("msgs", &self.msgs)
58 .field("history", &self.history)
59 .field("gossip", &self.gossip)
60 .finish()
61 }
62}
63
64impl MessageCache {
66 pub(crate) fn new(gossip: usize, history_capacity: usize) -> Self {
67 MessageCache {
68 gossip,
69 msgs: HashMap::default(),
70 iwant_counts: HashMap::default(),
71 history: vec![Vec::new(); history_capacity],
72 }
73 }
74
75 pub(crate) fn put(&mut self, message_id: &MessageId, msg: RawMessage) -> bool {
79 match self.msgs.entry(message_id.clone()) {
80 Entry::Occupied(_) => {
81 false
83 }
84 Entry::Vacant(entry) => {
85 let cache_entry = CacheEntry {
86 mid: message_id.clone(),
87 topic: msg.topic.clone(),
88 };
89 entry.insert((msg, HashSet::default()));
90 self.history[0].push(cache_entry);
91
92 tracing::trace!(message=?message_id, "Put message in mcache");
93 true
94 }
95 }
96 }
97
98 pub(crate) fn observe_duplicate(&mut self, message_id: &MessageId, source: &PeerId) {
100 if let Some((message, originating_peers)) = self.msgs.get_mut(message_id) {
101 if message.validated {
104 return;
105 }
106
107 originating_peers.insert(*source);
108 }
109 }
110
111 #[cfg(test)]
113 pub(crate) fn get(&self, message_id: &MessageId) -> Option<&RawMessage> {
114 self.msgs.get(message_id).map(|(message, _)| message)
115 }
116
117 pub(crate) fn get_with_iwant_counts(
120 &mut self,
121 message_id: &MessageId,
122 peer: &PeerId,
123 ) -> Option<(&RawMessage, u32)> {
124 let iwant_counts = &mut self.iwant_counts;
125 self.msgs.get(message_id).and_then(|(message, _)| {
126 if !message.validated {
127 None
128 } else {
129 Some((message, {
130 let count = iwant_counts
131 .entry(message_id.clone())
132 .or_default()
133 .entry(*peer)
134 .or_default();
135 *count += 1;
136 *count
137 }))
138 }
139 })
140 }
141
142 pub(crate) fn validate(
146 &mut self,
147 message_id: &MessageId,
148 ) -> Option<(&RawMessage, HashSet<PeerId>)> {
149 self.msgs.get_mut(message_id).map(|(message, known_peers)| {
150 message.validated = true;
151 let originating_peers = std::mem::take(known_peers);
154 (&*message, originating_peers)
155 })
156 }
157
158 pub(crate) fn get_gossip_message_ids(&self, topic: &TopicHash) -> Vec<MessageId> {
160 self.history[..self.gossip]
161 .iter()
162 .fold(vec![], |mut current_entries, entries| {
163 let mut found_entries: Vec<MessageId> = entries
165 .iter()
166 .filter_map(|entry| {
167 if &entry.topic == topic {
168 let mid = &entry.mid;
169 if let Some(true) = self.msgs.get(mid).map(|(msg, _)| msg.validated) {
171 Some(mid.clone())
172 } else {
173 None
174 }
175 } else {
176 None
177 }
178 })
179 .collect();
180
181 current_entries.append(&mut found_entries);
183 current_entries
184 })
185 }
186
187 pub(crate) fn shift(&mut self) {
190 for entry in self.history.pop().expect("history is always > 1") {
191 if let Some((msg, _)) = self.msgs.remove(&entry.mid) {
192 if !msg.validated {
193 tracing::debug!(
197 message=%&entry.mid,
198 "The message got removed from the cache without being validated."
199 );
200 }
201 }
202 tracing::trace!(message=%&entry.mid, "Remove message from the cache");
203
204 self.iwant_counts.remove(&entry.mid);
205 }
206
207 self.history.insert(0, Vec::new());
209 }
210
211 pub(crate) fn remove(
213 &mut self,
214 message_id: &MessageId,
215 ) -> Option<(RawMessage, HashSet<PeerId>)> {
216 self.iwant_counts.remove(message_id);
220 self.msgs.remove(message_id)
221 }
222}
223
224#[cfg(test)]
225mod tests {
226 use super::*;
227 use crate::IdentTopic as Topic;
228
229 fn gen_testm(x: u64, topic: TopicHash) -> (MessageId, RawMessage) {
230 let default_id = |message: &RawMessage| {
231 let mut source_string = message.source.as_ref().unwrap().to_base58();
233 source_string.push_str(&message.sequence_number.unwrap().to_string());
234 MessageId::from(source_string)
235 };
236 let u8x: u8 = x as u8;
237 let source = Some(PeerId::random());
238 let data: Vec<u8> = vec![u8x];
239 let sequence_number = Some(x);
240
241 let m = RawMessage {
242 source,
243 data,
244 sequence_number,
245 topic,
246 signature: None,
247 key: None,
248 validated: false,
249 };
250
251 let id = default_id(&m);
252 (id, m)
253 }
254
255 fn new_cache(gossip_size: usize, history: usize) -> MessageCache {
256 MessageCache::new(gossip_size, history)
257 }
258
259 #[test]
260 fn test_new_cache() {
262 let x: usize = 3;
263 let mc = new_cache(x, 5);
264
265 assert_eq!(mc.gossip, x);
266 }
267
268 #[test]
269 fn test_put_get_one() {
271 let mut mc = new_cache(10, 15);
272
273 let topic1_hash = Topic::new("topic1").hash();
274 let (id, m) = gen_testm(10, topic1_hash);
275
276 mc.put(&id, m.clone());
277
278 assert_eq!(mc.history[0].len(), 1);
279
280 let fetched = mc.get(&id);
281
282 assert_eq!(fetched.unwrap(), &m);
283 }
284
285 #[test]
286 fn test_get_wrong() {
288 let mut mc = new_cache(10, 15);
289
290 let topic1_hash = Topic::new("topic1").hash();
291 let (id, m) = gen_testm(10, topic1_hash);
292
293 mc.put(&id, m);
294
295 let wrong_id = MessageId::new(b"wrongid");
297 let fetched = mc.get(&wrong_id);
298 assert!(fetched.is_none());
299 }
300
301 #[test]
302 fn test_get_empty() {
304 let mc = new_cache(10, 15);
305
306 let wrong_string = MessageId::new(b"imempty");
308 let fetched = mc.get(&wrong_string);
309 assert!(fetched.is_none());
310 }
311
312 #[test]
313 fn test_shift() {
315 let mut mc = new_cache(1, 5);
316
317 let topic1_hash = Topic::new("topic1").hash();
318
319 for i in 0..10 {
321 let (id, m) = gen_testm(i, topic1_hash.clone());
322 mc.put(&id, m.clone());
323 }
324
325 mc.shift();
326
327 assert!(mc.history[0].is_empty());
329 assert!(mc.history[1].len() == 10);
330
331 assert!(mc.msgs.len() == 10);
333 }
334
335 #[test]
336 fn test_empty_shift() {
338 let mut mc = new_cache(1, 5);
339
340 let topic1_hash = Topic::new("topic1").hash();
341
342 for i in 0..10 {
344 let (id, m) = gen_testm(i, topic1_hash.clone());
345 mc.put(&id, m.clone());
346 }
347
348 mc.shift();
349
350 assert!(mc.history[0].is_empty());
352 assert!(mc.history[1].len() == 10);
353
354 mc.shift();
355
356 assert!(mc.history[2].len() == 10);
357 assert!(mc.history[1].is_empty());
358 assert!(mc.history[0].is_empty());
359 }
360
361 #[test]
362 fn test_remove_last_from_shift() {
364 let mut mc = new_cache(4, 5);
365
366 let topic1_hash = Topic::new("topic1").hash();
367
368 for i in 0..10 {
370 let (id, m) = gen_testm(i, topic1_hash.clone());
371 mc.put(&id, m.clone());
372 }
373
374 mc.shift();
376 mc.shift();
377 mc.shift();
378 mc.shift();
379
380 assert_eq!(mc.history[mc.history.len() - 1].len(), 10);
381
382 mc.shift();
384 assert_eq!(mc.history[mc.history.len() - 1].len(), 0);
385 assert_eq!(mc.history[0].len(), 0);
386 assert_eq!(mc.msgs.len(), 0);
387 }
388}