1use std::{
22 collections::{HashMap, HashSet},
23 iter,
24 task::{ready, Context, Poll},
25 time::Duration,
26};
27
28use bimap::BiMap;
29use futures::{future::BoxFuture, stream::FuturesUnordered, FutureExt, StreamExt};
30use libp2p_core::{transport::PortUse, Endpoint, Multiaddr};
31use libp2p_identity::PeerId;
32use libp2p_request_response::ProtocolSupport;
33use libp2p_swarm::{
34 behaviour::FromSwarm, ConnectionDenied, ConnectionId, NetworkBehaviour, THandler,
35 THandlerInEvent, THandlerOutEvent, ToSwarm,
36};
37
38use crate::{
39 codec::{Cookie, ErrorCode, Message, Namespace, NewRegistration, Registration, Ttl},
40 MAX_TTL, MIN_TTL,
41};
42
43pub struct Behaviour {
44 inner: libp2p_request_response::Behaviour<crate::codec::Codec>,
45
46 registrations: Registrations,
47}
48
49pub struct Config {
50 min_ttl: Ttl,
51 max_ttl: Ttl,
52}
53
54impl Config {
55 pub fn with_min_ttl(mut self, min_ttl: Ttl) -> Self {
56 self.min_ttl = min_ttl;
57 self
58 }
59
60 pub fn with_max_ttl(mut self, max_ttl: Ttl) -> Self {
61 self.max_ttl = max_ttl;
62 self
63 }
64}
65
66impl Default for Config {
67 fn default() -> Self {
68 Self {
69 min_ttl: MIN_TTL,
70 max_ttl: MAX_TTL,
71 }
72 }
73}
74
75impl Behaviour {
76 pub fn new(config: Config) -> Self {
78 Self {
79 inner: libp2p_request_response::Behaviour::with_codec(
80 crate::codec::Codec::default(),
81 iter::once((crate::PROTOCOL_IDENT, ProtocolSupport::Inbound)),
82 libp2p_request_response::Config::default(),
83 ),
84
85 registrations: Registrations::with_config(config),
86 }
87 }
88}
89
90#[derive(Debug)]
91#[allow(clippy::large_enum_variant)]
92pub enum Event {
93 DiscoverServed {
95 enquirer: PeerId,
96 registrations: Vec<Registration>,
97 },
98 DiscoverNotServed { enquirer: PeerId, error: ErrorCode },
100 PeerRegistered {
102 peer: PeerId,
103 registration: Registration,
104 },
105 PeerNotRegistered {
107 peer: PeerId,
108 namespace: Namespace,
109 error: ErrorCode,
110 },
111 PeerUnregistered { peer: PeerId, namespace: Namespace },
113 RegistrationExpired(Registration),
115}
116
117impl NetworkBehaviour for Behaviour {
118 type ConnectionHandler = <libp2p_request_response::Behaviour<
119 crate::codec::Codec,
120 > as NetworkBehaviour>::ConnectionHandler;
121
122 type ToSwarm = Event;
123
124 fn handle_established_inbound_connection(
125 &mut self,
126 connection_id: ConnectionId,
127 peer: PeerId,
128 local_addr: &Multiaddr,
129 remote_addr: &Multiaddr,
130 ) -> Result<THandler<Self>, ConnectionDenied> {
131 self.inner.handle_established_inbound_connection(
132 connection_id,
133 peer,
134 local_addr,
135 remote_addr,
136 )
137 }
138
139 fn handle_established_outbound_connection(
140 &mut self,
141 connection_id: ConnectionId,
142 peer: PeerId,
143 addr: &Multiaddr,
144 role_override: Endpoint,
145 port_use: PortUse,
146 ) -> Result<THandler<Self>, ConnectionDenied> {
147 self.inner.handle_established_outbound_connection(
148 connection_id,
149 peer,
150 addr,
151 role_override,
152 port_use,
153 )
154 }
155
156 fn on_connection_handler_event(
157 &mut self,
158 peer_id: PeerId,
159 connection: ConnectionId,
160 event: THandlerOutEvent<Self>,
161 ) {
162 self.inner
163 .on_connection_handler_event(peer_id, connection, event);
164 }
165
166 #[tracing::instrument(level = "trace", name = "NetworkBehaviour::poll", skip(self, cx))]
167 fn poll(
168 &mut self,
169 cx: &mut Context<'_>,
170 ) -> Poll<ToSwarm<Self::ToSwarm, THandlerInEvent<Self>>> {
171 if let Poll::Ready(ExpiredRegistration(registration)) = self.registrations.poll(cx) {
172 return Poll::Ready(ToSwarm::GenerateEvent(Event::RegistrationExpired(
173 registration,
174 )));
175 }
176
177 loop {
178 if let Poll::Ready(to_swarm) = self.inner.poll(cx) {
179 match to_swarm {
180 ToSwarm::GenerateEvent(libp2p_request_response::Event::Message {
181 peer: peer_id,
182 message:
183 libp2p_request_response::Message::Request {
184 request, channel, ..
185 },
186 ..
187 }) => {
188 if let Some((event, response)) =
189 handle_request(peer_id, request, &mut self.registrations)
190 {
191 if let Some(resp) = response {
192 if let Err(resp) = self.inner.send_response(channel, resp) {
193 tracing::debug!(
194 %peer_id,
195 "Failed to send response, peer disconnected {resp:?}"
196 );
197 }
198 }
199
200 return Poll::Ready(ToSwarm::GenerateEvent(event));
201 }
202
203 continue;
204 }
205 ToSwarm::GenerateEvent(libp2p_request_response::Event::InboundFailure {
206 peer,
207 request_id,
208 error,
209 ..
210 }) => {
211 tracing::warn!(
212 %peer,
213 request=%request_id,
214 "Inbound request with peer failed: {error}"
215 );
216
217 continue;
218 }
219 ToSwarm::GenerateEvent(libp2p_request_response::Event::ResponseSent {
220 ..
221 })
222 | ToSwarm::GenerateEvent(libp2p_request_response::Event::Message {
223 peer: _,
224 message: libp2p_request_response::Message::Response { .. },
225 ..
226 })
227 | ToSwarm::GenerateEvent(libp2p_request_response::Event::OutboundFailure {
228 ..
229 }) => {
230 continue;
231 }
232 other => {
233 let new_to_swarm = other
234 .map_out(|_| unreachable!("we manually map `GenerateEvent` variants"));
235
236 return Poll::Ready(new_to_swarm);
237 }
238 };
239 }
240
241 return Poll::Pending;
242 }
243 }
244
245 fn on_swarm_event(&mut self, event: FromSwarm) {
246 self.inner.on_swarm_event(event);
247 }
248}
249
250fn handle_request(
251 peer_id: PeerId,
252 message: Message,
253 registrations: &mut Registrations,
254) -> Option<(Event, Option<Message>)> {
255 match message {
256 Message::Register(registration) => {
257 if registration.record.peer_id() != peer_id {
258 let error = ErrorCode::NotAuthorized;
259
260 let event = Event::PeerNotRegistered {
261 peer: peer_id,
262 namespace: registration.namespace,
263 error,
264 };
265
266 return Some((event, Some(Message::RegisterResponse(Err(error)))));
267 }
268
269 let namespace = registration.namespace.clone();
270
271 match registrations.add(registration) {
272 Ok(registration) => {
273 let response = Message::RegisterResponse(Ok(registration.ttl));
274
275 let event = Event::PeerRegistered {
276 peer: peer_id,
277 registration,
278 };
279
280 Some((event, Some(response)))
281 }
282 Err(TtlOutOfRange::TooLong { .. }) | Err(TtlOutOfRange::TooShort { .. }) => {
283 let error = ErrorCode::InvalidTtl;
284
285 let response = Message::RegisterResponse(Err(error));
286
287 let event = Event::PeerNotRegistered {
288 peer: peer_id,
289 namespace,
290 error,
291 };
292
293 Some((event, Some(response)))
294 }
295 }
296 }
297 Message::Unregister(namespace) => {
298 registrations.remove(namespace.clone(), peer_id);
299
300 let event = Event::PeerUnregistered {
301 peer: peer_id,
302 namespace,
303 };
304
305 Some((event, None))
306 }
307 Message::Discover {
308 namespace,
309 cookie,
310 limit,
311 } => match registrations.get(namespace, cookie, limit) {
312 Ok((registrations, cookie)) => {
313 let discovered = registrations.cloned().collect::<Vec<_>>();
314
315 let response = Message::DiscoverResponse(Ok((discovered.clone(), cookie)));
316
317 let event = Event::DiscoverServed {
318 enquirer: peer_id,
319 registrations: discovered,
320 };
321
322 Some((event, Some(response)))
323 }
324 Err(_) => {
325 let error = ErrorCode::InvalidCookie;
326
327 let response = Message::DiscoverResponse(Err(error));
328
329 let event = Event::DiscoverNotServed {
330 enquirer: peer_id,
331 error,
332 };
333
334 Some((event, Some(response)))
335 }
336 },
337 Message::RegisterResponse(_) => None,
338 Message::DiscoverResponse(_) => None,
339 }
340}
341
342#[derive(Debug, Eq, PartialEq, Hash, Copy, Clone)]
343struct RegistrationId(u64);
344
345impl RegistrationId {
346 fn new() -> Self {
347 Self(rand::random())
348 }
349}
350
351#[derive(Debug, PartialEq)]
352struct ExpiredRegistration(Registration);
353
354pub struct Registrations {
355 registrations_for_peer: BiMap<(PeerId, Namespace), RegistrationId>,
356 registrations: HashMap<RegistrationId, Registration>,
357 cookies: HashMap<Cookie, HashSet<RegistrationId>>,
358 min_ttl: Ttl,
359 max_ttl: Ttl,
360 next_expiry: FuturesUnordered<BoxFuture<'static, RegistrationId>>,
361}
362
363#[derive(Debug, thiserror::Error)]
364pub enum TtlOutOfRange {
365 #[error("Requested TTL ({requested}s) is too long; max {bound}s")]
366 TooLong { bound: Ttl, requested: Ttl },
367 #[error("Requested TTL ({requested}s) is too short; min {bound}s")]
368 TooShort { bound: Ttl, requested: Ttl },
369}
370
371impl Default for Registrations {
372 fn default() -> Self {
373 Registrations::with_config(Config::default())
374 }
375}
376
377impl Registrations {
378 pub fn with_config(config: Config) -> Self {
379 Self {
380 registrations_for_peer: Default::default(),
381 registrations: Default::default(),
382 min_ttl: config.min_ttl,
383 max_ttl: config.max_ttl,
384 cookies: Default::default(),
385 next_expiry: FuturesUnordered::from_iter(vec![futures::future::pending().boxed()]),
386 }
387 }
388
389 pub fn add(
390 &mut self,
391 new_registration: NewRegistration,
392 ) -> Result<Registration, TtlOutOfRange> {
393 let ttl = new_registration.effective_ttl();
394 if ttl > self.max_ttl {
395 return Err(TtlOutOfRange::TooLong {
396 bound: self.max_ttl,
397 requested: ttl,
398 });
399 }
400 if ttl < self.min_ttl {
401 return Err(TtlOutOfRange::TooShort {
402 bound: self.min_ttl,
403 requested: ttl,
404 });
405 }
406
407 let namespace = new_registration.namespace;
408 let registration_id = RegistrationId::new();
409
410 if let Some(old_registration) = self
411 .registrations_for_peer
412 .get_by_left(&(new_registration.record.peer_id(), namespace.clone()))
413 {
414 self.registrations.remove(old_registration);
415 }
416
417 self.registrations_for_peer.insert(
418 (new_registration.record.peer_id(), namespace.clone()),
419 registration_id,
420 );
421
422 let registration = Registration {
423 namespace,
424 record: new_registration.record,
425 ttl,
426 };
427 self.registrations
428 .insert(registration_id, registration.clone());
429
430 let next_expiry = futures_timer::Delay::new(Duration::from_secs(ttl))
431 .map(move |_| registration_id)
432 .boxed();
433
434 self.next_expiry.push(next_expiry);
435
436 Ok(registration)
437 }
438
439 pub fn remove(&mut self, namespace: Namespace, peer_id: PeerId) {
440 let reggo_to_remove = self
441 .registrations_for_peer
442 .remove_by_left(&(peer_id, namespace));
443
444 if let Some((_, reggo_to_remove)) = reggo_to_remove {
445 self.registrations.remove(®go_to_remove);
446 }
447 }
448
449 pub fn get(
450 &mut self,
451 discover_namespace: Option<Namespace>,
452 cookie: Option<Cookie>,
453 limit: Option<u64>,
454 ) -> Result<(impl Iterator<Item = &Registration> + '_, Cookie), CookieNamespaceMismatch> {
455 let cookie_namespace = cookie.as_ref().and_then(|cookie| cookie.namespace());
456
457 match (discover_namespace.as_ref(), cookie_namespace) {
458 (None, Some(_)) => return Err(CookieNamespaceMismatch),
460 (Some(namespace), Some(cookie_namespace)) if namespace != cookie_namespace => {
462 return Err(CookieNamespaceMismatch)
463 }
464 _ => {}
466 }
467
468 let mut reggos_of_last_discover = cookie
469 .and_then(|cookie| self.cookies.get(&cookie))
470 .cloned()
471 .unwrap_or_default();
472
473 let ids = self
474 .registrations_for_peer
475 .iter()
476 .filter_map({
477 |((_, namespace), registration_id)| {
478 if reggos_of_last_discover.contains(registration_id) {
479 return None;
480 }
481
482 match discover_namespace.as_ref() {
483 Some(discover_namespace) if discover_namespace == namespace => {
484 Some(registration_id)
485 }
486 Some(_) => None,
487 None => Some(registration_id),
488 }
489 }
490 })
491 .take(limit.unwrap_or(u64::MAX) as usize)
492 .cloned()
493 .collect::<Vec<_>>();
494
495 reggos_of_last_discover.extend(&ids);
496
497 let new_cookie = discover_namespace
498 .map(Cookie::for_namespace)
499 .unwrap_or_else(Cookie::for_all_namespaces);
500 self.cookies
501 .insert(new_cookie.clone(), reggos_of_last_discover);
502
503 let regs = &self.registrations;
504 let registrations = ids
505 .into_iter()
506 .map(move |id| regs.get(&id).expect("bad internal data structure"));
507
508 Ok((registrations, new_cookie))
509 }
510
511 fn poll(&mut self, cx: &mut Context<'_>) -> Poll<ExpiredRegistration> {
512 loop {
513 let expired_registration = ready!(self.next_expiry.poll_next_unpin(cx)).expect(
514 "This stream should never finish because it is initialised with a pending future",
515 );
516
517 self.cookies.retain(|_, registrations| {
519 registrations.remove(&expired_registration);
520
521 !registrations.is_empty()
523 });
524
525 self.registrations_for_peer
526 .remove_by_right(&expired_registration);
527 match self.registrations.remove(&expired_registration) {
528 None => {
529 continue;
530 }
531 Some(registration) => {
532 return Poll::Ready(ExpiredRegistration(registration));
533 }
534 }
535 }
536 }
537}
538
539#[derive(Debug, thiserror::Error, Eq, PartialEq)]
540#[error("The provided cookie is not valid for a DISCOVER request for the given namespace")]
541pub struct CookieNamespaceMismatch;
542
543#[cfg(test)]
544mod tests {
545 use libp2p_core::PeerRecord;
546 use libp2p_identity as identity;
547 use web_time::SystemTime;
548
549 use super::*;
550
551 #[test]
552 fn given_cookie_from_discover_when_discover_again_then_only_get_diff() {
553 let mut registrations = Registrations::default();
554 registrations.add(new_dummy_registration("foo")).unwrap();
555 registrations.add(new_dummy_registration("foo")).unwrap();
556
557 let (initial_discover, cookie) = registrations.get(None, None, None).unwrap();
558 assert_eq!(initial_discover.count(), 2);
559
560 let (subsequent_discover, _) = registrations.get(None, Some(cookie), None).unwrap();
561 assert_eq!(subsequent_discover.count(), 0);
562 }
563
564 #[test]
565 fn given_registrations_when_discover_all_then_all_are_returned() {
566 let mut registrations = Registrations::default();
567 registrations.add(new_dummy_registration("foo")).unwrap();
568 registrations.add(new_dummy_registration("foo")).unwrap();
569
570 let (discover, _) = registrations.get(None, None, None).unwrap();
571
572 assert_eq!(discover.count(), 2);
573 }
574
575 #[test]
576 fn given_registrations_when_discover_only_for_specific_namespace_then_only_those_are_returned()
577 {
578 let mut registrations = Registrations::default();
579 registrations.add(new_dummy_registration("foo")).unwrap();
580 registrations.add(new_dummy_registration("bar")).unwrap();
581
582 let (discover, _) = registrations
583 .get(Some(Namespace::from_static("foo")), None, None)
584 .unwrap();
585
586 assert_eq!(
587 discover.map(|r| &r.namespace).collect::<Vec<_>>(),
588 vec!["foo"]
589 );
590 }
591
592 #[test]
593 fn given_reregistration_old_registration_is_discarded() {
594 let alice = identity::Keypair::generate_ed25519();
595 let mut registrations = Registrations::default();
596 registrations
597 .add(new_registration("foo", alice.clone(), None))
598 .unwrap();
599 registrations
600 .add(new_registration("foo", alice, None))
601 .unwrap();
602
603 let (discover, _) = registrations
604 .get(Some(Namespace::from_static("foo")), None, None)
605 .unwrap();
606
607 assert_eq!(
608 discover.map(|r| &r.namespace).collect::<Vec<_>>(),
609 vec!["foo"]
610 );
611 }
612
613 #[test]
614 fn given_cookie_from_2nd_discover_does_not_return_nodes_from_first_discover() {
615 let mut registrations = Registrations::default();
616 registrations.add(new_dummy_registration("foo")).unwrap();
617 registrations.add(new_dummy_registration("foo")).unwrap();
618
619 let (initial_discover, cookie1) = registrations.get(None, None, None).unwrap();
620 assert_eq!(initial_discover.count(), 2);
621
622 let (subsequent_discover, cookie2) = registrations.get(None, Some(cookie1), None).unwrap();
623 assert_eq!(subsequent_discover.count(), 0);
624
625 let (subsequent_discover, _) = registrations.get(None, Some(cookie2), None).unwrap();
626 assert_eq!(subsequent_discover.count(), 0);
627 }
628
629 #[test]
630 fn cookie_from_different_discover_request_is_not_valid() {
631 let mut registrations = Registrations::default();
632 registrations.add(new_dummy_registration("foo")).unwrap();
633 registrations.add(new_dummy_registration("bar")).unwrap();
634
635 let (_, foo_discover_cookie) = registrations
636 .get(Some(Namespace::from_static("foo")), None, None)
637 .unwrap();
638 let result = registrations.get(
639 Some(Namespace::from_static("bar")),
640 Some(foo_discover_cookie),
641 None,
642 );
643
644 assert!(matches!(result, Err(CookieNamespaceMismatch)))
645 }
646
647 #[tokio::test]
648 async fn given_two_registration_ttls_one_expires_one_lives() {
649 let mut registrations = Registrations::with_config(Config {
650 min_ttl: 0,
651 max_ttl: 4,
652 });
653
654 let start_time = SystemTime::now();
655
656 registrations
657 .add(new_dummy_registration_with_ttl("foo", 1))
658 .unwrap();
659 registrations
660 .add(new_dummy_registration_with_ttl("bar", 4))
661 .unwrap();
662
663 let event = registrations.next_event().await;
664
665 let elapsed = start_time.elapsed().unwrap();
666 assert!(elapsed.as_secs() >= 1);
667 assert!(elapsed.as_secs() < 2);
668
669 assert_eq!(event.0.namespace, Namespace::from_static("foo"));
670
671 {
672 let (mut discovered_foo, _) = registrations
673 .get(Some(Namespace::from_static("foo")), None, None)
674 .unwrap();
675 assert!(discovered_foo.next().is_none());
676 }
677 let (mut discovered_bar, _) = registrations
678 .get(Some(Namespace::from_static("bar")), None, None)
679 .unwrap();
680 assert!(discovered_bar.next().is_some());
681 }
682
683 #[tokio::test]
684 async fn given_peer_unregisters_before_expiry_do_not_emit_registration_expired() {
685 let mut registrations = Registrations::with_config(Config {
686 min_ttl: 1,
687 max_ttl: 10,
688 });
689 let dummy_registration = new_dummy_registration_with_ttl("foo", 2);
690 let namespace = dummy_registration.namespace.clone();
691 let peer_id = dummy_registration.record.peer_id();
692
693 registrations.add(dummy_registration).unwrap();
694 registrations.no_event_for(1).await;
695 registrations.remove(namespace, peer_id);
696
697 registrations.no_event_for(3).await
698 }
699
700 #[tokio::test]
705 async fn given_all_registrations_expired_then_successfully_handle_new_registration_and_expiry()
706 {
707 let mut registrations = Registrations::with_config(Config {
708 min_ttl: 0,
709 max_ttl: 10,
710 });
711 let dummy_registration = new_dummy_registration_with_ttl("foo", 1);
712
713 registrations.add(dummy_registration.clone()).unwrap();
714 let _ = registrations.next_event_in_at_most(2).await;
715
716 registrations.no_event_for(1).await;
717
718 registrations.add(dummy_registration).unwrap();
719 let _ = registrations.next_event_in_at_most(2).await;
720 }
721
722 #[tokio::test]
723 async fn cookies_are_cleaned_up_if_registrations_expire() {
724 let mut registrations = Registrations::with_config(Config {
725 min_ttl: 1,
726 max_ttl: 10,
727 });
728
729 registrations
730 .add(new_dummy_registration_with_ttl("foo", 2))
731 .unwrap();
732 let (_, _) = registrations.get(None, None, None).unwrap();
733
734 assert_eq!(registrations.cookies.len(), 1);
735
736 let _ = registrations.next_event_in_at_most(3).await;
737
738 assert_eq!(registrations.cookies.len(), 0);
739 }
740
741 #[test]
742 fn given_limit_discover_only_returns_n_results() {
743 let mut registrations = Registrations::default();
744 registrations.add(new_dummy_registration("foo")).unwrap();
745 registrations.add(new_dummy_registration("foo")).unwrap();
746
747 let (registrations, _) = registrations.get(None, None, Some(1)).unwrap();
748
749 assert_eq!(registrations.count(), 1);
750 }
751
752 #[test]
753 fn given_limit_cookie_can_be_used_for_pagination() {
754 let mut registrations = Registrations::default();
755 registrations.add(new_dummy_registration("foo")).unwrap();
756 registrations.add(new_dummy_registration("foo")).unwrap();
757
758 let (discover1, cookie) = registrations.get(None, None, Some(1)).unwrap();
759 assert_eq!(discover1.count(), 1);
760
761 let (discover2, _) = registrations.get(None, Some(cookie), None).unwrap();
762 assert_eq!(discover2.count(), 1);
763 }
764
765 fn new_dummy_registration(namespace: &'static str) -> NewRegistration {
766 let identity = identity::Keypair::generate_ed25519();
767
768 new_registration(namespace, identity, None)
769 }
770
771 fn new_dummy_registration_with_ttl(namespace: &'static str, ttl: Ttl) -> NewRegistration {
772 let identity = identity::Keypair::generate_ed25519();
773
774 new_registration(namespace, identity, Some(ttl))
775 }
776
777 fn new_registration(
778 namespace: &'static str,
779 identity: identity::Keypair,
780 ttl: Option<Ttl>,
781 ) -> NewRegistration {
782 NewRegistration::new(
783 Namespace::from_static(namespace),
784 PeerRecord::new(&identity, vec!["/ip4/127.0.0.1/tcp/1234".parse().unwrap()]).unwrap(),
785 ttl,
786 )
787 }
788
789 impl Registrations {
791 async fn next_event(&mut self) -> ExpiredRegistration {
792 futures::future::poll_fn(|cx| self.poll(cx)).await
793 }
794
795 async fn no_event_for(&mut self, seconds: u64) {
797 tokio::time::timeout(Duration::from_secs(seconds), self.next_event())
798 .await
799 .unwrap_err();
800 }
801
802 async fn next_event_in_at_most(&mut self, seconds: u64) -> ExpiredRegistration {
805 tokio::time::timeout(Duration::from_secs(seconds), self.next_event())
806 .await
807 .unwrap()
808 }
809 }
810}