1use std::{
22 collections::{hash_map::Entry, HashMap, HashSet},
23 fmt,
24 fmt::Debug,
25};
26
27use libp2p_identity::PeerId;
28
29use crate::{
30 topic::TopicHash,
31 types::{MessageId, RawMessage},
32};
33
34#[derive(Debug, Clone, PartialEq, Eq, Hash)]
36pub(crate) struct CacheEntry {
37 mid: MessageId,
38 topic: TopicHash,
39}
40
41#[derive(Clone)]
43pub(crate) struct MessageCache {
44 msgs: HashMap<MessageId, (RawMessage, HashSet<PeerId>)>,
45 iwant_counts: HashMap<MessageId, HashMap<PeerId, u32>>,
47 history: Vec<Vec<CacheEntry>>,
48 gossip: usize,
52}
53
54impl fmt::Debug for MessageCache {
55 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
56 f.debug_struct("MessageCache")
57 .field("msgs", &self.msgs)
58 .field("history", &self.history)
59 .field("gossip", &self.gossip)
60 .finish()
61 }
62}
63
64impl MessageCache {
66 pub(crate) fn new(gossip: usize, history_capacity: usize) -> Self {
67 MessageCache {
68 gossip,
69 msgs: HashMap::default(),
70 iwant_counts: HashMap::default(),
71 history: vec![Vec::new(); history_capacity],
72 }
73 }
74
75 pub(crate) fn put(&mut self, message_id: &MessageId, msg: RawMessage) -> bool {
79 if self.history.is_empty() {
80 return true;
81 }
82 match self.msgs.entry(message_id.clone()) {
83 Entry::Occupied(_) => {
84 false
86 }
87 Entry::Vacant(entry) => {
88 let cache_entry = CacheEntry {
89 mid: message_id.clone(),
90 topic: msg.topic.clone(),
91 };
92 entry.insert((msg, HashSet::default()));
93 self.history[0].push(cache_entry);
94
95 tracing::trace!(message=?message_id, "Put message in mcache");
96 true
97 }
98 }
99 }
100
101 pub(crate) fn observe_duplicate(&mut self, message_id: &MessageId, source: &PeerId) {
103 if let Some((message, originating_peers)) = self.msgs.get_mut(message_id) {
104 if message.validated {
107 return;
108 }
109
110 originating_peers.insert(*source);
111 }
112 }
113
114 #[cfg(test)]
116 pub(crate) fn get(&self, message_id: &MessageId) -> Option<&RawMessage> {
117 self.msgs.get(message_id).map(|(message, _)| message)
118 }
119
120 pub(crate) fn get_with_iwant_counts(
123 &mut self,
124 message_id: &MessageId,
125 peer: &PeerId,
126 ) -> Option<(&RawMessage, u32)> {
127 let iwant_counts = &mut self.iwant_counts;
128 self.msgs.get(message_id).and_then(|(message, _)| {
129 if !message.validated {
130 None
131 } else {
132 Some((message, {
133 let count = iwant_counts
134 .entry(message_id.clone())
135 .or_default()
136 .entry(*peer)
137 .or_default();
138 *count += 1;
139 *count
140 }))
141 }
142 })
143 }
144
145 pub(crate) fn validate(
149 &mut self,
150 message_id: &MessageId,
151 ) -> Option<(&RawMessage, HashSet<PeerId>)> {
152 self.msgs.get_mut(message_id).map(|(message, known_peers)| {
153 message.validated = true;
154 let originating_peers = std::mem::take(known_peers);
157 (&*message, originating_peers)
158 })
159 }
160
161 pub(crate) fn get_gossip_message_ids(&self, topic: &TopicHash) -> Vec<MessageId> {
163 self.history[..self.gossip]
164 .iter()
165 .fold(vec![], |mut current_entries, entries| {
166 let mut found_entries: Vec<MessageId> = entries
168 .iter()
169 .filter_map(|entry| {
170 if &entry.topic == topic {
171 let mid = &entry.mid;
172 if let Some(true) = self.msgs.get(mid).map(|(msg, _)| msg.validated) {
174 Some(mid.clone())
175 } else {
176 None
177 }
178 } else {
179 None
180 }
181 })
182 .collect();
183
184 current_entries.append(&mut found_entries);
186 current_entries
187 })
188 }
189
190 pub(crate) fn shift(&mut self) {
193 if self.history.is_empty() {
194 return;
195 }
196
197 for entry in self.history.pop().expect("history is always > 1") {
198 if let Some((msg, _)) = self.msgs.remove(&entry.mid) {
199 if !msg.validated {
200 tracing::debug!(
204 message=%&entry.mid,
205 "The message got removed from the cache without being validated."
206 );
207 }
208 }
209 tracing::trace!(message=%&entry.mid, "Remove message from the cache");
210
211 self.iwant_counts.remove(&entry.mid);
212 }
213
214 self.history.insert(0, Vec::new());
216 }
217
218 pub(crate) fn remove(
220 &mut self,
221 message_id: &MessageId,
222 ) -> Option<(RawMessage, HashSet<PeerId>)> {
223 self.iwant_counts.remove(message_id);
227 self.msgs.remove(message_id)
228 }
229}
230
231#[cfg(test)]
232mod tests {
233 use super::*;
234 use crate::IdentTopic as Topic;
235
236 fn gen_testm(x: u64, topic: TopicHash) -> (MessageId, RawMessage) {
237 let default_id = |message: &RawMessage| {
238 let mut source_string = message.source.as_ref().unwrap().to_base58();
240 source_string.push_str(&message.sequence_number.unwrap().to_string());
241 MessageId::from(source_string)
242 };
243 let u8x: u8 = x as u8;
244 let source = Some(PeerId::random());
245 let data: Vec<u8> = vec![u8x];
246 let sequence_number = Some(x);
247
248 let m = RawMessage {
249 source,
250 data,
251 sequence_number,
252 topic,
253 signature: None,
254 key: None,
255 validated: false,
256 };
257
258 let id = default_id(&m);
259 (id, m)
260 }
261
262 fn new_cache(gossip_size: usize, history: usize) -> MessageCache {
263 MessageCache::new(gossip_size, history)
264 }
265
266 #[test]
267 fn test_new_cache() {
269 let x: usize = 3;
270 let mc = new_cache(x, 5);
271
272 assert_eq!(mc.gossip, x);
273 }
274
275 #[test]
276 fn test_put_get_one() {
278 let mut mc = new_cache(10, 15);
279
280 let topic1_hash = Topic::new("topic1").hash();
281 let (id, m) = gen_testm(10, topic1_hash);
282
283 mc.put(&id, m.clone());
284
285 assert_eq!(mc.history[0].len(), 1);
286
287 let fetched = mc.get(&id);
288
289 assert_eq!(fetched.unwrap(), &m);
290 }
291
292 #[test]
293 fn test_get_wrong() {
295 let mut mc = new_cache(10, 15);
296
297 let topic1_hash = Topic::new("topic1").hash();
298 let (id, m) = gen_testm(10, topic1_hash);
299
300 mc.put(&id, m);
301
302 let wrong_id = MessageId::new(b"wrongid");
304 let fetched = mc.get(&wrong_id);
305 assert!(fetched.is_none());
306 }
307
308 #[test]
309 fn test_get_empty() {
311 let mc = new_cache(10, 15);
312
313 let wrong_string = MessageId::new(b"imempty");
315 let fetched = mc.get(&wrong_string);
316 assert!(fetched.is_none());
317 }
318
319 #[test]
320 fn test_shift() {
322 let mut mc = new_cache(1, 5);
323
324 let topic1_hash = Topic::new("topic1").hash();
325
326 for i in 0..10 {
328 let (id, m) = gen_testm(i, topic1_hash.clone());
329 mc.put(&id, m.clone());
330 }
331
332 mc.shift();
333
334 assert!(mc.history[0].is_empty());
336 assert!(mc.history[1].len() == 10);
337
338 assert!(mc.msgs.len() == 10);
340 }
341
342 #[test]
343 fn test_empty_shift() {
345 let mut mc = new_cache(1, 5);
346
347 let topic1_hash = Topic::new("topic1").hash();
348
349 for i in 0..10 {
351 let (id, m) = gen_testm(i, topic1_hash.clone());
352 mc.put(&id, m.clone());
353 }
354
355 mc.shift();
356
357 assert!(mc.history[0].is_empty());
359 assert!(mc.history[1].len() == 10);
360
361 mc.shift();
362
363 assert!(mc.history[2].len() == 10);
364 assert!(mc.history[1].is_empty());
365 assert!(mc.history[0].is_empty());
366 }
367
368 #[test]
369 fn test_remove_last_from_shift() {
371 let mut mc = new_cache(4, 5);
372
373 let topic1_hash = Topic::new("topic1").hash();
374
375 for i in 0..10 {
377 let (id, m) = gen_testm(i, topic1_hash.clone());
378 mc.put(&id, m.clone());
379 }
380
381 mc.shift();
383 mc.shift();
384 mc.shift();
385 mc.shift();
386
387 assert_eq!(mc.history[mc.history.len() - 1].len(), 10);
388
389 mc.shift();
391 assert_eq!(mc.history[mc.history.len() - 1].len(), 0);
392 assert_eq!(mc.history[0].len(), 0);
393 assert_eq!(mc.msgs.len(), 0);
394 }
395}