1#![cfg_attr(docsrs, feature(doc_cfg, doc_auto_cfg))]
68
69#[cfg(feature = "cbor")]
70pub mod cbor;
71mod codec;
72mod handler;
73#[cfg(feature = "json")]
74pub mod json;
75
76use std::{
77 collections::{HashMap, HashSet, VecDeque},
78 fmt, io,
79 sync::{atomic::AtomicU64, Arc},
80 task::{Context, Poll},
81 time::Duration,
82};
83
84pub use codec::Codec;
85use futures::channel::oneshot;
86use handler::Handler;
87pub use handler::ProtocolSupport;
88use libp2p_core::{transport::PortUse, ConnectedPoint, Endpoint, Multiaddr};
89use libp2p_identity::PeerId;
90use libp2p_swarm::{
91 behaviour::{AddressChange, ConnectionClosed, DialFailure, FromSwarm},
92 dial_opts::DialOpts,
93 ConnectionDenied, ConnectionHandler, ConnectionId, NetworkBehaviour, NotifyHandler,
94 PeerAddresses, THandler, THandlerInEvent, THandlerOutEvent, ToSwarm,
95};
96use smallvec::SmallVec;
97
98use crate::handler::OutboundMessage;
99
100#[derive(Debug)]
102pub enum Message<TRequest, TResponse, TChannelResponse = TResponse> {
103 Request {
105 request_id: InboundRequestId,
107 request: TRequest,
109 channel: ResponseChannel<TChannelResponse>,
115 },
116 Response {
118 request_id: OutboundRequestId,
122 response: TResponse,
124 },
125}
126
127#[derive(Debug)]
129pub enum Event<TRequest, TResponse, TChannelResponse = TResponse> {
130 Message {
132 peer: PeerId,
134 connection_id: ConnectionId,
136 message: Message<TRequest, TResponse, TChannelResponse>,
138 },
139 OutboundFailure {
141 peer: PeerId,
143 connection_id: ConnectionId,
145 request_id: OutboundRequestId,
147 error: OutboundFailure,
149 },
150 InboundFailure {
152 peer: PeerId,
154 connection_id: ConnectionId,
156 request_id: InboundRequestId,
158 error: InboundFailure,
160 },
161 ResponseSent {
166 peer: PeerId,
168 connection_id: ConnectionId,
170 request_id: InboundRequestId,
172 },
173}
174
175#[derive(Debug)]
178pub enum OutboundFailure {
179 DialFailure,
181 Timeout,
186 ConnectionClosed,
191 UnsupportedProtocols,
193 Io(io::Error),
195}
196
197impl fmt::Display for OutboundFailure {
198 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
199 match self {
200 OutboundFailure::DialFailure => write!(f, "Failed to dial the requested peer"),
201 OutboundFailure::Timeout => write!(f, "Timeout while waiting for a response"),
202 OutboundFailure::ConnectionClosed => {
203 write!(f, "Connection was closed before a response was received")
204 }
205 OutboundFailure::UnsupportedProtocols => {
206 write!(f, "The remote supports none of the requested protocols")
207 }
208 OutboundFailure::Io(e) => write!(f, "IO error on outbound stream: {e}"),
209 }
210 }
211}
212
213impl std::error::Error for OutboundFailure {}
214
215#[derive(Debug)]
218pub enum InboundFailure {
219 Timeout,
224 ConnectionClosed,
226 UnsupportedProtocols,
229 ResponseOmission,
233 Io(io::Error),
235}
236
237impl fmt::Display for InboundFailure {
238 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
239 match self {
240 InboundFailure::Timeout => {
241 write!(f, "Timeout while receiving request or sending response")
242 }
243 InboundFailure::ConnectionClosed => {
244 write!(f, "Connection was closed before a response could be sent")
245 }
246 InboundFailure::UnsupportedProtocols => write!(
247 f,
248 "The local peer supports none of the protocols requested by the remote"
249 ),
250 InboundFailure::ResponseOmission => write!(
251 f,
252 "The response channel was dropped without sending a response to the remote"
253 ),
254 InboundFailure::Io(e) => write!(f, "IO error on inbound stream: {e}"),
255 }
256 }
257}
258
259impl std::error::Error for InboundFailure {}
260
261#[derive(Debug)]
265pub struct ResponseChannel<TResponse> {
266 sender: oneshot::Sender<TResponse>,
267}
268
269impl<TResponse> ResponseChannel<TResponse> {
270 pub fn is_open(&self) -> bool {
278 !self.sender.is_canceled()
279 }
280}
281
282#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)]
287pub struct InboundRequestId(u64);
288
289impl fmt::Display for InboundRequestId {
290 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
291 write!(f, "{}", self.0)
292 }
293}
294
295#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)]
300pub struct OutboundRequestId(u64);
301
302impl fmt::Display for OutboundRequestId {
303 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
304 write!(f, "{}", self.0)
305 }
306}
307
308#[derive(Debug, Clone)]
310pub struct Config {
311 request_timeout: Duration,
312 max_concurrent_streams: usize,
313}
314
315impl Default for Config {
316 fn default() -> Self {
317 Self {
318 request_timeout: Duration::from_secs(10),
319 max_concurrent_streams: 100,
320 }
321 }
322}
323
324impl Config {
325 #[deprecated(note = "Use `Config::with_request_timeout` for one-liner constructions.")]
327 pub fn set_request_timeout(&mut self, v: Duration) -> &mut Self {
328 self.request_timeout = v;
329 self
330 }
331
332 pub fn with_request_timeout(mut self, v: Duration) -> Self {
334 self.request_timeout = v;
335 self
336 }
337
338 pub fn with_max_concurrent_streams(mut self, num_streams: usize) -> Self {
340 self.max_concurrent_streams = num_streams;
341 self
342 }
343}
344
345pub struct Behaviour<TCodec>
347where
348 TCodec: Codec + Clone + Send + 'static,
349{
350 inbound_protocols: SmallVec<[TCodec::Protocol; 2]>,
352 outbound_protocols: SmallVec<[TCodec::Protocol; 2]>,
354 next_outbound_request_id: OutboundRequestId,
356 next_inbound_request_id: Arc<AtomicU64>,
358 config: Config,
360 codec: TCodec,
362 pending_events:
364 VecDeque<ToSwarm<Event<TCodec::Request, TCodec::Response>, OutboundMessage<TCodec>>>,
365 connected: HashMap<PeerId, SmallVec<[Connection; 2]>>,
368 addresses: PeerAddresses,
370 pending_outbound_requests: HashMap<PeerId, SmallVec<[OutboundMessage<TCodec>; 10]>>,
373}
374
375impl<TCodec> Behaviour<TCodec>
376where
377 TCodec: Codec + Default + Clone + Send + 'static,
378{
379 pub fn new<I>(protocols: I, cfg: Config) -> Self
382 where
383 I: IntoIterator<Item = (TCodec::Protocol, ProtocolSupport)>,
384 {
385 Self::with_codec(TCodec::default(), protocols, cfg)
386 }
387}
388
389impl<TCodec> Behaviour<TCodec>
390where
391 TCodec: Codec + Clone + Send + 'static,
392{
393 pub fn with_codec<I>(codec: TCodec, protocols: I, cfg: Config) -> Self
396 where
397 I: IntoIterator<Item = (TCodec::Protocol, ProtocolSupport)>,
398 {
399 let mut inbound_protocols = SmallVec::new();
400 let mut outbound_protocols = SmallVec::new();
401 for (p, s) in protocols {
402 if s.inbound() {
403 inbound_protocols.push(p.clone());
404 }
405 if s.outbound() {
406 outbound_protocols.push(p.clone());
407 }
408 }
409 Behaviour {
410 inbound_protocols,
411 outbound_protocols,
412 next_outbound_request_id: OutboundRequestId(1),
413 next_inbound_request_id: Arc::new(AtomicU64::new(1)),
414 config: cfg,
415 codec,
416 pending_events: VecDeque::new(),
417 connected: HashMap::new(),
418 pending_outbound_requests: HashMap::new(),
419 addresses: PeerAddresses::default(),
420 }
421 }
422
423 pub fn send_request(&mut self, peer: &PeerId, request: TCodec::Request) -> OutboundRequestId {
437 let request_id = self.next_outbound_request_id();
438 let request = OutboundMessage {
439 request_id,
440 request,
441 protocols: self.outbound_protocols.clone(),
442 };
443
444 if let Some(request) = self.try_send_request(peer, request) {
445 self.pending_events.push_back(ToSwarm::Dial {
446 opts: DialOpts::peer_id(*peer).build(),
447 });
448 self.pending_outbound_requests
449 .entry(*peer)
450 .or_default()
451 .push(request);
452 }
453
454 request_id
455 }
456
457 pub fn send_response(
469 &mut self,
470 ch: ResponseChannel<TCodec::Response>,
471 rs: TCodec::Response,
472 ) -> Result<(), TCodec::Response> {
473 ch.sender.send(rs)
474 }
475
476 #[deprecated(note = "Use `Swarm::add_peer_address` instead.")]
485 pub fn add_address(&mut self, peer: &PeerId, address: Multiaddr) -> bool {
486 self.addresses.add(*peer, address)
487 }
488
489 #[deprecated(note = "Will be removed with the next breaking release and won't be replaced.")]
491 pub fn remove_address(&mut self, peer: &PeerId, address: &Multiaddr) {
492 self.addresses.remove(peer, address);
493 }
494
495 pub fn is_connected(&self, peer: &PeerId) -> bool {
497 if let Some(connections) = self.connected.get(peer) {
498 !connections.is_empty()
499 } else {
500 false
501 }
502 }
503
504 pub fn is_pending_outbound(&self, peer: &PeerId, request_id: &OutboundRequestId) -> bool {
508 let est_conn = self
510 .connected
511 .get(peer)
512 .map(|cs| {
513 cs.iter()
514 .any(|c| c.pending_outbound_responses.contains(request_id))
515 })
516 .unwrap_or(false);
517 let pen_conn = self
519 .pending_outbound_requests
520 .get(peer)
521 .map(|rps| rps.iter().any(|rp| rp.request_id == *request_id))
522 .unwrap_or(false);
523
524 est_conn || pen_conn
525 }
526
527 pub fn is_pending_inbound(&self, peer: &PeerId, request_id: &InboundRequestId) -> bool {
531 self.connected
532 .get(peer)
533 .map(|cs| {
534 cs.iter()
535 .any(|c| c.pending_inbound_responses.contains(request_id))
536 })
537 .unwrap_or(false)
538 }
539
540 fn next_outbound_request_id(&mut self) -> OutboundRequestId {
542 let request_id = self.next_outbound_request_id;
543 self.next_outbound_request_id.0 += 1;
544 request_id
545 }
546
547 fn try_send_request(
551 &mut self,
552 peer: &PeerId,
553 request: OutboundMessage<TCodec>,
554 ) -> Option<OutboundMessage<TCodec>> {
555 if let Some(connections) = self.connected.get_mut(peer) {
556 if connections.is_empty() {
557 return Some(request);
558 }
559 let ix = (request.request_id.0 as usize) % connections.len();
560 let conn = &mut connections[ix];
561 conn.pending_outbound_responses.insert(request.request_id);
562 self.pending_events.push_back(ToSwarm::NotifyHandler {
563 peer_id: *peer,
564 handler: NotifyHandler::One(conn.id),
565 event: request,
566 });
567 None
568 } else {
569 Some(request)
570 }
571 }
572
573 fn remove_pending_outbound_response(
579 &mut self,
580 peer: &PeerId,
581 connection_id: ConnectionId,
582 request: OutboundRequestId,
583 ) -> bool {
584 self.get_connection_mut(peer, connection_id)
585 .map(|c| c.pending_outbound_responses.remove(&request))
586 .unwrap_or(false)
587 }
588
589 fn remove_pending_inbound_response(
595 &mut self,
596 peer: &PeerId,
597 connection_id: ConnectionId,
598 request: InboundRequestId,
599 ) -> bool {
600 self.get_connection_mut(peer, connection_id)
601 .map(|c| c.pending_inbound_responses.remove(&request))
602 .unwrap_or(false)
603 }
604
605 fn get_connection_mut(
608 &mut self,
609 peer: &PeerId,
610 connection_id: ConnectionId,
611 ) -> Option<&mut Connection> {
612 self.connected
613 .get_mut(peer)
614 .and_then(|connections| connections.iter_mut().find(|c| c.id == connection_id))
615 }
616
617 fn on_address_change(
618 &mut self,
619 AddressChange {
620 peer_id,
621 connection_id,
622 new,
623 ..
624 }: AddressChange,
625 ) {
626 let new_address = match new {
627 ConnectedPoint::Dialer { address, .. } => Some(address.clone()),
628 ConnectedPoint::Listener { .. } => None,
629 };
630 let connections = self
631 .connected
632 .get_mut(&peer_id)
633 .expect("Address change can only happen on an established connection.");
634
635 let connection = connections
636 .iter_mut()
637 .find(|c| c.id == connection_id)
638 .expect("Address change can only happen on an established connection.");
639 connection.remote_address = new_address;
640 }
641
642 fn on_connection_closed(
643 &mut self,
644 ConnectionClosed {
645 peer_id,
646 connection_id,
647 remaining_established,
648 ..
649 }: ConnectionClosed,
650 ) {
651 let connections = self
652 .connected
653 .get_mut(&peer_id)
654 .expect("Expected some established connection to peer before closing.");
655
656 let connection = connections
657 .iter()
658 .position(|c| c.id == connection_id)
659 .map(|p: usize| connections.remove(p))
660 .expect("Expected connection to be established before closing.");
661
662 debug_assert_eq!(connections.is_empty(), remaining_established == 0);
663 if connections.is_empty() {
664 self.connected.remove(&peer_id);
665 }
666
667 for request_id in connection.pending_inbound_responses {
668 self.pending_events
669 .push_back(ToSwarm::GenerateEvent(Event::InboundFailure {
670 peer: peer_id,
671 connection_id,
672 request_id,
673 error: InboundFailure::ConnectionClosed,
674 }));
675 }
676
677 for request_id in connection.pending_outbound_responses {
678 self.pending_events
679 .push_back(ToSwarm::GenerateEvent(Event::OutboundFailure {
680 peer: peer_id,
681 connection_id,
682 request_id,
683 error: OutboundFailure::ConnectionClosed,
684 }));
685 }
686 }
687
688 fn on_dial_failure(
689 &mut self,
690 DialFailure {
691 peer_id,
692 connection_id,
693 ..
694 }: DialFailure,
695 ) {
696 if let Some(peer) = peer_id {
697 if let Some(pending) = self.pending_outbound_requests.remove(&peer) {
704 for request in pending {
705 self.pending_events
706 .push_back(ToSwarm::GenerateEvent(Event::OutboundFailure {
707 peer,
708 connection_id,
709 request_id: request.request_id,
710 error: OutboundFailure::DialFailure,
711 }));
712 }
713 }
714 }
715 }
716
717 fn preload_new_handler(
720 &mut self,
721 handler: &mut Handler<TCodec>,
722 peer: PeerId,
723 connection_id: ConnectionId,
724 remote_address: Option<Multiaddr>,
725 ) {
726 let mut connection = Connection::new(connection_id, remote_address);
727
728 if let Some(pending_requests) = self.pending_outbound_requests.remove(&peer) {
729 for request in pending_requests {
730 connection
731 .pending_outbound_responses
732 .insert(request.request_id);
733 handler.on_behaviour_event(request);
734 }
735 }
736
737 self.connected.entry(peer).or_default().push(connection);
738 }
739}
740
741impl<TCodec> NetworkBehaviour for Behaviour<TCodec>
742where
743 TCodec: Codec + Send + Clone + 'static,
744{
745 type ConnectionHandler = Handler<TCodec>;
746 type ToSwarm = Event<TCodec::Request, TCodec::Response>;
747
748 fn handle_established_inbound_connection(
749 &mut self,
750 connection_id: ConnectionId,
751 peer: PeerId,
752 _: &Multiaddr,
753 _: &Multiaddr,
754 ) -> Result<THandler<Self>, ConnectionDenied> {
755 let mut handler = Handler::new(
756 self.inbound_protocols.clone(),
757 self.codec.clone(),
758 self.config.request_timeout,
759 self.next_inbound_request_id.clone(),
760 self.config.max_concurrent_streams,
761 );
762
763 self.preload_new_handler(&mut handler, peer, connection_id, None);
764
765 Ok(handler)
766 }
767
768 fn handle_pending_outbound_connection(
769 &mut self,
770 _connection_id: ConnectionId,
771 maybe_peer: Option<PeerId>,
772 _addresses: &[Multiaddr],
773 _effective_role: Endpoint,
774 ) -> Result<Vec<Multiaddr>, ConnectionDenied> {
775 let peer = match maybe_peer {
776 None => return Ok(vec![]),
777 Some(peer) => peer,
778 };
779
780 let mut addresses = Vec::new();
781 if let Some(connections) = self.connected.get(&peer) {
782 addresses.extend(connections.iter().filter_map(|c| c.remote_address.clone()))
783 }
784
785 let cached_addrs = self.addresses.get(&peer);
786 addresses.extend(cached_addrs);
787
788 Ok(addresses)
789 }
790
791 fn handle_established_outbound_connection(
792 &mut self,
793 connection_id: ConnectionId,
794 peer: PeerId,
795 remote_address: &Multiaddr,
796 _: Endpoint,
797 _: PortUse,
798 ) -> Result<THandler<Self>, ConnectionDenied> {
799 let mut handler = Handler::new(
800 self.inbound_protocols.clone(),
801 self.codec.clone(),
802 self.config.request_timeout,
803 self.next_inbound_request_id.clone(),
804 self.config.max_concurrent_streams,
805 );
806
807 self.preload_new_handler(
808 &mut handler,
809 peer,
810 connection_id,
811 Some(remote_address.clone()),
812 );
813
814 Ok(handler)
815 }
816
817 fn on_swarm_event(&mut self, event: FromSwarm) {
818 self.addresses.on_swarm_event(&event);
819 match event {
820 FromSwarm::ConnectionEstablished(_) => {}
821 FromSwarm::ConnectionClosed(connection_closed) => {
822 self.on_connection_closed(connection_closed)
823 }
824 FromSwarm::AddressChange(address_change) => self.on_address_change(address_change),
825 FromSwarm::DialFailure(dial_failure) => self.on_dial_failure(dial_failure),
826 _ => {}
827 }
828 }
829
830 fn on_connection_handler_event(
831 &mut self,
832 peer: PeerId,
833 connection_id: ConnectionId,
834 event: THandlerOutEvent<Self>,
835 ) {
836 match event {
837 handler::Event::Response {
838 request_id,
839 response,
840 } => {
841 let removed =
842 self.remove_pending_outbound_response(&peer, connection_id, request_id);
843 debug_assert!(
844 removed,
845 "Expect request_id to be pending before receiving response.",
846 );
847
848 let message = Message::Response {
849 request_id,
850 response,
851 };
852 self.pending_events
853 .push_back(ToSwarm::GenerateEvent(Event::Message {
854 peer,
855 connection_id,
856 message,
857 }));
858 }
859 handler::Event::Request {
860 request_id,
861 request,
862 sender,
863 } => match self.get_connection_mut(&peer, connection_id) {
864 Some(connection) => {
865 let inserted = connection.pending_inbound_responses.insert(request_id);
866 debug_assert!(inserted, "Expect id of new request to be unknown.");
867
868 let channel = ResponseChannel { sender };
869 let message = Message::Request {
870 request_id,
871 request,
872 channel,
873 };
874 self.pending_events
875 .push_back(ToSwarm::GenerateEvent(Event::Message {
876 peer,
877 connection_id,
878 message,
879 }));
880 }
881 None => {
882 tracing::debug!("Connection ({connection_id}) closed after `Event::Request` ({request_id}) has been emitted.");
883 }
884 },
885 handler::Event::ResponseSent(request_id) => {
886 let removed =
887 self.remove_pending_inbound_response(&peer, connection_id, request_id);
888 debug_assert!(
889 removed,
890 "Expect request_id to be pending before response is sent."
891 );
892
893 self.pending_events
894 .push_back(ToSwarm::GenerateEvent(Event::ResponseSent {
895 peer,
896 connection_id,
897 request_id,
898 }));
899 }
900 handler::Event::ResponseOmission(request_id) => {
901 let removed =
902 self.remove_pending_inbound_response(&peer, connection_id, request_id);
903 debug_assert!(
904 removed,
905 "Expect request_id to be pending before response is omitted.",
906 );
907
908 self.pending_events
909 .push_back(ToSwarm::GenerateEvent(Event::InboundFailure {
910 peer,
911 connection_id,
912 request_id,
913 error: InboundFailure::ResponseOmission,
914 }));
915 }
916 handler::Event::OutboundTimeout(request_id) => {
917 let removed =
918 self.remove_pending_outbound_response(&peer, connection_id, request_id);
919 debug_assert!(
920 removed,
921 "Expect request_id to be pending before request times out."
922 );
923
924 self.pending_events
925 .push_back(ToSwarm::GenerateEvent(Event::OutboundFailure {
926 peer,
927 connection_id,
928 request_id,
929 error: OutboundFailure::Timeout,
930 }));
931 }
932 handler::Event::OutboundUnsupportedProtocols(request_id) => {
933 let removed =
934 self.remove_pending_outbound_response(&peer, connection_id, request_id);
935 debug_assert!(
936 removed,
937 "Expect request_id to be pending before failing to connect.",
938 );
939
940 self.pending_events
941 .push_back(ToSwarm::GenerateEvent(Event::OutboundFailure {
942 peer,
943 connection_id,
944 request_id,
945 error: OutboundFailure::UnsupportedProtocols,
946 }));
947 }
948 handler::Event::OutboundStreamFailed { request_id, error } => {
949 let removed =
950 self.remove_pending_outbound_response(&peer, connection_id, request_id);
951 debug_assert!(removed, "Expect request_id to be pending upon failure");
952
953 self.pending_events
954 .push_back(ToSwarm::GenerateEvent(Event::OutboundFailure {
955 peer,
956 connection_id,
957 request_id,
958 error: OutboundFailure::Io(error),
959 }))
960 }
961 handler::Event::InboundTimeout(request_id) => {
962 let removed =
963 self.remove_pending_inbound_response(&peer, connection_id, request_id);
964
965 if removed {
966 self.pending_events
967 .push_back(ToSwarm::GenerateEvent(Event::InboundFailure {
968 peer,
969 connection_id,
970 request_id,
971 error: InboundFailure::Timeout,
972 }));
973 } else {
974 tracing::debug!(
976 "Inbound request timeout for an unknown request_id ({request_id})"
977 );
978 }
979 }
980 handler::Event::InboundStreamFailed { request_id, error } => {
981 let removed =
982 self.remove_pending_inbound_response(&peer, connection_id, request_id);
983
984 if removed {
985 self.pending_events
986 .push_back(ToSwarm::GenerateEvent(Event::InboundFailure {
987 peer,
988 connection_id,
989 request_id,
990 error: InboundFailure::Io(error),
991 }));
992 } else {
993 tracing::debug!("Inbound failure is reported for an unknown request_id ({request_id}): {error}");
995 }
996 }
997 }
998 }
999
1000 #[tracing::instrument(level = "trace", name = "NetworkBehaviour::poll", skip(self))]
1001 fn poll(&mut self, _: &mut Context<'_>) -> Poll<ToSwarm<Self::ToSwarm, THandlerInEvent<Self>>> {
1002 if let Some(ev) = self.pending_events.pop_front() {
1003 return Poll::Ready(ev);
1004 } else if self.pending_events.capacity() > EMPTY_QUEUE_SHRINK_THRESHOLD {
1005 self.pending_events.shrink_to_fit();
1006 }
1007
1008 Poll::Pending
1009 }
1010}
1011
1012const EMPTY_QUEUE_SHRINK_THRESHOLD: usize = 100;
1017
1018struct Connection {
1020 id: ConnectionId,
1021 remote_address: Option<Multiaddr>,
1022 pending_outbound_responses: HashSet<OutboundRequestId>,
1026 pending_inbound_responses: HashSet<InboundRequestId>,
1029}
1030
1031impl Connection {
1032 fn new(id: ConnectionId, remote_address: Option<Multiaddr>) -> Self {
1033 Self {
1034 id,
1035 remote_address,
1036 pending_outbound_responses: Default::default(),
1037 pending_inbound_responses: Default::default(),
1038 }
1039 }
1040}