1mod error;
22
23pub(crate) mod pool;
24mod supported_protocols;
25
26use std::{
27 collections::{HashMap, HashSet},
28 fmt,
29 fmt::{Display, Formatter},
30 future::Future,
31 io, mem,
32 pin::Pin,
33 sync::atomic::{AtomicUsize, Ordering},
34 task::{Context, Poll, Waker},
35 time::Duration,
36};
37
38pub use error::ConnectionError;
39pub(crate) use error::{PendingInboundConnectionError, PendingOutboundConnectionError};
40use futures::{future::BoxFuture, stream, stream::FuturesUnordered, FutureExt, StreamExt};
41use futures_timer::Delay;
42use libp2p_core::{
43 connection::ConnectedPoint,
44 multiaddr::Multiaddr,
45 muxing::{StreamMuxerBox, StreamMuxerEvent, StreamMuxerExt, SubstreamBox},
46 transport::PortUse,
47 upgrade,
48 upgrade::{NegotiationError, ProtocolError},
49 Endpoint,
50};
51use libp2p_identity::PeerId;
52pub use supported_protocols::SupportedProtocols;
53use web_time::Instant;
54
55use crate::{
56 handler::{
57 AddressChange, ConnectionEvent, ConnectionHandler, DialUpgradeError,
58 FullyNegotiatedInbound, FullyNegotiatedOutbound, ListenUpgradeError, ProtocolSupport,
59 ProtocolsChange, UpgradeInfoSend,
60 },
61 stream::ActiveStreamCounter,
62 upgrade::{InboundUpgradeSend, OutboundUpgradeSend},
63 ConnectionHandlerEvent, Stream, StreamProtocol, StreamUpgradeError, SubstreamProtocol,
64};
65
66static NEXT_CONNECTION_ID: AtomicUsize = AtomicUsize::new(1);
67
68#[derive(Debug, Copy, Clone, Hash, PartialEq, Eq, PartialOrd, Ord)]
70pub struct ConnectionId(usize);
71
72impl ConnectionId {
73 pub fn new_unchecked(id: usize) -> Self {
81 Self(id)
82 }
83
84 pub(crate) fn next() -> Self {
86 Self(NEXT_CONNECTION_ID.fetch_add(1, Ordering::SeqCst))
87 }
88}
89
90impl Display for ConnectionId {
91 fn fmt(&self, f: &mut Formatter) -> fmt::Result {
92 write!(f, "{}", self.0)
93 }
94}
95
96#[derive(Debug, Clone, PartialEq, Eq)]
98pub(crate) struct Connected {
99 pub(crate) endpoint: ConnectedPoint,
101 pub(crate) peer_id: PeerId,
103}
104
105#[derive(Debug, Clone)]
107pub(crate) enum Event<T> {
108 Handler(T),
110 AddressChange(Multiaddr),
112}
113
114pub(crate) struct Connection<THandler>
116where
117 THandler: ConnectionHandler,
118{
119 muxing: StreamMuxerBox,
121 handler: THandler,
123 negotiating_in: FuturesUnordered<
125 StreamUpgrade<
126 THandler::InboundOpenInfo,
127 <THandler::InboundProtocol as InboundUpgradeSend>::Output,
128 <THandler::InboundProtocol as InboundUpgradeSend>::Error,
129 >,
130 >,
131 negotiating_out: FuturesUnordered<
133 StreamUpgrade<
134 THandler::OutboundOpenInfo,
135 <THandler::OutboundProtocol as OutboundUpgradeSend>::Output,
136 <THandler::OutboundProtocol as OutboundUpgradeSend>::Error,
137 >,
138 >,
139 shutdown: Shutdown,
141 substream_upgrade_protocol_override: Option<upgrade::Version>,
143 max_negotiating_inbound_streams: usize,
152 requested_substreams: FuturesUnordered<
157 SubstreamRequested<THandler::OutboundOpenInfo, THandler::OutboundProtocol>,
158 >,
159
160 local_supported_protocols:
161 HashMap<AsStrHashEq<<THandler::InboundProtocol as UpgradeInfoSend>::Info>, bool>,
162 remote_supported_protocols: HashSet<StreamProtocol>,
163 protocol_buffer: Vec<StreamProtocol>,
164
165 idle_timeout: Duration,
166 stream_counter: ActiveStreamCounter,
167}
168
169impl<THandler> fmt::Debug for Connection<THandler>
170where
171 THandler: ConnectionHandler + fmt::Debug,
172 THandler::OutboundOpenInfo: fmt::Debug,
173{
174 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
175 f.debug_struct("Connection")
176 .field("handler", &self.handler)
177 .finish()
178 }
179}
180
181impl<THandler> Unpin for Connection<THandler> where THandler: ConnectionHandler {}
182
183impl<THandler> Connection<THandler>
184where
185 THandler: ConnectionHandler,
186{
187 pub(crate) fn new(
190 muxer: StreamMuxerBox,
191 mut handler: THandler,
192 substream_upgrade_protocol_override: Option<upgrade::Version>,
193 max_negotiating_inbound_streams: usize,
194 idle_timeout: Duration,
195 ) -> Self {
196 let initial_protocols = gather_supported_protocols(&handler);
197 let mut buffer = Vec::new();
198
199 if !initial_protocols.is_empty() {
200 handler.on_connection_event(ConnectionEvent::LocalProtocolsChange(
201 ProtocolsChange::from_initial_protocols(
202 initial_protocols.keys().map(|e| &e.0),
203 &mut buffer,
204 ),
205 ));
206 }
207
208 Connection {
209 muxing: muxer,
210 handler,
211 negotiating_in: Default::default(),
212 negotiating_out: Default::default(),
213 shutdown: Shutdown::None,
214 substream_upgrade_protocol_override,
215 max_negotiating_inbound_streams,
216 requested_substreams: Default::default(),
217 local_supported_protocols: initial_protocols,
218 remote_supported_protocols: Default::default(),
219 protocol_buffer: buffer,
220 idle_timeout,
221 stream_counter: ActiveStreamCounter::default(),
222 }
223 }
224
225 pub(crate) fn on_behaviour_event(&mut self, event: THandler::FromBehaviour) {
227 self.handler.on_behaviour_event(event);
228 }
229
230 pub(crate) fn close(
233 self,
234 ) -> (
235 impl futures::Stream<Item = THandler::ToBehaviour>,
236 impl Future<Output = io::Result<()>>,
237 ) {
238 let Connection {
239 mut handler,
240 muxing,
241 ..
242 } = self;
243
244 (
245 stream::poll_fn(move |cx| handler.poll_close(cx)),
246 muxing.close(),
247 )
248 }
249
250 #[tracing::instrument(level = "debug", name = "Connection::poll", skip(self, cx))]
253 pub(crate) fn poll(
254 self: Pin<&mut Self>,
255 cx: &mut Context<'_>,
256 ) -> Poll<Result<Event<THandler::ToBehaviour>, ConnectionError>> {
257 let Self {
258 requested_substreams,
259 muxing,
260 handler,
261 negotiating_out,
262 negotiating_in,
263 shutdown,
264 max_negotiating_inbound_streams,
265 substream_upgrade_protocol_override,
266 local_supported_protocols: supported_protocols,
267 remote_supported_protocols,
268 protocol_buffer,
269 idle_timeout,
270 stream_counter,
271 ..
272 } = self.get_mut();
273
274 loop {
275 match requested_substreams.poll_next_unpin(cx) {
276 Poll::Ready(Some(Ok(()))) => continue,
277 Poll::Ready(Some(Err(info))) => {
278 handler.on_connection_event(ConnectionEvent::DialUpgradeError(
279 DialUpgradeError {
280 info,
281 error: StreamUpgradeError::Timeout,
282 },
283 ));
284 continue;
285 }
286 Poll::Ready(None) | Poll::Pending => {}
287 }
288
289 match handler.poll(cx) {
291 Poll::Pending => {}
292 Poll::Ready(ConnectionHandlerEvent::OutboundSubstreamRequest { protocol }) => {
293 let timeout = *protocol.timeout();
294 let (upgrade, user_data) = protocol.into_upgrade();
295
296 requested_substreams.push(SubstreamRequested::new(user_data, timeout, upgrade));
297 continue; }
299 Poll::Ready(ConnectionHandlerEvent::NotifyBehaviour(event)) => {
300 return Poll::Ready(Ok(Event::Handler(event)));
301 }
302 Poll::Ready(ConnectionHandlerEvent::ReportRemoteProtocols(
303 ProtocolSupport::Added(protocols),
304 )) => {
305 if let Some(added) =
306 ProtocolsChange::add(remote_supported_protocols, protocols, protocol_buffer)
307 {
308 handler.on_connection_event(ConnectionEvent::RemoteProtocolsChange(added));
309 remote_supported_protocols.extend(protocol_buffer.drain(..));
310 }
311 continue;
312 }
313 Poll::Ready(ConnectionHandlerEvent::ReportRemoteProtocols(
314 ProtocolSupport::Removed(protocols),
315 )) => {
316 if let Some(removed) = ProtocolsChange::remove(
317 remote_supported_protocols,
318 protocols,
319 protocol_buffer,
320 ) {
321 handler
322 .on_connection_event(ConnectionEvent::RemoteProtocolsChange(removed));
323 }
324 continue;
325 }
326 }
327
328 match negotiating_out.poll_next_unpin(cx) {
331 Poll::Pending | Poll::Ready(None) => {}
332 Poll::Ready(Some((info, Ok(protocol)))) => {
333 handler.on_connection_event(ConnectionEvent::FullyNegotiatedOutbound(
334 FullyNegotiatedOutbound { protocol, info },
335 ));
336 continue;
337 }
338 Poll::Ready(Some((info, Err(error)))) => {
339 handler.on_connection_event(ConnectionEvent::DialUpgradeError(
340 DialUpgradeError { info, error },
341 ));
342 continue;
343 }
344 }
345
346 match negotiating_in.poll_next_unpin(cx) {
349 Poll::Pending | Poll::Ready(None) => {}
350 Poll::Ready(Some((info, Ok(protocol)))) => {
351 handler.on_connection_event(ConnectionEvent::FullyNegotiatedInbound(
352 FullyNegotiatedInbound { protocol, info },
353 ));
354 continue;
355 }
356 Poll::Ready(Some((info, Err(StreamUpgradeError::Apply(error))))) => {
357 handler.on_connection_event(ConnectionEvent::ListenUpgradeError(
358 ListenUpgradeError { info, error },
359 ));
360 continue;
361 }
362 Poll::Ready(Some((_, Err(StreamUpgradeError::Io(e))))) => {
363 tracing::debug!("failed to upgrade inbound stream: {e}");
364 continue;
365 }
366 Poll::Ready(Some((_, Err(StreamUpgradeError::NegotiationFailed)))) => {
367 tracing::debug!("no protocol could be agreed upon for inbound stream");
368 continue;
369 }
370 Poll::Ready(Some((_, Err(StreamUpgradeError::Timeout)))) => {
371 tracing::debug!("inbound stream upgrade timed out");
372 continue;
373 }
374 }
375
376 if negotiating_in.is_empty()
380 && negotiating_out.is_empty()
381 && requested_substreams.is_empty()
382 && stream_counter.has_no_active_streams()
383 {
384 if let Some(new_timeout) =
385 compute_new_shutdown(handler.connection_keep_alive(), shutdown, *idle_timeout)
386 {
387 *shutdown = new_timeout;
388 }
389
390 match shutdown {
391 Shutdown::None => {}
392 Shutdown::Asap => return Poll::Ready(Err(ConnectionError::KeepAliveTimeout)),
393 Shutdown::Later(delay) => match Future::poll(Pin::new(delay), cx) {
394 Poll::Ready(_) => {
395 return Poll::Ready(Err(ConnectionError::KeepAliveTimeout))
396 }
397 Poll::Pending => {}
398 },
399 }
400 } else {
401 *shutdown = Shutdown::None;
402 }
403
404 match muxing.poll_unpin(cx)? {
405 Poll::Pending => {}
406 Poll::Ready(StreamMuxerEvent::AddressChange(address)) => {
407 handler.on_connection_event(ConnectionEvent::AddressChange(AddressChange {
408 new_address: &address,
409 }));
410 return Poll::Ready(Ok(Event::AddressChange(address)));
411 }
412 }
413
414 if let Some(requested_substream) = requested_substreams.iter_mut().next() {
415 match muxing.poll_outbound_unpin(cx)? {
416 Poll::Pending => {}
417 Poll::Ready(substream) => {
418 let (user_data, timeout, upgrade) = requested_substream.extract();
419
420 negotiating_out.push(StreamUpgrade::new_outbound(
421 substream,
422 user_data,
423 timeout,
424 upgrade,
425 *substream_upgrade_protocol_override,
426 stream_counter.clone(),
427 ));
428
429 continue;
432 }
433 }
434 }
435
436 if negotiating_in.len() < *max_negotiating_inbound_streams {
437 match muxing.poll_inbound_unpin(cx)? {
438 Poll::Pending => {}
439 Poll::Ready(substream) => {
440 let protocol = handler.listen_protocol();
441
442 negotiating_in.push(StreamUpgrade::new_inbound(
443 substream,
444 protocol,
445 stream_counter.clone(),
446 ));
447
448 continue;
451 }
452 }
453 }
454
455 let changes = ProtocolsChange::from_full_sets(
456 supported_protocols,
457 handler.listen_protocol().upgrade().protocol_info(),
458 protocol_buffer,
459 );
460
461 if !changes.is_empty() {
462 for change in changes {
463 handler.on_connection_event(ConnectionEvent::LocalProtocolsChange(change));
464 }
465 continue;
467 }
468
469 return Poll::Pending;
471 }
472 }
473
474 #[cfg(test)]
475 fn poll_noop_waker(&mut self) -> Poll<Result<Event<THandler::ToBehaviour>, ConnectionError>> {
476 Pin::new(self).poll(&mut Context::from_waker(futures::task::noop_waker_ref()))
477 }
478}
479
480fn gather_supported_protocols<C: ConnectionHandler>(
481 handler: &C,
482) -> HashMap<AsStrHashEq<<C::InboundProtocol as UpgradeInfoSend>::Info>, bool> {
483 handler
484 .listen_protocol()
485 .upgrade()
486 .protocol_info()
487 .map(|info| (AsStrHashEq(info), true))
488 .collect()
489}
490
491fn compute_new_shutdown(
492 handler_keep_alive: bool,
493 current_shutdown: &Shutdown,
494 idle_timeout: Duration,
495) -> Option<Shutdown> {
496 match (current_shutdown, handler_keep_alive) {
497 (_, false) if idle_timeout == Duration::ZERO => Some(Shutdown::Asap),
498 (Shutdown::Later(_), false) => None,
500 (_, false) => {
501 let now = Instant::now();
502 let safe_keep_alive = checked_add_fraction(now, idle_timeout);
503
504 Some(Shutdown::Later(Delay::new(safe_keep_alive)))
505 }
506 (_, true) => Some(Shutdown::None),
507 }
508}
509
510fn checked_add_fraction(start: Instant, mut duration: Duration) -> Duration {
517 while start.checked_add(duration).is_none() {
518 tracing::debug!(start=?start, duration=?duration, "start + duration cannot be presented, halving duration");
519
520 duration /= 2;
521 }
522
523 duration
524}
525
526#[derive(Debug, Copy, Clone)]
528pub(crate) struct IncomingInfo<'a> {
529 pub(crate) local_addr: &'a Multiaddr,
531 pub(crate) send_back_addr: &'a Multiaddr,
533}
534
535impl IncomingInfo<'_> {
536 pub(crate) fn create_connected_point(&self) -> ConnectedPoint {
538 ConnectedPoint::Listener {
539 local_addr: self.local_addr.clone(),
540 send_back_addr: self.send_back_addr.clone(),
541 }
542 }
543}
544
545struct StreamUpgrade<UserData, TOk, TErr> {
546 user_data: Option<UserData>,
547 timeout: Delay,
548 upgrade: BoxFuture<'static, Result<TOk, StreamUpgradeError<TErr>>>,
549}
550
551impl<UserData, TOk, TErr> StreamUpgrade<UserData, TOk, TErr> {
552 fn new_outbound<Upgrade>(
553 substream: SubstreamBox,
554 user_data: UserData,
555 timeout: Delay,
556 upgrade: Upgrade,
557 version_override: Option<upgrade::Version>,
558 counter: ActiveStreamCounter,
559 ) -> Self
560 where
561 Upgrade: OutboundUpgradeSend<Output = TOk, Error = TErr>,
562 {
563 let effective_version = match version_override {
564 Some(version_override) if version_override != upgrade::Version::default() => {
565 tracing::debug!(
566 "Substream upgrade protocol override: {:?} -> {:?}",
567 upgrade::Version::default(),
568 version_override
569 );
570
571 version_override
572 }
573 _ => upgrade::Version::default(),
574 };
575 let protocols = upgrade.protocol_info();
576
577 Self {
578 user_data: Some(user_data),
579 timeout,
580 upgrade: Box::pin(async move {
581 let (info, stream) = multistream_select::dialer_select_proto(
582 substream,
583 protocols,
584 effective_version,
585 )
586 .await
587 .map_err(to_stream_upgrade_error)?;
588
589 let output = upgrade
590 .upgrade_outbound(Stream::new(stream, counter), info)
591 .await
592 .map_err(StreamUpgradeError::Apply)?;
593
594 Ok(output)
595 }),
596 }
597 }
598}
599
600impl<UserData, TOk, TErr> StreamUpgrade<UserData, TOk, TErr> {
601 fn new_inbound<Upgrade>(
602 substream: SubstreamBox,
603 protocol: SubstreamProtocol<Upgrade, UserData>,
604 counter: ActiveStreamCounter,
605 ) -> Self
606 where
607 Upgrade: InboundUpgradeSend<Output = TOk, Error = TErr>,
608 {
609 let timeout = *protocol.timeout();
610 let (upgrade, open_info) = protocol.into_upgrade();
611 let protocols = upgrade.protocol_info();
612
613 Self {
614 user_data: Some(open_info),
615 timeout: Delay::new(timeout),
616 upgrade: Box::pin(async move {
617 let (info, stream) =
618 multistream_select::listener_select_proto(substream, protocols)
619 .await
620 .map_err(to_stream_upgrade_error)?;
621
622 let output = upgrade
623 .upgrade_inbound(Stream::new(stream, counter), info)
624 .await
625 .map_err(StreamUpgradeError::Apply)?;
626
627 Ok(output)
628 }),
629 }
630 }
631}
632
633fn to_stream_upgrade_error<T>(e: NegotiationError) -> StreamUpgradeError<T> {
634 match e {
635 NegotiationError::Failed => StreamUpgradeError::NegotiationFailed,
636 NegotiationError::ProtocolError(ProtocolError::IoError(e)) => StreamUpgradeError::Io(e),
637 NegotiationError::ProtocolError(other) => {
638 StreamUpgradeError::Io(io::Error::new(io::ErrorKind::Other, other))
639 }
640 }
641}
642
643impl<UserData, TOk, TErr> Unpin for StreamUpgrade<UserData, TOk, TErr> {}
644
645impl<UserData, TOk, TErr> Future for StreamUpgrade<UserData, TOk, TErr> {
646 type Output = (UserData, Result<TOk, StreamUpgradeError<TErr>>);
647
648 fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
649 match self.timeout.poll_unpin(cx) {
650 Poll::Ready(()) => {
651 return Poll::Ready((
652 self.user_data
653 .take()
654 .expect("Future not to be polled again once ready."),
655 Err(StreamUpgradeError::Timeout),
656 ))
657 }
658
659 Poll::Pending => {}
660 }
661
662 let result = futures::ready!(self.upgrade.poll_unpin(cx));
663 let user_data = self
664 .user_data
665 .take()
666 .expect("Future not to be polled again once ready.");
667
668 Poll::Ready((user_data, result))
669 }
670}
671
672enum SubstreamRequested<UserData, Upgrade> {
673 Waiting {
674 user_data: UserData,
675 timeout: Delay,
676 upgrade: Upgrade,
677 extracted_waker: Option<Waker>,
682 },
683 Done,
684}
685
686impl<UserData, Upgrade> SubstreamRequested<UserData, Upgrade> {
687 fn new(user_data: UserData, timeout: Duration, upgrade: Upgrade) -> Self {
688 Self::Waiting {
689 user_data,
690 timeout: Delay::new(timeout),
691 upgrade,
692 extracted_waker: None,
693 }
694 }
695
696 fn extract(&mut self) -> (UserData, Delay, Upgrade) {
697 match mem::replace(self, Self::Done) {
698 SubstreamRequested::Waiting {
699 user_data,
700 timeout,
701 upgrade,
702 extracted_waker: waker,
703 } => {
704 if let Some(waker) = waker {
705 waker.wake();
706 }
707
708 (user_data, timeout, upgrade)
709 }
710 SubstreamRequested::Done => panic!("cannot extract twice"),
711 }
712 }
713}
714
715impl<UserData, Upgrade> Unpin for SubstreamRequested<UserData, Upgrade> {}
716
717impl<UserData, Upgrade> Future for SubstreamRequested<UserData, Upgrade> {
718 type Output = Result<(), UserData>;
719
720 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
721 let this = self.get_mut();
722
723 match mem::replace(this, Self::Done) {
724 SubstreamRequested::Waiting {
725 user_data,
726 upgrade,
727 mut timeout,
728 ..
729 } => match timeout.poll_unpin(cx) {
730 Poll::Ready(()) => Poll::Ready(Err(user_data)),
731 Poll::Pending => {
732 *this = Self::Waiting {
733 user_data,
734 upgrade,
735 timeout,
736 extracted_waker: Some(cx.waker().clone()),
737 };
738 Poll::Pending
739 }
740 },
741 SubstreamRequested::Done => Poll::Ready(Ok(())),
742 }
743 }
744}
745
746#[derive(Debug)]
756enum Shutdown {
757 None,
759 Asap,
761 Later(Delay),
763}
764
765pub(crate) struct AsStrHashEq<T>(pub(crate) T);
769
770impl<T: AsRef<str>> Eq for AsStrHashEq<T> {}
771
772impl<T: AsRef<str>> PartialEq for AsStrHashEq<T> {
773 fn eq(&self, other: &Self) -> bool {
774 self.0.as_ref() == other.0.as_ref()
775 }
776}
777
778impl<T: AsRef<str>> std::hash::Hash for AsStrHashEq<T> {
779 fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
780 self.0.as_ref().hash(state)
781 }
782}
783
784#[cfg(test)]
785mod tests {
786 use std::{
787 convert::Infallible,
788 sync::{Arc, Weak},
789 time::Instant,
790 };
791
792 use futures::{future, AsyncRead, AsyncWrite};
793 use libp2p_core::{
794 upgrade::{DeniedUpgrade, InboundUpgrade, OutboundUpgrade, UpgradeInfo},
795 StreamMuxer,
796 };
797 use quickcheck::*;
798 use tracing_subscriber::EnvFilter;
799
800 use super::*;
801 use crate::dummy;
802
803 #[test]
804 fn max_negotiating_inbound_streams() {
805 let _ = tracing_subscriber::fmt()
806 .with_env_filter(EnvFilter::from_default_env())
807 .try_init();
808
809 fn prop(max_negotiating_inbound_streams: u8) {
810 let max_negotiating_inbound_streams: usize = max_negotiating_inbound_streams.into();
811
812 let alive_substream_counter = Arc::new(());
813 let mut connection = Connection::new(
814 StreamMuxerBox::new(DummyStreamMuxer {
815 counter: alive_substream_counter.clone(),
816 }),
817 MockConnectionHandler::new(Duration::from_secs(10)),
818 None,
819 max_negotiating_inbound_streams,
820 Duration::ZERO,
821 );
822
823 let result = connection.poll_noop_waker();
824
825 assert!(result.is_pending());
826 assert_eq!(
827 Arc::weak_count(&alive_substream_counter),
828 max_negotiating_inbound_streams,
829 "Expect no more than the maximum number of allowed streams"
830 );
831 }
832
833 QuickCheck::new().quickcheck(prop as fn(_));
834 }
835
836 #[test]
837 fn outbound_stream_timeout_starts_on_request() {
838 let upgrade_timeout = Duration::from_secs(1);
839 let mut connection = Connection::new(
840 StreamMuxerBox::new(PendingStreamMuxer),
841 MockConnectionHandler::new(upgrade_timeout),
842 None,
843 2,
844 Duration::ZERO,
845 );
846
847 connection.handler.open_new_outbound();
848 let _ = connection.poll_noop_waker();
849
850 std::thread::sleep(upgrade_timeout + Duration::from_secs(1));
851
852 let _ = connection.poll_noop_waker();
853
854 assert!(matches!(
855 connection.handler.error.unwrap(),
856 StreamUpgradeError::Timeout
857 ))
858 }
859
860 #[test]
861 fn propagates_changes_to_supported_inbound_protocols() {
862 let mut connection = Connection::new(
863 StreamMuxerBox::new(PendingStreamMuxer),
864 ConfigurableProtocolConnectionHandler::default(),
865 None,
866 0,
867 Duration::ZERO,
868 );
869
870 connection.handler.listen_on(&["/foo"]);
872 let _ = connection.poll_noop_waker();
873
874 assert_eq!(connection.handler.local_added, vec![vec!["/foo"]]);
875 assert!(connection.handler.local_removed.is_empty());
876
877 connection.handler.listen_on(&["/foo", "/bar"]);
879 let _ = connection.poll_noop_waker();
880
881 assert_eq!(
882 connection.handler.local_added,
883 vec![vec!["/foo"], vec!["/bar"]],
884 "expect to only receive an event for the newly added protocols"
885 );
886 assert!(connection.handler.local_removed.is_empty());
887
888 connection.handler.listen_on(&["/bar"]);
890 let _ = connection.poll_noop_waker();
891
892 assert_eq!(
893 connection.handler.local_added,
894 vec![vec!["/foo"], vec!["/bar"]]
895 );
896 assert_eq!(connection.handler.local_removed, vec![vec!["/foo"]]);
897 }
898
899 #[test]
900 fn only_propagtes_actual_changes_to_remote_protocols_to_handler() {
901 let mut connection = Connection::new(
902 StreamMuxerBox::new(PendingStreamMuxer),
903 ConfigurableProtocolConnectionHandler::default(),
904 None,
905 0,
906 Duration::ZERO,
907 );
908
909 connection.handler.remote_adds_support_for(&["/foo"]);
911 let _ = connection.poll_noop_waker();
912
913 assert_eq!(connection.handler.remote_added, vec![vec!["/foo"]]);
914 assert!(connection.handler.remote_removed.is_empty());
915
916 connection
918 .handler
919 .remote_adds_support_for(&["/foo", "/bar"]);
920 let _ = connection.poll_noop_waker();
921
922 assert_eq!(
923 connection.handler.remote_added,
924 vec![vec!["/foo"], vec!["/bar"]],
925 "expect to only receive an event for the newly added protocol"
926 );
927 assert!(connection.handler.remote_removed.is_empty());
928
929 connection.handler.remote_removes_support_for(&["/baz"]);
932 let _ = connection.poll_noop_waker();
933
934 assert_eq!(
935 connection.handler.remote_added,
936 vec![vec!["/foo"], vec!["/bar"]]
937 );
938 assert!(&connection.handler.remote_removed.is_empty());
939
940 connection.handler.remote_removes_support_for(&["/bar"]);
942 let _ = connection.poll_noop_waker();
943
944 assert_eq!(
945 connection.handler.remote_added,
946 vec![vec!["/foo"], vec!["/bar"]]
947 );
948 assert_eq!(connection.handler.remote_removed, vec![vec!["/bar"]]);
949 }
950
951 #[tokio::test]
952 async fn idle_timeout_with_keep_alive_no() {
953 let idle_timeout = Duration::from_millis(100);
954
955 let mut connection = Connection::new(
956 StreamMuxerBox::new(PendingStreamMuxer),
957 dummy::ConnectionHandler,
958 None,
959 0,
960 idle_timeout,
961 );
962
963 assert!(connection.poll_noop_waker().is_pending());
964
965 tokio::time::sleep(idle_timeout).await;
966
967 assert!(matches!(
968 connection.poll_noop_waker(),
969 Poll::Ready(Err(ConnectionError::KeepAliveTimeout))
970 ));
971 }
972
973 #[test]
974 fn checked_add_fraction_can_add_u64_max() {
975 let _ = tracing_subscriber::fmt()
976 .with_env_filter(tracing_subscriber::EnvFilter::from_default_env())
977 .try_init();
978 let start = Instant::now();
979
980 let duration = checked_add_fraction(start, Duration::from_secs(u64::MAX));
981
982 assert!(start.checked_add(duration).is_some())
983 }
984
985 #[test]
986 fn compute_new_shutdown_does_not_panic() {
987 let _ = tracing_subscriber::fmt()
988 .with_env_filter(EnvFilter::from_default_env())
989 .try_init();
990
991 #[derive(Debug)]
992 struct ArbitraryShutdown(Shutdown);
993
994 impl Clone for ArbitraryShutdown {
995 fn clone(&self) -> Self {
996 let shutdown = match self.0 {
997 Shutdown::None => Shutdown::None,
998 Shutdown::Asap => Shutdown::Asap,
999 Shutdown::Later(_) => Shutdown::Later(
1000 Delay::new(Duration::from_secs(1)),
1003 ),
1004 };
1005
1006 ArbitraryShutdown(shutdown)
1007 }
1008 }
1009
1010 impl Arbitrary for ArbitraryShutdown {
1011 fn arbitrary(g: &mut Gen) -> Self {
1012 let shutdown = match g.gen_range(1u8..4) {
1013 1 => Shutdown::None,
1014 2 => Shutdown::Asap,
1015 3 => Shutdown::Later(Delay::new(Duration::from_secs(u32::arbitrary(g) as u64))),
1016 _ => unreachable!(),
1017 };
1018
1019 Self(shutdown)
1020 }
1021 }
1022
1023 fn prop(
1024 handler_keep_alive: bool,
1025 current_shutdown: ArbitraryShutdown,
1026 idle_timeout: Duration,
1027 ) {
1028 compute_new_shutdown(handler_keep_alive, ¤t_shutdown.0, idle_timeout);
1029 }
1030
1031 QuickCheck::new().quickcheck(prop as fn(_, _, _));
1032 }
1033
1034 struct DummyStreamMuxer {
1035 counter: Arc<()>,
1036 }
1037
1038 impl StreamMuxer for DummyStreamMuxer {
1039 type Substream = PendingSubstream;
1040 type Error = Infallible;
1041
1042 fn poll_inbound(
1043 self: Pin<&mut Self>,
1044 _: &mut Context<'_>,
1045 ) -> Poll<Result<Self::Substream, Self::Error>> {
1046 Poll::Ready(Ok(PendingSubstream {
1047 _weak: Arc::downgrade(&self.counter),
1048 }))
1049 }
1050
1051 fn poll_outbound(
1052 self: Pin<&mut Self>,
1053 _: &mut Context<'_>,
1054 ) -> Poll<Result<Self::Substream, Self::Error>> {
1055 Poll::Pending
1056 }
1057
1058 fn poll_close(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
1059 Poll::Ready(Ok(()))
1060 }
1061
1062 fn poll(
1063 self: Pin<&mut Self>,
1064 _: &mut Context<'_>,
1065 ) -> Poll<Result<StreamMuxerEvent, Self::Error>> {
1066 Poll::Pending
1067 }
1068 }
1069
1070 struct PendingStreamMuxer;
1072
1073 impl StreamMuxer for PendingStreamMuxer {
1074 type Substream = PendingSubstream;
1075 type Error = Infallible;
1076
1077 fn poll_inbound(
1078 self: Pin<&mut Self>,
1079 _: &mut Context<'_>,
1080 ) -> Poll<Result<Self::Substream, Self::Error>> {
1081 Poll::Pending
1082 }
1083
1084 fn poll_outbound(
1085 self: Pin<&mut Self>,
1086 _: &mut Context<'_>,
1087 ) -> Poll<Result<Self::Substream, Self::Error>> {
1088 Poll::Pending
1089 }
1090
1091 fn poll_close(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
1092 Poll::Pending
1093 }
1094
1095 fn poll(
1096 self: Pin<&mut Self>,
1097 _: &mut Context<'_>,
1098 ) -> Poll<Result<StreamMuxerEvent, Self::Error>> {
1099 Poll::Pending
1100 }
1101 }
1102
1103 struct PendingSubstream {
1104 _weak: Weak<()>,
1105 }
1106
1107 impl AsyncRead for PendingSubstream {
1108 fn poll_read(
1109 self: Pin<&mut Self>,
1110 _cx: &mut Context<'_>,
1111 _buf: &mut [u8],
1112 ) -> Poll<std::io::Result<usize>> {
1113 Poll::Pending
1114 }
1115 }
1116
1117 impl AsyncWrite for PendingSubstream {
1118 fn poll_write(
1119 self: Pin<&mut Self>,
1120 _cx: &mut Context<'_>,
1121 _buf: &[u8],
1122 ) -> Poll<std::io::Result<usize>> {
1123 Poll::Pending
1124 }
1125
1126 fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
1127 Poll::Pending
1128 }
1129
1130 fn poll_close(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
1131 Poll::Pending
1132 }
1133 }
1134
1135 struct MockConnectionHandler {
1136 outbound_requested: bool,
1137 error: Option<StreamUpgradeError<Infallible>>,
1138 upgrade_timeout: Duration,
1139 }
1140
1141 impl MockConnectionHandler {
1142 fn new(upgrade_timeout: Duration) -> Self {
1143 Self {
1144 outbound_requested: false,
1145 error: None,
1146 upgrade_timeout,
1147 }
1148 }
1149
1150 fn open_new_outbound(&mut self) {
1151 self.outbound_requested = true;
1152 }
1153 }
1154
1155 #[derive(Default)]
1156 struct ConfigurableProtocolConnectionHandler {
1157 events: Vec<ConnectionHandlerEvent<DeniedUpgrade, (), Infallible>>,
1158 active_protocols: HashSet<StreamProtocol>,
1159 local_added: Vec<Vec<StreamProtocol>>,
1160 local_removed: Vec<Vec<StreamProtocol>>,
1161 remote_added: Vec<Vec<StreamProtocol>>,
1162 remote_removed: Vec<Vec<StreamProtocol>>,
1163 }
1164
1165 impl ConfigurableProtocolConnectionHandler {
1166 fn listen_on(&mut self, protocols: &[&'static str]) {
1167 self.active_protocols = protocols.iter().copied().map(StreamProtocol::new).collect();
1168 }
1169
1170 fn remote_adds_support_for(&mut self, protocols: &[&'static str]) {
1171 self.events
1172 .push(ConnectionHandlerEvent::ReportRemoteProtocols(
1173 ProtocolSupport::Added(
1174 protocols.iter().copied().map(StreamProtocol::new).collect(),
1175 ),
1176 ));
1177 }
1178
1179 fn remote_removes_support_for(&mut self, protocols: &[&'static str]) {
1180 self.events
1181 .push(ConnectionHandlerEvent::ReportRemoteProtocols(
1182 ProtocolSupport::Removed(
1183 protocols.iter().copied().map(StreamProtocol::new).collect(),
1184 ),
1185 ));
1186 }
1187 }
1188
1189 impl ConnectionHandler for MockConnectionHandler {
1190 type FromBehaviour = Infallible;
1191 type ToBehaviour = Infallible;
1192 type InboundProtocol = DeniedUpgrade;
1193 type OutboundProtocol = DeniedUpgrade;
1194 type InboundOpenInfo = ();
1195 type OutboundOpenInfo = ();
1196
1197 fn listen_protocol(&self) -> SubstreamProtocol<Self::InboundProtocol> {
1198 SubstreamProtocol::new(DeniedUpgrade, ()).with_timeout(self.upgrade_timeout)
1199 }
1200
1201 fn on_connection_event(
1202 &mut self,
1203 event: ConnectionEvent<Self::InboundProtocol, Self::OutboundProtocol>,
1204 ) {
1205 match event {
1206 ConnectionEvent::FullyNegotiatedInbound(FullyNegotiatedInbound {
1207 protocol,
1208 ..
1209 }) => libp2p_core::util::unreachable(protocol),
1210 ConnectionEvent::FullyNegotiatedOutbound(FullyNegotiatedOutbound {
1211 protocol,
1212 ..
1213 }) => libp2p_core::util::unreachable(protocol),
1214 ConnectionEvent::DialUpgradeError(DialUpgradeError { error, .. }) => {
1215 self.error = Some(error)
1216 }
1217 ConnectionEvent::AddressChange(_)
1218 | ConnectionEvent::ListenUpgradeError(_)
1219 | ConnectionEvent::LocalProtocolsChange(_)
1220 | ConnectionEvent::RemoteProtocolsChange(_) => {}
1221 }
1222 }
1223
1224 fn on_behaviour_event(&mut self, event: Self::FromBehaviour) {
1225 libp2p_core::util::unreachable(event)
1226 }
1227
1228 fn connection_keep_alive(&self) -> bool {
1229 true
1230 }
1231
1232 fn poll(
1233 &mut self,
1234 _: &mut Context<'_>,
1235 ) -> Poll<ConnectionHandlerEvent<Self::OutboundProtocol, (), Self::ToBehaviour>> {
1236 if self.outbound_requested {
1237 self.outbound_requested = false;
1238 return Poll::Ready(ConnectionHandlerEvent::OutboundSubstreamRequest {
1239 protocol: SubstreamProtocol::new(DeniedUpgrade, ())
1240 .with_timeout(self.upgrade_timeout),
1241 });
1242 }
1243
1244 Poll::Pending
1245 }
1246 }
1247
1248 impl ConnectionHandler for ConfigurableProtocolConnectionHandler {
1249 type FromBehaviour = Infallible;
1250 type ToBehaviour = Infallible;
1251 type InboundProtocol = ManyProtocolsUpgrade;
1252 type OutboundProtocol = DeniedUpgrade;
1253 type InboundOpenInfo = ();
1254 type OutboundOpenInfo = ();
1255
1256 fn listen_protocol(&self) -> SubstreamProtocol<Self::InboundProtocol> {
1257 SubstreamProtocol::new(
1258 ManyProtocolsUpgrade {
1259 protocols: Vec::from_iter(self.active_protocols.clone()),
1260 },
1261 (),
1262 )
1263 }
1264
1265 fn on_connection_event(
1266 &mut self,
1267 event: ConnectionEvent<Self::InboundProtocol, Self::OutboundProtocol>,
1268 ) {
1269 match event {
1270 ConnectionEvent::LocalProtocolsChange(ProtocolsChange::Added(added)) => {
1271 self.local_added.push(added.cloned().collect())
1272 }
1273 ConnectionEvent::LocalProtocolsChange(ProtocolsChange::Removed(removed)) => {
1274 self.local_removed.push(removed.cloned().collect())
1275 }
1276 ConnectionEvent::RemoteProtocolsChange(ProtocolsChange::Added(added)) => {
1277 self.remote_added.push(added.cloned().collect())
1278 }
1279 ConnectionEvent::RemoteProtocolsChange(ProtocolsChange::Removed(removed)) => {
1280 self.remote_removed.push(removed.cloned().collect())
1281 }
1282 _ => {}
1283 }
1284 }
1285
1286 fn on_behaviour_event(&mut self, event: Self::FromBehaviour) {
1287 libp2p_core::util::unreachable(event)
1288 }
1289
1290 fn connection_keep_alive(&self) -> bool {
1291 true
1292 }
1293
1294 fn poll(
1295 &mut self,
1296 _: &mut Context<'_>,
1297 ) -> Poll<ConnectionHandlerEvent<Self::OutboundProtocol, (), Self::ToBehaviour>> {
1298 if let Some(event) = self.events.pop() {
1299 return Poll::Ready(event);
1300 }
1301
1302 Poll::Pending
1303 }
1304 }
1305
1306 struct ManyProtocolsUpgrade {
1307 protocols: Vec<StreamProtocol>,
1308 }
1309
1310 impl UpgradeInfo for ManyProtocolsUpgrade {
1311 type Info = StreamProtocol;
1312 type InfoIter = std::vec::IntoIter<Self::Info>;
1313
1314 fn protocol_info(&self) -> Self::InfoIter {
1315 self.protocols.clone().into_iter()
1316 }
1317 }
1318
1319 impl<C> InboundUpgrade<C> for ManyProtocolsUpgrade {
1320 type Output = C;
1321 type Error = Infallible;
1322 type Future = future::Ready<Result<Self::Output, Self::Error>>;
1323
1324 fn upgrade_inbound(self, stream: C, _: Self::Info) -> Self::Future {
1325 future::ready(Ok(stream))
1326 }
1327 }
1328
1329 impl<C> OutboundUpgrade<C> for ManyProtocolsUpgrade {
1330 type Output = C;
1331 type Error = Infallible;
1332 type Future = future::Ready<Result<Self::Output, Self::Error>>;
1333
1334 fn upgrade_outbound(self, stream: C, _: Self::Info) -> Self::Future {
1335 future::ready(Ok(stream))
1336 }
1337 }
1338}
1339
1340#[derive(Debug, Clone, PartialEq, Eq, Hash)]
1342enum PendingPoint {
1343 Dialer {
1349 role_override: Endpoint,
1351 port_use: PortUse,
1352 },
1353 Listener {
1355 local_addr: Multiaddr,
1357 send_back_addr: Multiaddr,
1359 },
1360}
1361
1362impl From<ConnectedPoint> for PendingPoint {
1363 fn from(endpoint: ConnectedPoint) -> Self {
1364 match endpoint {
1365 ConnectedPoint::Dialer {
1366 role_override,
1367 port_use,
1368 ..
1369 } => PendingPoint::Dialer {
1370 role_override,
1371 port_use,
1372 },
1373 ConnectedPoint::Listener {
1374 local_addr,
1375 send_back_addr,
1376 } => PendingPoint::Listener {
1377 local_addr,
1378 send_back_addr,
1379 },
1380 }
1381 }
1382}