1use std::{
22 collections::{HashMap, HashSet},
23 iter,
24 task::{ready, Context, Poll},
25 time::Duration,
26};
27
28use bimap::BiMap;
29use futures::{future::BoxFuture, stream::FuturesUnordered, FutureExt, StreamExt};
30use libp2p_core::{transport::PortUse, Endpoint, Multiaddr};
31use libp2p_identity::PeerId;
32use libp2p_request_response::ProtocolSupport;
33use libp2p_swarm::{
34 behaviour::FromSwarm, ConnectionDenied, ConnectionId, NetworkBehaviour, THandler,
35 THandlerInEvent, THandlerOutEvent, ToSwarm,
36};
37
38use crate::{
39 codec::{Cookie, ErrorCode, Message, Namespace, NewRegistration, Registration, Ttl},
40 MAX_TTL, MIN_TTL,
41};
42
43pub struct Behaviour {
44 inner: libp2p_request_response::Behaviour<crate::codec::Codec>,
45
46 registrations: Registrations,
47}
48
49pub struct Config {
50 min_ttl: Ttl,
51 max_ttl: Ttl,
52}
53
54impl Config {
55 pub fn with_min_ttl(mut self, min_ttl: Ttl) -> Self {
56 self.min_ttl = min_ttl;
57 self
58 }
59
60 pub fn with_max_ttl(mut self, max_ttl: Ttl) -> Self {
61 self.max_ttl = max_ttl;
62 self
63 }
64}
65
66impl Default for Config {
67 fn default() -> Self {
68 Self {
69 min_ttl: MIN_TTL,
70 max_ttl: MAX_TTL,
71 }
72 }
73}
74
75impl Behaviour {
76 pub fn new(config: Config) -> Self {
78 Self {
79 inner: libp2p_request_response::Behaviour::with_codec(
80 crate::codec::Codec::default(),
81 iter::once((crate::PROTOCOL_IDENT, ProtocolSupport::Inbound)),
82 libp2p_request_response::Config::default(),
83 ),
84
85 registrations: Registrations::with_config(config),
86 }
87 }
88}
89
90#[derive(Debug)]
91#[allow(clippy::large_enum_variant)]
92pub enum Event {
93 DiscoverServed {
95 enquirer: PeerId,
96 registrations: Vec<Registration>,
97 },
98 DiscoverNotServed { enquirer: PeerId, error: ErrorCode },
100 PeerRegistered {
102 peer: PeerId,
103 registration: Registration,
104 },
105 PeerNotRegistered {
107 peer: PeerId,
108 namespace: Namespace,
109 error: ErrorCode,
110 },
111 PeerUnregistered { peer: PeerId, namespace: Namespace },
113 RegistrationExpired(Registration),
115}
116
117impl NetworkBehaviour for Behaviour {
118 type ConnectionHandler = <libp2p_request_response::Behaviour<
119 crate::codec::Codec,
120 > as NetworkBehaviour>::ConnectionHandler;
121
122 type ToSwarm = Event;
123
124 fn handle_established_inbound_connection(
125 &mut self,
126 connection_id: ConnectionId,
127 peer: PeerId,
128 local_addr: &Multiaddr,
129 remote_addr: &Multiaddr,
130 ) -> Result<THandler<Self>, ConnectionDenied> {
131 self.inner.handle_established_inbound_connection(
132 connection_id,
133 peer,
134 local_addr,
135 remote_addr,
136 )
137 }
138
139 fn handle_established_outbound_connection(
140 &mut self,
141 connection_id: ConnectionId,
142 peer: PeerId,
143 addr: &Multiaddr,
144 role_override: Endpoint,
145 port_use: PortUse,
146 ) -> Result<THandler<Self>, ConnectionDenied> {
147 self.inner.handle_established_outbound_connection(
148 connection_id,
149 peer,
150 addr,
151 role_override,
152 port_use,
153 )
154 }
155
156 fn on_connection_handler_event(
157 &mut self,
158 peer_id: PeerId,
159 connection: ConnectionId,
160 event: THandlerOutEvent<Self>,
161 ) {
162 self.inner
163 .on_connection_handler_event(peer_id, connection, event);
164 }
165
166 #[tracing::instrument(level = "trace", name = "NetworkBehaviour::poll", skip(self, cx))]
167 fn poll(
168 &mut self,
169 cx: &mut Context<'_>,
170 ) -> Poll<ToSwarm<Self::ToSwarm, THandlerInEvent<Self>>> {
171 if let Poll::Ready(ExpiredRegistration(registration)) = self.registrations.poll(cx) {
172 return Poll::Ready(ToSwarm::GenerateEvent(Event::RegistrationExpired(
173 registration,
174 )));
175 }
176
177 loop {
178 if let Poll::Ready(to_swarm) = self.inner.poll(cx) {
179 match to_swarm {
180 ToSwarm::GenerateEvent(libp2p_request_response::Event::Message {
181 peer: peer_id,
182 message:
183 libp2p_request_response::Message::Request {
184 request, channel, ..
185 },
186 ..
187 }) => {
188 if let Some((event, response)) =
189 handle_request(peer_id, request, &mut self.registrations)
190 {
191 if let Some(resp) = response {
192 self.inner
193 .send_response(channel, resp)
194 .expect("Send response");
195 }
196
197 return Poll::Ready(ToSwarm::GenerateEvent(event));
198 }
199
200 continue;
201 }
202 ToSwarm::GenerateEvent(libp2p_request_response::Event::InboundFailure {
203 peer,
204 request_id,
205 error,
206 ..
207 }) => {
208 tracing::warn!(
209 %peer,
210 request=%request_id,
211 "Inbound request with peer failed: {error}"
212 );
213
214 continue;
215 }
216 ToSwarm::GenerateEvent(libp2p_request_response::Event::ResponseSent {
217 ..
218 })
219 | ToSwarm::GenerateEvent(libp2p_request_response::Event::Message {
220 peer: _,
221 message: libp2p_request_response::Message::Response { .. },
222 ..
223 })
224 | ToSwarm::GenerateEvent(libp2p_request_response::Event::OutboundFailure {
225 ..
226 }) => {
227 continue;
228 }
229 other => {
230 let new_to_swarm = other
231 .map_out(|_| unreachable!("we manually map `GenerateEvent` variants"));
232
233 return Poll::Ready(new_to_swarm);
234 }
235 };
236 }
237
238 return Poll::Pending;
239 }
240 }
241
242 fn on_swarm_event(&mut self, event: FromSwarm) {
243 self.inner.on_swarm_event(event);
244 }
245}
246
247fn handle_request(
248 peer_id: PeerId,
249 message: Message,
250 registrations: &mut Registrations,
251) -> Option<(Event, Option<Message>)> {
252 match message {
253 Message::Register(registration) => {
254 if registration.record.peer_id() != peer_id {
255 let error = ErrorCode::NotAuthorized;
256
257 let event = Event::PeerNotRegistered {
258 peer: peer_id,
259 namespace: registration.namespace,
260 error,
261 };
262
263 return Some((event, Some(Message::RegisterResponse(Err(error)))));
264 }
265
266 let namespace = registration.namespace.clone();
267
268 match registrations.add(registration) {
269 Ok(registration) => {
270 let response = Message::RegisterResponse(Ok(registration.ttl));
271
272 let event = Event::PeerRegistered {
273 peer: peer_id,
274 registration,
275 };
276
277 Some((event, Some(response)))
278 }
279 Err(TtlOutOfRange::TooLong { .. }) | Err(TtlOutOfRange::TooShort { .. }) => {
280 let error = ErrorCode::InvalidTtl;
281
282 let response = Message::RegisterResponse(Err(error));
283
284 let event = Event::PeerNotRegistered {
285 peer: peer_id,
286 namespace,
287 error,
288 };
289
290 Some((event, Some(response)))
291 }
292 }
293 }
294 Message::Unregister(namespace) => {
295 registrations.remove(namespace.clone(), peer_id);
296
297 let event = Event::PeerUnregistered {
298 peer: peer_id,
299 namespace,
300 };
301
302 Some((event, None))
303 }
304 Message::Discover {
305 namespace,
306 cookie,
307 limit,
308 } => match registrations.get(namespace, cookie, limit) {
309 Ok((registrations, cookie)) => {
310 let discovered = registrations.cloned().collect::<Vec<_>>();
311
312 let response = Message::DiscoverResponse(Ok((discovered.clone(), cookie)));
313
314 let event = Event::DiscoverServed {
315 enquirer: peer_id,
316 registrations: discovered,
317 };
318
319 Some((event, Some(response)))
320 }
321 Err(_) => {
322 let error = ErrorCode::InvalidCookie;
323
324 let response = Message::DiscoverResponse(Err(error));
325
326 let event = Event::DiscoverNotServed {
327 enquirer: peer_id,
328 error,
329 };
330
331 Some((event, Some(response)))
332 }
333 },
334 Message::RegisterResponse(_) => None,
335 Message::DiscoverResponse(_) => None,
336 }
337}
338
339#[derive(Debug, Eq, PartialEq, Hash, Copy, Clone)]
340struct RegistrationId(u64);
341
342impl RegistrationId {
343 fn new() -> Self {
344 Self(rand::random())
345 }
346}
347
348#[derive(Debug, PartialEq)]
349struct ExpiredRegistration(Registration);
350
351pub struct Registrations {
352 registrations_for_peer: BiMap<(PeerId, Namespace), RegistrationId>,
353 registrations: HashMap<RegistrationId, Registration>,
354 cookies: HashMap<Cookie, HashSet<RegistrationId>>,
355 min_ttl: Ttl,
356 max_ttl: Ttl,
357 next_expiry: FuturesUnordered<BoxFuture<'static, RegistrationId>>,
358}
359
360#[derive(Debug, thiserror::Error)]
361pub enum TtlOutOfRange {
362 #[error("Requested TTL ({requested}s) is too long; max {bound}s")]
363 TooLong { bound: Ttl, requested: Ttl },
364 #[error("Requested TTL ({requested}s) is too short; min {bound}s")]
365 TooShort { bound: Ttl, requested: Ttl },
366}
367
368impl Default for Registrations {
369 fn default() -> Self {
370 Registrations::with_config(Config::default())
371 }
372}
373
374impl Registrations {
375 pub fn with_config(config: Config) -> Self {
376 Self {
377 registrations_for_peer: Default::default(),
378 registrations: Default::default(),
379 min_ttl: config.min_ttl,
380 max_ttl: config.max_ttl,
381 cookies: Default::default(),
382 next_expiry: FuturesUnordered::from_iter(vec![futures::future::pending().boxed()]),
383 }
384 }
385
386 pub fn add(
387 &mut self,
388 new_registration: NewRegistration,
389 ) -> Result<Registration, TtlOutOfRange> {
390 let ttl = new_registration.effective_ttl();
391 if ttl > self.max_ttl {
392 return Err(TtlOutOfRange::TooLong {
393 bound: self.max_ttl,
394 requested: ttl,
395 });
396 }
397 if ttl < self.min_ttl {
398 return Err(TtlOutOfRange::TooShort {
399 bound: self.min_ttl,
400 requested: ttl,
401 });
402 }
403
404 let namespace = new_registration.namespace;
405 let registration_id = RegistrationId::new();
406
407 if let Some(old_registration) = self
408 .registrations_for_peer
409 .get_by_left(&(new_registration.record.peer_id(), namespace.clone()))
410 {
411 self.registrations.remove(old_registration);
412 }
413
414 self.registrations_for_peer.insert(
415 (new_registration.record.peer_id(), namespace.clone()),
416 registration_id,
417 );
418
419 let registration = Registration {
420 namespace,
421 record: new_registration.record,
422 ttl,
423 };
424 self.registrations
425 .insert(registration_id, registration.clone());
426
427 let next_expiry = futures_timer::Delay::new(Duration::from_secs(ttl))
428 .map(move |_| registration_id)
429 .boxed();
430
431 self.next_expiry.push(next_expiry);
432
433 Ok(registration)
434 }
435
436 pub fn remove(&mut self, namespace: Namespace, peer_id: PeerId) {
437 let reggo_to_remove = self
438 .registrations_for_peer
439 .remove_by_left(&(peer_id, namespace));
440
441 if let Some((_, reggo_to_remove)) = reggo_to_remove {
442 self.registrations.remove(®go_to_remove);
443 }
444 }
445
446 pub fn get(
447 &mut self,
448 discover_namespace: Option<Namespace>,
449 cookie: Option<Cookie>,
450 limit: Option<u64>,
451 ) -> Result<(impl Iterator<Item = &Registration> + '_, Cookie), CookieNamespaceMismatch> {
452 let cookie_namespace = cookie.as_ref().and_then(|cookie| cookie.namespace());
453
454 match (discover_namespace.as_ref(), cookie_namespace) {
455 (None, Some(_)) => return Err(CookieNamespaceMismatch),
457 (Some(namespace), Some(cookie_namespace)) if namespace != cookie_namespace => {
459 return Err(CookieNamespaceMismatch)
460 }
461 _ => {}
463 }
464
465 let mut reggos_of_last_discover = cookie
466 .and_then(|cookie| self.cookies.get(&cookie))
467 .cloned()
468 .unwrap_or_default();
469
470 let ids = self
471 .registrations_for_peer
472 .iter()
473 .filter_map({
474 |((_, namespace), registration_id)| {
475 if reggos_of_last_discover.contains(registration_id) {
476 return None;
477 }
478
479 match discover_namespace.as_ref() {
480 Some(discover_namespace) if discover_namespace == namespace => {
481 Some(registration_id)
482 }
483 Some(_) => None,
484 None => Some(registration_id),
485 }
486 }
487 })
488 .take(limit.unwrap_or(u64::MAX) as usize)
489 .cloned()
490 .collect::<Vec<_>>();
491
492 reggos_of_last_discover.extend(&ids);
493
494 let new_cookie = discover_namespace
495 .map(Cookie::for_namespace)
496 .unwrap_or_else(Cookie::for_all_namespaces);
497 self.cookies
498 .insert(new_cookie.clone(), reggos_of_last_discover);
499
500 let regs = &self.registrations;
501 let registrations = ids
502 .into_iter()
503 .map(move |id| regs.get(&id).expect("bad internal data structure"));
504
505 Ok((registrations, new_cookie))
506 }
507
508 fn poll(&mut self, cx: &mut Context<'_>) -> Poll<ExpiredRegistration> {
509 loop {
510 let expired_registration = ready!(self.next_expiry.poll_next_unpin(cx)).expect(
511 "This stream should never finish because it is initialised with a pending future",
512 );
513
514 self.cookies.retain(|_, registrations| {
516 registrations.remove(&expired_registration);
517
518 !registrations.is_empty()
520 });
521
522 self.registrations_for_peer
523 .remove_by_right(&expired_registration);
524 match self.registrations.remove(&expired_registration) {
525 None => {
526 continue;
527 }
528 Some(registration) => {
529 return Poll::Ready(ExpiredRegistration(registration));
530 }
531 }
532 }
533 }
534}
535
536#[derive(Debug, thiserror::Error, Eq, PartialEq)]
537#[error("The provided cookie is not valid for a DISCOVER request for the given namespace")]
538pub struct CookieNamespaceMismatch;
539
540#[cfg(test)]
541mod tests {
542 use libp2p_core::PeerRecord;
543 use libp2p_identity as identity;
544 use web_time::SystemTime;
545
546 use super::*;
547
548 #[test]
549 fn given_cookie_from_discover_when_discover_again_then_only_get_diff() {
550 let mut registrations = Registrations::default();
551 registrations.add(new_dummy_registration("foo")).unwrap();
552 registrations.add(new_dummy_registration("foo")).unwrap();
553
554 let (initial_discover, cookie) = registrations.get(None, None, None).unwrap();
555 assert_eq!(initial_discover.count(), 2);
556
557 let (subsequent_discover, _) = registrations.get(None, Some(cookie), None).unwrap();
558 assert_eq!(subsequent_discover.count(), 0);
559 }
560
561 #[test]
562 fn given_registrations_when_discover_all_then_all_are_returned() {
563 let mut registrations = Registrations::default();
564 registrations.add(new_dummy_registration("foo")).unwrap();
565 registrations.add(new_dummy_registration("foo")).unwrap();
566
567 let (discover, _) = registrations.get(None, None, None).unwrap();
568
569 assert_eq!(discover.count(), 2);
570 }
571
572 #[test]
573 fn given_registrations_when_discover_only_for_specific_namespace_then_only_those_are_returned()
574 {
575 let mut registrations = Registrations::default();
576 registrations.add(new_dummy_registration("foo")).unwrap();
577 registrations.add(new_dummy_registration("bar")).unwrap();
578
579 let (discover, _) = registrations
580 .get(Some(Namespace::from_static("foo")), None, None)
581 .unwrap();
582
583 assert_eq!(
584 discover.map(|r| &r.namespace).collect::<Vec<_>>(),
585 vec!["foo"]
586 );
587 }
588
589 #[test]
590 fn given_reregistration_old_registration_is_discarded() {
591 let alice = identity::Keypair::generate_ed25519();
592 let mut registrations = Registrations::default();
593 registrations
594 .add(new_registration("foo", alice.clone(), None))
595 .unwrap();
596 registrations
597 .add(new_registration("foo", alice, None))
598 .unwrap();
599
600 let (discover, _) = registrations
601 .get(Some(Namespace::from_static("foo")), None, None)
602 .unwrap();
603
604 assert_eq!(
605 discover.map(|r| &r.namespace).collect::<Vec<_>>(),
606 vec!["foo"]
607 );
608 }
609
610 #[test]
611 fn given_cookie_from_2nd_discover_does_not_return_nodes_from_first_discover() {
612 let mut registrations = Registrations::default();
613 registrations.add(new_dummy_registration("foo")).unwrap();
614 registrations.add(new_dummy_registration("foo")).unwrap();
615
616 let (initial_discover, cookie1) = registrations.get(None, None, None).unwrap();
617 assert_eq!(initial_discover.count(), 2);
618
619 let (subsequent_discover, cookie2) = registrations.get(None, Some(cookie1), None).unwrap();
620 assert_eq!(subsequent_discover.count(), 0);
621
622 let (subsequent_discover, _) = registrations.get(None, Some(cookie2), None).unwrap();
623 assert_eq!(subsequent_discover.count(), 0);
624 }
625
626 #[test]
627 fn cookie_from_different_discover_request_is_not_valid() {
628 let mut registrations = Registrations::default();
629 registrations.add(new_dummy_registration("foo")).unwrap();
630 registrations.add(new_dummy_registration("bar")).unwrap();
631
632 let (_, foo_discover_cookie) = registrations
633 .get(Some(Namespace::from_static("foo")), None, None)
634 .unwrap();
635 let result = registrations.get(
636 Some(Namespace::from_static("bar")),
637 Some(foo_discover_cookie),
638 None,
639 );
640
641 assert!(matches!(result, Err(CookieNamespaceMismatch)))
642 }
643
644 #[tokio::test]
645 async fn given_two_registration_ttls_one_expires_one_lives() {
646 let mut registrations = Registrations::with_config(Config {
647 min_ttl: 0,
648 max_ttl: 4,
649 });
650
651 let start_time = SystemTime::now();
652
653 registrations
654 .add(new_dummy_registration_with_ttl("foo", 1))
655 .unwrap();
656 registrations
657 .add(new_dummy_registration_with_ttl("bar", 4))
658 .unwrap();
659
660 let event = registrations.next_event().await;
661
662 let elapsed = start_time.elapsed().unwrap();
663 assert!(elapsed.as_secs() >= 1);
664 assert!(elapsed.as_secs() < 2);
665
666 assert_eq!(event.0.namespace, Namespace::from_static("foo"));
667
668 {
669 let (mut discovered_foo, _) = registrations
670 .get(Some(Namespace::from_static("foo")), None, None)
671 .unwrap();
672 assert!(discovered_foo.next().is_none());
673 }
674 let (mut discovered_bar, _) = registrations
675 .get(Some(Namespace::from_static("bar")), None, None)
676 .unwrap();
677 assert!(discovered_bar.next().is_some());
678 }
679
680 #[tokio::test]
681 async fn given_peer_unregisters_before_expiry_do_not_emit_registration_expired() {
682 let mut registrations = Registrations::with_config(Config {
683 min_ttl: 1,
684 max_ttl: 10,
685 });
686 let dummy_registration = new_dummy_registration_with_ttl("foo", 2);
687 let namespace = dummy_registration.namespace.clone();
688 let peer_id = dummy_registration.record.peer_id();
689
690 registrations.add(dummy_registration).unwrap();
691 registrations.no_event_for(1).await;
692 registrations.remove(namespace, peer_id);
693
694 registrations.no_event_for(3).await
695 }
696
697 #[tokio::test]
702 async fn given_all_registrations_expired_then_successfully_handle_new_registration_and_expiry()
703 {
704 let mut registrations = Registrations::with_config(Config {
705 min_ttl: 0,
706 max_ttl: 10,
707 });
708 let dummy_registration = new_dummy_registration_with_ttl("foo", 1);
709
710 registrations.add(dummy_registration.clone()).unwrap();
711 let _ = registrations.next_event_in_at_most(2).await;
712
713 registrations.no_event_for(1).await;
714
715 registrations.add(dummy_registration).unwrap();
716 let _ = registrations.next_event_in_at_most(2).await;
717 }
718
719 #[tokio::test]
720 async fn cookies_are_cleaned_up_if_registrations_expire() {
721 let mut registrations = Registrations::with_config(Config {
722 min_ttl: 1,
723 max_ttl: 10,
724 });
725
726 registrations
727 .add(new_dummy_registration_with_ttl("foo", 2))
728 .unwrap();
729 let (_, _) = registrations.get(None, None, None).unwrap();
730
731 assert_eq!(registrations.cookies.len(), 1);
732
733 let _ = registrations.next_event_in_at_most(3).await;
734
735 assert_eq!(registrations.cookies.len(), 0);
736 }
737
738 #[test]
739 fn given_limit_discover_only_returns_n_results() {
740 let mut registrations = Registrations::default();
741 registrations.add(new_dummy_registration("foo")).unwrap();
742 registrations.add(new_dummy_registration("foo")).unwrap();
743
744 let (registrations, _) = registrations.get(None, None, Some(1)).unwrap();
745
746 assert_eq!(registrations.count(), 1);
747 }
748
749 #[test]
750 fn given_limit_cookie_can_be_used_for_pagination() {
751 let mut registrations = Registrations::default();
752 registrations.add(new_dummy_registration("foo")).unwrap();
753 registrations.add(new_dummy_registration("foo")).unwrap();
754
755 let (discover1, cookie) = registrations.get(None, None, Some(1)).unwrap();
756 assert_eq!(discover1.count(), 1);
757
758 let (discover2, _) = registrations.get(None, Some(cookie), None).unwrap();
759 assert_eq!(discover2.count(), 1);
760 }
761
762 fn new_dummy_registration(namespace: &'static str) -> NewRegistration {
763 let identity = identity::Keypair::generate_ed25519();
764
765 new_registration(namespace, identity, None)
766 }
767
768 fn new_dummy_registration_with_ttl(namespace: &'static str, ttl: Ttl) -> NewRegistration {
769 let identity = identity::Keypair::generate_ed25519();
770
771 new_registration(namespace, identity, Some(ttl))
772 }
773
774 fn new_registration(
775 namespace: &'static str,
776 identity: identity::Keypair,
777 ttl: Option<Ttl>,
778 ) -> NewRegistration {
779 NewRegistration::new(
780 Namespace::from_static(namespace),
781 PeerRecord::new(&identity, vec!["/ip4/127.0.0.1/tcp/1234".parse().unwrap()]).unwrap(),
782 ttl,
783 )
784 }
785
786 impl Registrations {
788 async fn next_event(&mut self) -> ExpiredRegistration {
789 futures::future::poll_fn(|cx| self.poll(cx)).await
790 }
791
792 async fn no_event_for(&mut self, seconds: u64) {
794 tokio::time::timeout(Duration::from_secs(seconds), self.next_event())
795 .await
796 .unwrap_err();
797 }
798
799 async fn next_event_in_at_most(&mut self, seconds: u64) -> ExpiredRegistration {
802 tokio::time::timeout(Duration::from_secs(seconds), self.next_event())
803 .await
804 .unwrap()
805 }
806 }
807}