1mod error;
22
23pub(crate) mod pool;
24mod supported_protocols;
25
26use std::{
27 collections::{HashMap, HashSet},
28 fmt,
29 fmt::{Display, Formatter},
30 future::Future,
31 io, mem,
32 pin::Pin,
33 sync::atomic::{AtomicUsize, Ordering},
34 task::{Context, Poll, Waker},
35 time::Duration,
36};
37
38pub use error::ConnectionError;
39pub(crate) use error::{PendingInboundConnectionError, PendingOutboundConnectionError};
40use futures::{future::BoxFuture, stream, stream::FuturesUnordered, FutureExt, StreamExt};
41use futures_timer::Delay;
42use libp2p_core::{
43 connection::ConnectedPoint,
44 multiaddr::Multiaddr,
45 muxing::{StreamMuxerBox, StreamMuxerEvent, StreamMuxerExt, SubstreamBox},
46 transport::PortUse,
47 upgrade,
48 upgrade::{NegotiationError, ProtocolError},
49 Endpoint,
50};
51use libp2p_identity::PeerId;
52pub use supported_protocols::SupportedProtocols;
53use web_time::Instant;
54
55use crate::{
56 handler::{
57 AddressChange, ConnectionEvent, ConnectionHandler, DialUpgradeError,
58 FullyNegotiatedInbound, FullyNegotiatedOutbound, ListenUpgradeError, ProtocolSupport,
59 ProtocolsChange, UpgradeInfoSend,
60 },
61 stream::ActiveStreamCounter,
62 upgrade::{InboundUpgradeSend, OutboundUpgradeSend},
63 ConnectionHandlerEvent, Stream, StreamProtocol, StreamUpgradeError, SubstreamProtocol,
64};
65
66static NEXT_CONNECTION_ID: AtomicUsize = AtomicUsize::new(1);
67
68#[derive(Debug, Copy, Clone, Hash, PartialEq, Eq, PartialOrd, Ord)]
70pub struct ConnectionId(usize);
71
72impl ConnectionId {
73 pub fn new_unchecked(id: usize) -> Self {
81 Self(id)
82 }
83
84 pub(crate) fn next() -> Self {
86 Self(NEXT_CONNECTION_ID.fetch_add(1, Ordering::SeqCst))
87 }
88}
89
90impl Display for ConnectionId {
91 fn fmt(&self, f: &mut Formatter) -> fmt::Result {
92 write!(f, "{}", self.0)
93 }
94}
95
96#[derive(Debug, Clone, PartialEq, Eq)]
98pub(crate) struct Connected {
99 pub(crate) endpoint: ConnectedPoint,
101 pub(crate) peer_id: PeerId,
103}
104
105#[derive(Debug, Clone)]
107pub(crate) enum Event<T> {
108 Handler(T),
110 AddressChange(Multiaddr),
112}
113
114pub(crate) struct Connection<THandler>
116where
117 THandler: ConnectionHandler,
118{
119 muxing: StreamMuxerBox,
121 handler: THandler,
123 negotiating_in: FuturesUnordered<
125 StreamUpgrade<
126 THandler::InboundOpenInfo,
127 <THandler::InboundProtocol as InboundUpgradeSend>::Output,
128 <THandler::InboundProtocol as InboundUpgradeSend>::Error,
129 >,
130 >,
131 negotiating_out: FuturesUnordered<
133 StreamUpgrade<
134 THandler::OutboundOpenInfo,
135 <THandler::OutboundProtocol as OutboundUpgradeSend>::Output,
136 <THandler::OutboundProtocol as OutboundUpgradeSend>::Error,
137 >,
138 >,
139 shutdown: Shutdown,
141 substream_upgrade_protocol_override: Option<upgrade::Version>,
143 max_negotiating_inbound_streams: usize,
152 requested_substreams: FuturesUnordered<
157 SubstreamRequested<THandler::OutboundOpenInfo, THandler::OutboundProtocol>,
158 >,
159
160 local_supported_protocols:
161 HashMap<AsStrHashEq<<THandler::InboundProtocol as UpgradeInfoSend>::Info>, bool>,
162 remote_supported_protocols: HashSet<StreamProtocol>,
163 protocol_buffer: Vec<StreamProtocol>,
164
165 idle_timeout: Duration,
166 stream_counter: ActiveStreamCounter,
167}
168
169impl<THandler> fmt::Debug for Connection<THandler>
170where
171 THandler: ConnectionHandler + fmt::Debug,
172 THandler::OutboundOpenInfo: fmt::Debug,
173{
174 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
175 f.debug_struct("Connection")
176 .field("handler", &self.handler)
177 .finish()
178 }
179}
180
181impl<THandler> Unpin for Connection<THandler> where THandler: ConnectionHandler {}
182
183impl<THandler> Connection<THandler>
184where
185 THandler: ConnectionHandler,
186{
187 pub(crate) fn new(
190 muxer: StreamMuxerBox,
191 mut handler: THandler,
192 substream_upgrade_protocol_override: Option<upgrade::Version>,
193 max_negotiating_inbound_streams: usize,
194 idle_timeout: Duration,
195 ) -> Self {
196 let initial_protocols = gather_supported_protocols(&handler);
197 let mut buffer = Vec::new();
198
199 if !initial_protocols.is_empty() {
200 handler.on_connection_event(ConnectionEvent::LocalProtocolsChange(
201 ProtocolsChange::from_initial_protocols(
202 initial_protocols.keys().map(|e| &e.0),
203 &mut buffer,
204 ),
205 ));
206 }
207
208 Connection {
209 muxing: muxer,
210 handler,
211 negotiating_in: Default::default(),
212 negotiating_out: Default::default(),
213 shutdown: Shutdown::None,
214 substream_upgrade_protocol_override,
215 max_negotiating_inbound_streams,
216 requested_substreams: Default::default(),
217 local_supported_protocols: initial_protocols,
218 remote_supported_protocols: Default::default(),
219 protocol_buffer: buffer,
220 idle_timeout,
221 stream_counter: ActiveStreamCounter::default(),
222 }
223 }
224
225 pub(crate) fn on_behaviour_event(&mut self, event: THandler::FromBehaviour) {
227 self.handler.on_behaviour_event(event);
228 }
229
230 pub(crate) fn close(
233 self,
234 ) -> (
235 impl futures::Stream<Item = THandler::ToBehaviour>,
236 impl Future<Output = io::Result<()>>,
237 ) {
238 let Connection {
239 mut handler,
240 muxing,
241 ..
242 } = self;
243
244 (
245 stream::poll_fn(move |cx| handler.poll_close(cx)),
246 muxing.close(),
247 )
248 }
249
250 #[tracing::instrument(level = "debug", name = "Connection::poll", skip(self, cx))]
253 pub(crate) fn poll(
254 self: Pin<&mut Self>,
255 cx: &mut Context<'_>,
256 ) -> Poll<Result<Event<THandler::ToBehaviour>, ConnectionError>> {
257 let Self {
258 requested_substreams,
259 muxing,
260 handler,
261 negotiating_out,
262 negotiating_in,
263 shutdown,
264 max_negotiating_inbound_streams,
265 substream_upgrade_protocol_override,
266 local_supported_protocols: supported_protocols,
267 remote_supported_protocols,
268 protocol_buffer,
269 idle_timeout,
270 stream_counter,
271 ..
272 } = self.get_mut();
273
274 loop {
275 match requested_substreams.poll_next_unpin(cx) {
276 Poll::Ready(Some(Ok(()))) => continue,
277 Poll::Ready(Some(Err(info))) => {
278 handler.on_connection_event(ConnectionEvent::DialUpgradeError(
279 DialUpgradeError {
280 info,
281 error: StreamUpgradeError::Timeout,
282 },
283 ));
284 continue;
285 }
286 Poll::Ready(None) | Poll::Pending => {}
287 }
288
289 match handler.poll(cx) {
291 Poll::Pending => {}
292 Poll::Ready(ConnectionHandlerEvent::OutboundSubstreamRequest { protocol }) => {
293 let timeout = *protocol.timeout();
294 let (upgrade, user_data) = protocol.into_upgrade();
295
296 requested_substreams.push(SubstreamRequested::new(user_data, timeout, upgrade));
297 continue; }
299 Poll::Ready(ConnectionHandlerEvent::NotifyBehaviour(event)) => {
300 return Poll::Ready(Ok(Event::Handler(event)));
301 }
302 Poll::Ready(ConnectionHandlerEvent::ReportRemoteProtocols(
303 ProtocolSupport::Added(protocols),
304 )) => {
305 if let Some(added) =
306 ProtocolsChange::add(remote_supported_protocols, protocols, protocol_buffer)
307 {
308 handler.on_connection_event(ConnectionEvent::RemoteProtocolsChange(added));
309 remote_supported_protocols.extend(protocol_buffer.drain(..));
310 }
311 continue;
312 }
313 Poll::Ready(ConnectionHandlerEvent::ReportRemoteProtocols(
314 ProtocolSupport::Removed(protocols),
315 )) => {
316 if let Some(removed) = ProtocolsChange::remove(
317 remote_supported_protocols,
318 protocols,
319 protocol_buffer,
320 ) {
321 handler
322 .on_connection_event(ConnectionEvent::RemoteProtocolsChange(removed));
323 }
324 continue;
325 }
326 }
327
328 match negotiating_out.poll_next_unpin(cx) {
331 Poll::Pending | Poll::Ready(None) => {}
332 Poll::Ready(Some((info, Ok(protocol)))) => {
333 handler.on_connection_event(ConnectionEvent::FullyNegotiatedOutbound(
334 FullyNegotiatedOutbound { protocol, info },
335 ));
336 continue;
337 }
338 Poll::Ready(Some((info, Err(error)))) => {
339 handler.on_connection_event(ConnectionEvent::DialUpgradeError(
340 DialUpgradeError { info, error },
341 ));
342 continue;
343 }
344 }
345
346 match negotiating_in.poll_next_unpin(cx) {
349 Poll::Pending | Poll::Ready(None) => {}
350 Poll::Ready(Some((info, Ok(protocol)))) => {
351 handler.on_connection_event(ConnectionEvent::FullyNegotiatedInbound(
352 FullyNegotiatedInbound { protocol, info },
353 ));
354 continue;
355 }
356 Poll::Ready(Some((info, Err(StreamUpgradeError::Apply(error))))) => {
357 handler.on_connection_event(ConnectionEvent::ListenUpgradeError(
358 ListenUpgradeError { info, error },
359 ));
360 continue;
361 }
362 Poll::Ready(Some((_, Err(StreamUpgradeError::Io(e))))) => {
363 tracing::debug!("failed to upgrade inbound stream: {e}");
364 continue;
365 }
366 Poll::Ready(Some((_, Err(StreamUpgradeError::NegotiationFailed)))) => {
367 tracing::debug!("no protocol could be agreed upon for inbound stream");
368 continue;
369 }
370 Poll::Ready(Some((_, Err(StreamUpgradeError::Timeout)))) => {
371 tracing::debug!("inbound stream upgrade timed out");
372 continue;
373 }
374 }
375
376 if negotiating_in.is_empty()
380 && negotiating_out.is_empty()
381 && requested_substreams.is_empty()
382 && stream_counter.has_no_active_streams()
383 {
384 if let Some(new_timeout) =
385 compute_new_shutdown(handler.connection_keep_alive(), shutdown, *idle_timeout)
386 {
387 *shutdown = new_timeout;
388 }
389
390 match shutdown {
391 Shutdown::None => {}
392 Shutdown::Asap => return Poll::Ready(Err(ConnectionError::KeepAliveTimeout)),
393 Shutdown::Later(delay) => match Future::poll(Pin::new(delay), cx) {
394 Poll::Ready(_) => {
395 return Poll::Ready(Err(ConnectionError::KeepAliveTimeout))
396 }
397 Poll::Pending => {}
398 },
399 }
400 } else {
401 *shutdown = Shutdown::None;
402 }
403
404 match muxing.poll_unpin(cx)? {
405 Poll::Pending => {}
406 Poll::Ready(StreamMuxerEvent::AddressChange(address)) => {
407 handler.on_connection_event(ConnectionEvent::AddressChange(AddressChange {
408 new_address: &address,
409 }));
410 return Poll::Ready(Ok(Event::AddressChange(address)));
411 }
412 }
413
414 if let Some(requested_substream) = requested_substreams.iter_mut().next() {
415 match muxing.poll_outbound_unpin(cx)? {
416 Poll::Pending => {}
417 Poll::Ready(substream) => {
418 let (user_data, timeout, upgrade) = requested_substream.extract();
419
420 negotiating_out.push(StreamUpgrade::new_outbound(
421 substream,
422 user_data,
423 timeout,
424 upgrade,
425 *substream_upgrade_protocol_override,
426 stream_counter.clone(),
427 ));
428
429 continue;
432 }
433 }
434 }
435
436 if negotiating_in.len() < *max_negotiating_inbound_streams {
437 match muxing.poll_inbound_unpin(cx)? {
438 Poll::Pending => {}
439 Poll::Ready(substream) => {
440 let protocol = handler.listen_protocol();
441
442 negotiating_in.push(StreamUpgrade::new_inbound(
443 substream,
444 protocol,
445 stream_counter.clone(),
446 ));
447
448 continue;
451 }
452 }
453 }
454
455 let changes = ProtocolsChange::from_full_sets(
456 supported_protocols,
457 handler.listen_protocol().upgrade().protocol_info(),
458 protocol_buffer,
459 );
460
461 if !changes.is_empty() {
462 for change in changes {
463 handler.on_connection_event(ConnectionEvent::LocalProtocolsChange(change));
464 }
465 continue;
467 }
468
469 return Poll::Pending;
471 }
472 }
473
474 #[cfg(test)]
475 fn poll_noop_waker(&mut self) -> Poll<Result<Event<THandler::ToBehaviour>, ConnectionError>> {
476 Pin::new(self).poll(&mut Context::from_waker(futures::task::noop_waker_ref()))
477 }
478}
479
480fn gather_supported_protocols<C: ConnectionHandler>(
481 handler: &C,
482) -> HashMap<AsStrHashEq<<C::InboundProtocol as UpgradeInfoSend>::Info>, bool> {
483 handler
484 .listen_protocol()
485 .upgrade()
486 .protocol_info()
487 .map(|info| (AsStrHashEq(info), true))
488 .collect()
489}
490
491fn compute_new_shutdown(
492 handler_keep_alive: bool,
493 current_shutdown: &Shutdown,
494 idle_timeout: Duration,
495) -> Option<Shutdown> {
496 match (current_shutdown, handler_keep_alive) {
497 (_, false) if idle_timeout == Duration::ZERO => Some(Shutdown::Asap),
498 (Shutdown::Later(_), false) => None,
500 (_, false) => {
501 let now = Instant::now();
502 let safe_keep_alive = checked_add_fraction(now, idle_timeout);
503
504 Some(Shutdown::Later(Delay::new(safe_keep_alive)))
505 }
506 (_, true) => Some(Shutdown::None),
507 }
508}
509
510fn checked_add_fraction(start: Instant, mut duration: Duration) -> Duration {
517 while start.checked_add(duration).is_none() {
518 tracing::debug!(start=?start, duration=?duration, "start + duration cannot be presented, halving duration");
519
520 duration /= 2;
521 }
522
523 duration
524}
525
526#[derive(Debug, Copy, Clone)]
528pub(crate) struct IncomingInfo<'a> {
529 pub(crate) local_addr: &'a Multiaddr,
531 pub(crate) send_back_addr: &'a Multiaddr,
533}
534
535impl IncomingInfo<'_> {
536 pub(crate) fn create_connected_point(&self) -> ConnectedPoint {
538 ConnectedPoint::Listener {
539 local_addr: self.local_addr.clone(),
540 send_back_addr: self.send_back_addr.clone(),
541 }
542 }
543}
544
545struct StreamUpgrade<UserData, TOk, TErr> {
546 user_data: Option<UserData>,
547 timeout: Delay,
548 upgrade: BoxFuture<'static, Result<TOk, StreamUpgradeError<TErr>>>,
549}
550
551impl<UserData, TOk, TErr> StreamUpgrade<UserData, TOk, TErr> {
552 fn new_outbound<Upgrade>(
553 substream: SubstreamBox,
554 user_data: UserData,
555 timeout: Delay,
556 upgrade: Upgrade,
557 version_override: Option<upgrade::Version>,
558 counter: ActiveStreamCounter,
559 ) -> Self
560 where
561 Upgrade: OutboundUpgradeSend<Output = TOk, Error = TErr>,
562 {
563 let effective_version = match version_override {
564 Some(version_override) if version_override != upgrade::Version::default() => {
565 tracing::debug!(
566 "Substream upgrade protocol override: {:?} -> {:?}",
567 upgrade::Version::default(),
568 version_override
569 );
570
571 version_override
572 }
573 _ => upgrade::Version::default(),
574 };
575 let protocols = upgrade.protocol_info();
576
577 Self {
578 user_data: Some(user_data),
579 timeout,
580 upgrade: Box::pin(async move {
581 let (info, stream) = multistream_select::dialer_select_proto(
582 substream,
583 protocols,
584 effective_version,
585 )
586 .await
587 .map_err(to_stream_upgrade_error)?;
588
589 let output = upgrade
590 .upgrade_outbound(Stream::new(stream, counter), info)
591 .await
592 .map_err(StreamUpgradeError::Apply)?;
593
594 Ok(output)
595 }),
596 }
597 }
598}
599
600impl<UserData, TOk, TErr> StreamUpgrade<UserData, TOk, TErr> {
601 fn new_inbound<Upgrade>(
602 substream: SubstreamBox,
603 protocol: SubstreamProtocol<Upgrade, UserData>,
604 counter: ActiveStreamCounter,
605 ) -> Self
606 where
607 Upgrade: InboundUpgradeSend<Output = TOk, Error = TErr>,
608 {
609 let timeout = *protocol.timeout();
610 let (upgrade, open_info) = protocol.into_upgrade();
611 let protocols = upgrade.protocol_info();
612
613 Self {
614 user_data: Some(open_info),
615 timeout: Delay::new(timeout),
616 upgrade: Box::pin(async move {
617 let (info, stream) =
618 multistream_select::listener_select_proto(substream, protocols)
619 .await
620 .map_err(to_stream_upgrade_error)?;
621
622 let output = upgrade
623 .upgrade_inbound(Stream::new(stream, counter), info)
624 .await
625 .map_err(StreamUpgradeError::Apply)?;
626
627 Ok(output)
628 }),
629 }
630 }
631}
632
633fn to_stream_upgrade_error<T>(e: NegotiationError) -> StreamUpgradeError<T> {
634 match e {
635 NegotiationError::Failed => StreamUpgradeError::NegotiationFailed,
636 NegotiationError::ProtocolError(ProtocolError::IoError(e)) => StreamUpgradeError::Io(e),
637 NegotiationError::ProtocolError(other) => StreamUpgradeError::Io(io::Error::other(other)),
638 }
639}
640
641impl<UserData, TOk, TErr> Unpin for StreamUpgrade<UserData, TOk, TErr> {}
642
643impl<UserData, TOk, TErr> Future for StreamUpgrade<UserData, TOk, TErr> {
644 type Output = (UserData, Result<TOk, StreamUpgradeError<TErr>>);
645
646 fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
647 match self.timeout.poll_unpin(cx) {
648 Poll::Ready(()) => {
649 return Poll::Ready((
650 self.user_data
651 .take()
652 .expect("Future not to be polled again once ready."),
653 Err(StreamUpgradeError::Timeout),
654 ))
655 }
656
657 Poll::Pending => {}
658 }
659
660 let result = futures::ready!(self.upgrade.poll_unpin(cx));
661 let user_data = self
662 .user_data
663 .take()
664 .expect("Future not to be polled again once ready.");
665
666 Poll::Ready((user_data, result))
667 }
668}
669
670enum SubstreamRequested<UserData, Upgrade> {
671 Waiting {
672 user_data: UserData,
673 timeout: Delay,
674 upgrade: Upgrade,
675 extracted_waker: Option<Waker>,
680 },
681 Done,
682}
683
684impl<UserData, Upgrade> SubstreamRequested<UserData, Upgrade> {
685 fn new(user_data: UserData, timeout: Duration, upgrade: Upgrade) -> Self {
686 Self::Waiting {
687 user_data,
688 timeout: Delay::new(timeout),
689 upgrade,
690 extracted_waker: None,
691 }
692 }
693
694 fn extract(&mut self) -> (UserData, Delay, Upgrade) {
695 match mem::replace(self, Self::Done) {
696 SubstreamRequested::Waiting {
697 user_data,
698 timeout,
699 upgrade,
700 extracted_waker: waker,
701 } => {
702 if let Some(waker) = waker {
703 waker.wake();
704 }
705
706 (user_data, timeout, upgrade)
707 }
708 SubstreamRequested::Done => panic!("cannot extract twice"),
709 }
710 }
711}
712
713impl<UserData, Upgrade> Unpin for SubstreamRequested<UserData, Upgrade> {}
714
715impl<UserData, Upgrade> Future for SubstreamRequested<UserData, Upgrade> {
716 type Output = Result<(), UserData>;
717
718 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
719 let this = self.get_mut();
720
721 match mem::replace(this, Self::Done) {
722 SubstreamRequested::Waiting {
723 user_data,
724 upgrade,
725 mut timeout,
726 ..
727 } => match timeout.poll_unpin(cx) {
728 Poll::Ready(()) => Poll::Ready(Err(user_data)),
729 Poll::Pending => {
730 *this = Self::Waiting {
731 user_data,
732 upgrade,
733 timeout,
734 extracted_waker: Some(cx.waker().clone()),
735 };
736 Poll::Pending
737 }
738 },
739 SubstreamRequested::Done => Poll::Ready(Ok(())),
740 }
741 }
742}
743
744#[derive(Debug)]
754enum Shutdown {
755 None,
757 Asap,
759 Later(Delay),
761}
762
763pub(crate) struct AsStrHashEq<T>(pub(crate) T);
767
768impl<T: AsRef<str>> Eq for AsStrHashEq<T> {}
769
770impl<T: AsRef<str>> PartialEq for AsStrHashEq<T> {
771 fn eq(&self, other: &Self) -> bool {
772 self.0.as_ref() == other.0.as_ref()
773 }
774}
775
776impl<T: AsRef<str>> std::hash::Hash for AsStrHashEq<T> {
777 fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
778 self.0.as_ref().hash(state)
779 }
780}
781
782#[cfg(test)]
783mod tests {
784 use std::{
785 convert::Infallible,
786 sync::{Arc, Weak},
787 time::Instant,
788 };
789
790 use futures::{future, AsyncRead, AsyncWrite};
791 use libp2p_core::{
792 upgrade::{DeniedUpgrade, InboundUpgrade, OutboundUpgrade, UpgradeInfo},
793 StreamMuxer,
794 };
795 use quickcheck::*;
796 use tracing_subscriber::EnvFilter;
797
798 use super::*;
799 use crate::dummy;
800
801 #[test]
802 fn max_negotiating_inbound_streams() {
803 let _ = tracing_subscriber::fmt()
804 .with_env_filter(EnvFilter::from_default_env())
805 .try_init();
806
807 fn prop(max_negotiating_inbound_streams: u8) {
808 let max_negotiating_inbound_streams: usize = max_negotiating_inbound_streams.into();
809
810 let alive_substream_counter = Arc::new(());
811 let mut connection = Connection::new(
812 StreamMuxerBox::new(DummyStreamMuxer {
813 counter: alive_substream_counter.clone(),
814 }),
815 MockConnectionHandler::new(Duration::from_secs(10)),
816 None,
817 max_negotiating_inbound_streams,
818 Duration::ZERO,
819 );
820
821 let result = connection.poll_noop_waker();
822
823 assert!(result.is_pending());
824 assert_eq!(
825 Arc::weak_count(&alive_substream_counter),
826 max_negotiating_inbound_streams,
827 "Expect no more than the maximum number of allowed streams"
828 );
829 }
830
831 QuickCheck::new().quickcheck(prop as fn(_));
832 }
833
834 #[test]
835 fn outbound_stream_timeout_starts_on_request() {
836 let upgrade_timeout = Duration::from_secs(1);
837 let mut connection = Connection::new(
838 StreamMuxerBox::new(PendingStreamMuxer),
839 MockConnectionHandler::new(upgrade_timeout),
840 None,
841 2,
842 Duration::ZERO,
843 );
844
845 connection.handler.open_new_outbound();
846 let _ = connection.poll_noop_waker();
847
848 std::thread::sleep(upgrade_timeout + Duration::from_secs(1));
849
850 let _ = connection.poll_noop_waker();
851
852 assert!(matches!(
853 connection.handler.error.unwrap(),
854 StreamUpgradeError::Timeout
855 ))
856 }
857
858 #[test]
859 fn propagates_changes_to_supported_inbound_protocols() {
860 let mut connection = Connection::new(
861 StreamMuxerBox::new(PendingStreamMuxer),
862 ConfigurableProtocolConnectionHandler::default(),
863 None,
864 0,
865 Duration::ZERO,
866 );
867
868 connection.handler.listen_on(&["/foo"]);
870 let _ = connection.poll_noop_waker();
871
872 assert_eq!(connection.handler.local_added, vec![vec!["/foo"]]);
873 assert!(connection.handler.local_removed.is_empty());
874
875 connection.handler.listen_on(&["/foo", "/bar"]);
877 let _ = connection.poll_noop_waker();
878
879 assert_eq!(
880 connection.handler.local_added,
881 vec![vec!["/foo"], vec!["/bar"]],
882 "expect to only receive an event for the newly added protocols"
883 );
884 assert!(connection.handler.local_removed.is_empty());
885
886 connection.handler.listen_on(&["/bar"]);
888 let _ = connection.poll_noop_waker();
889
890 assert_eq!(
891 connection.handler.local_added,
892 vec![vec!["/foo"], vec!["/bar"]]
893 );
894 assert_eq!(connection.handler.local_removed, vec![vec!["/foo"]]);
895 }
896
897 #[test]
898 fn only_propagtes_actual_changes_to_remote_protocols_to_handler() {
899 let mut connection = Connection::new(
900 StreamMuxerBox::new(PendingStreamMuxer),
901 ConfigurableProtocolConnectionHandler::default(),
902 None,
903 0,
904 Duration::ZERO,
905 );
906
907 connection.handler.remote_adds_support_for(&["/foo"]);
909 let _ = connection.poll_noop_waker();
910
911 assert_eq!(connection.handler.remote_added, vec![vec!["/foo"]]);
912 assert!(connection.handler.remote_removed.is_empty());
913
914 connection
916 .handler
917 .remote_adds_support_for(&["/foo", "/bar"]);
918 let _ = connection.poll_noop_waker();
919
920 assert_eq!(
921 connection.handler.remote_added,
922 vec![vec!["/foo"], vec!["/bar"]],
923 "expect to only receive an event for the newly added protocol"
924 );
925 assert!(connection.handler.remote_removed.is_empty());
926
927 connection.handler.remote_removes_support_for(&["/baz"]);
930 let _ = connection.poll_noop_waker();
931
932 assert_eq!(
933 connection.handler.remote_added,
934 vec![vec!["/foo"], vec!["/bar"]]
935 );
936 assert!(&connection.handler.remote_removed.is_empty());
937
938 connection.handler.remote_removes_support_for(&["/bar"]);
940 let _ = connection.poll_noop_waker();
941
942 assert_eq!(
943 connection.handler.remote_added,
944 vec![vec!["/foo"], vec!["/bar"]]
945 );
946 assert_eq!(connection.handler.remote_removed, vec![vec!["/bar"]]);
947 }
948
949 #[tokio::test]
950 async fn idle_timeout_with_keep_alive_no() {
951 let idle_timeout = Duration::from_millis(100);
952
953 let mut connection = Connection::new(
954 StreamMuxerBox::new(PendingStreamMuxer),
955 dummy::ConnectionHandler,
956 None,
957 0,
958 idle_timeout,
959 );
960
961 assert!(connection.poll_noop_waker().is_pending());
962
963 tokio::time::sleep(idle_timeout).await;
964
965 assert!(matches!(
966 connection.poll_noop_waker(),
967 Poll::Ready(Err(ConnectionError::KeepAliveTimeout))
968 ));
969 }
970
971 #[test]
972 fn checked_add_fraction_can_add_u64_max() {
973 let _ = tracing_subscriber::fmt()
974 .with_env_filter(tracing_subscriber::EnvFilter::from_default_env())
975 .try_init();
976 let start = Instant::now();
977
978 let duration = checked_add_fraction(start, Duration::from_secs(u64::MAX));
979
980 assert!(start.checked_add(duration).is_some())
981 }
982
983 #[test]
984 fn compute_new_shutdown_does_not_panic() {
985 let _ = tracing_subscriber::fmt()
986 .with_env_filter(EnvFilter::from_default_env())
987 .try_init();
988
989 #[derive(Debug)]
990 struct ArbitraryShutdown(Shutdown);
991
992 impl Clone for ArbitraryShutdown {
993 fn clone(&self) -> Self {
994 let shutdown = match self.0 {
995 Shutdown::None => Shutdown::None,
996 Shutdown::Asap => Shutdown::Asap,
997 Shutdown::Later(_) => Shutdown::Later(
998 Delay::new(Duration::from_secs(1)),
1001 ),
1002 };
1003
1004 ArbitraryShutdown(shutdown)
1005 }
1006 }
1007
1008 impl Arbitrary for ArbitraryShutdown {
1009 fn arbitrary(g: &mut Gen) -> Self {
1010 let shutdown = match g.gen_range(1u8..4) {
1011 1 => Shutdown::None,
1012 2 => Shutdown::Asap,
1013 3 => Shutdown::Later(Delay::new(Duration::from_secs(u32::arbitrary(g) as u64))),
1014 _ => unreachable!(),
1015 };
1016
1017 Self(shutdown)
1018 }
1019 }
1020
1021 fn prop(
1022 handler_keep_alive: bool,
1023 current_shutdown: ArbitraryShutdown,
1024 idle_timeout: Duration,
1025 ) {
1026 compute_new_shutdown(handler_keep_alive, ¤t_shutdown.0, idle_timeout);
1027 }
1028
1029 QuickCheck::new().quickcheck(prop as fn(_, _, _));
1030 }
1031
1032 struct DummyStreamMuxer {
1033 counter: Arc<()>,
1034 }
1035
1036 impl StreamMuxer for DummyStreamMuxer {
1037 type Substream = PendingSubstream;
1038 type Error = Infallible;
1039
1040 fn poll_inbound(
1041 self: Pin<&mut Self>,
1042 _: &mut Context<'_>,
1043 ) -> Poll<Result<Self::Substream, Self::Error>> {
1044 Poll::Ready(Ok(PendingSubstream {
1045 _weak: Arc::downgrade(&self.counter),
1046 }))
1047 }
1048
1049 fn poll_outbound(
1050 self: Pin<&mut Self>,
1051 _: &mut Context<'_>,
1052 ) -> Poll<Result<Self::Substream, Self::Error>> {
1053 Poll::Pending
1054 }
1055
1056 fn poll_close(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
1057 Poll::Ready(Ok(()))
1058 }
1059
1060 fn poll(
1061 self: Pin<&mut Self>,
1062 _: &mut Context<'_>,
1063 ) -> Poll<Result<StreamMuxerEvent, Self::Error>> {
1064 Poll::Pending
1065 }
1066 }
1067
1068 struct PendingStreamMuxer;
1070
1071 impl StreamMuxer for PendingStreamMuxer {
1072 type Substream = PendingSubstream;
1073 type Error = Infallible;
1074
1075 fn poll_inbound(
1076 self: Pin<&mut Self>,
1077 _: &mut Context<'_>,
1078 ) -> Poll<Result<Self::Substream, Self::Error>> {
1079 Poll::Pending
1080 }
1081
1082 fn poll_outbound(
1083 self: Pin<&mut Self>,
1084 _: &mut Context<'_>,
1085 ) -> Poll<Result<Self::Substream, Self::Error>> {
1086 Poll::Pending
1087 }
1088
1089 fn poll_close(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
1090 Poll::Pending
1091 }
1092
1093 fn poll(
1094 self: Pin<&mut Self>,
1095 _: &mut Context<'_>,
1096 ) -> Poll<Result<StreamMuxerEvent, Self::Error>> {
1097 Poll::Pending
1098 }
1099 }
1100
1101 struct PendingSubstream {
1102 _weak: Weak<()>,
1103 }
1104
1105 impl AsyncRead for PendingSubstream {
1106 fn poll_read(
1107 self: Pin<&mut Self>,
1108 _cx: &mut Context<'_>,
1109 _buf: &mut [u8],
1110 ) -> Poll<std::io::Result<usize>> {
1111 Poll::Pending
1112 }
1113 }
1114
1115 impl AsyncWrite for PendingSubstream {
1116 fn poll_write(
1117 self: Pin<&mut Self>,
1118 _cx: &mut Context<'_>,
1119 _buf: &[u8],
1120 ) -> Poll<std::io::Result<usize>> {
1121 Poll::Pending
1122 }
1123
1124 fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
1125 Poll::Pending
1126 }
1127
1128 fn poll_close(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
1129 Poll::Pending
1130 }
1131 }
1132
1133 struct MockConnectionHandler {
1134 outbound_requested: bool,
1135 error: Option<StreamUpgradeError<Infallible>>,
1136 upgrade_timeout: Duration,
1137 }
1138
1139 impl MockConnectionHandler {
1140 fn new(upgrade_timeout: Duration) -> Self {
1141 Self {
1142 outbound_requested: false,
1143 error: None,
1144 upgrade_timeout,
1145 }
1146 }
1147
1148 fn open_new_outbound(&mut self) {
1149 self.outbound_requested = true;
1150 }
1151 }
1152
1153 #[derive(Default)]
1154 struct ConfigurableProtocolConnectionHandler {
1155 events: Vec<ConnectionHandlerEvent<DeniedUpgrade, (), Infallible>>,
1156 active_protocols: HashSet<StreamProtocol>,
1157 local_added: Vec<Vec<StreamProtocol>>,
1158 local_removed: Vec<Vec<StreamProtocol>>,
1159 remote_added: Vec<Vec<StreamProtocol>>,
1160 remote_removed: Vec<Vec<StreamProtocol>>,
1161 }
1162
1163 impl ConfigurableProtocolConnectionHandler {
1164 fn listen_on(&mut self, protocols: &[&'static str]) {
1165 self.active_protocols = protocols.iter().copied().map(StreamProtocol::new).collect();
1166 }
1167
1168 fn remote_adds_support_for(&mut self, protocols: &[&'static str]) {
1169 self.events
1170 .push(ConnectionHandlerEvent::ReportRemoteProtocols(
1171 ProtocolSupport::Added(
1172 protocols.iter().copied().map(StreamProtocol::new).collect(),
1173 ),
1174 ));
1175 }
1176
1177 fn remote_removes_support_for(&mut self, protocols: &[&'static str]) {
1178 self.events
1179 .push(ConnectionHandlerEvent::ReportRemoteProtocols(
1180 ProtocolSupport::Removed(
1181 protocols.iter().copied().map(StreamProtocol::new).collect(),
1182 ),
1183 ));
1184 }
1185 }
1186
1187 impl ConnectionHandler for MockConnectionHandler {
1188 type FromBehaviour = Infallible;
1189 type ToBehaviour = Infallible;
1190 type InboundProtocol = DeniedUpgrade;
1191 type OutboundProtocol = DeniedUpgrade;
1192 type InboundOpenInfo = ();
1193 type OutboundOpenInfo = ();
1194
1195 fn listen_protocol(&self) -> SubstreamProtocol<Self::InboundProtocol> {
1196 SubstreamProtocol::new(DeniedUpgrade, ()).with_timeout(self.upgrade_timeout)
1197 }
1198
1199 fn on_connection_event(
1200 &mut self,
1201 event: ConnectionEvent<Self::InboundProtocol, Self::OutboundProtocol>,
1202 ) {
1203 match event {
1204 ConnectionEvent::FullyNegotiatedInbound(FullyNegotiatedInbound {
1205 protocol,
1206 ..
1207 }) => libp2p_core::util::unreachable(protocol),
1208 ConnectionEvent::FullyNegotiatedOutbound(FullyNegotiatedOutbound {
1209 protocol,
1210 ..
1211 }) => libp2p_core::util::unreachable(protocol),
1212 ConnectionEvent::DialUpgradeError(DialUpgradeError { error, .. }) => {
1213 self.error = Some(error)
1214 }
1215 ConnectionEvent::AddressChange(_)
1216 | ConnectionEvent::ListenUpgradeError(_)
1217 | ConnectionEvent::LocalProtocolsChange(_)
1218 | ConnectionEvent::RemoteProtocolsChange(_) => {}
1219 }
1220 }
1221
1222 fn on_behaviour_event(&mut self, event: Self::FromBehaviour) {
1223 libp2p_core::util::unreachable(event)
1224 }
1225
1226 fn connection_keep_alive(&self) -> bool {
1227 true
1228 }
1229
1230 fn poll(
1231 &mut self,
1232 _: &mut Context<'_>,
1233 ) -> Poll<ConnectionHandlerEvent<Self::OutboundProtocol, (), Self::ToBehaviour>> {
1234 if self.outbound_requested {
1235 self.outbound_requested = false;
1236 return Poll::Ready(ConnectionHandlerEvent::OutboundSubstreamRequest {
1237 protocol: SubstreamProtocol::new(DeniedUpgrade, ())
1238 .with_timeout(self.upgrade_timeout),
1239 });
1240 }
1241
1242 Poll::Pending
1243 }
1244 }
1245
1246 impl ConnectionHandler for ConfigurableProtocolConnectionHandler {
1247 type FromBehaviour = Infallible;
1248 type ToBehaviour = Infallible;
1249 type InboundProtocol = ManyProtocolsUpgrade;
1250 type OutboundProtocol = DeniedUpgrade;
1251 type InboundOpenInfo = ();
1252 type OutboundOpenInfo = ();
1253
1254 fn listen_protocol(&self) -> SubstreamProtocol<Self::InboundProtocol> {
1255 SubstreamProtocol::new(
1256 ManyProtocolsUpgrade {
1257 protocols: Vec::from_iter(self.active_protocols.clone()),
1258 },
1259 (),
1260 )
1261 }
1262
1263 fn on_connection_event(
1264 &mut self,
1265 event: ConnectionEvent<Self::InboundProtocol, Self::OutboundProtocol>,
1266 ) {
1267 match event {
1268 ConnectionEvent::LocalProtocolsChange(ProtocolsChange::Added(added)) => {
1269 self.local_added.push(added.cloned().collect())
1270 }
1271 ConnectionEvent::LocalProtocolsChange(ProtocolsChange::Removed(removed)) => {
1272 self.local_removed.push(removed.cloned().collect())
1273 }
1274 ConnectionEvent::RemoteProtocolsChange(ProtocolsChange::Added(added)) => {
1275 self.remote_added.push(added.cloned().collect())
1276 }
1277 ConnectionEvent::RemoteProtocolsChange(ProtocolsChange::Removed(removed)) => {
1278 self.remote_removed.push(removed.cloned().collect())
1279 }
1280 _ => {}
1281 }
1282 }
1283
1284 fn on_behaviour_event(&mut self, event: Self::FromBehaviour) {
1285 libp2p_core::util::unreachable(event)
1286 }
1287
1288 fn connection_keep_alive(&self) -> bool {
1289 true
1290 }
1291
1292 fn poll(
1293 &mut self,
1294 _: &mut Context<'_>,
1295 ) -> Poll<ConnectionHandlerEvent<Self::OutboundProtocol, (), Self::ToBehaviour>> {
1296 if let Some(event) = self.events.pop() {
1297 return Poll::Ready(event);
1298 }
1299
1300 Poll::Pending
1301 }
1302 }
1303
1304 struct ManyProtocolsUpgrade {
1305 protocols: Vec<StreamProtocol>,
1306 }
1307
1308 impl UpgradeInfo for ManyProtocolsUpgrade {
1309 type Info = StreamProtocol;
1310 type InfoIter = std::vec::IntoIter<Self::Info>;
1311
1312 fn protocol_info(&self) -> Self::InfoIter {
1313 self.protocols.clone().into_iter()
1314 }
1315 }
1316
1317 impl<C> InboundUpgrade<C> for ManyProtocolsUpgrade {
1318 type Output = C;
1319 type Error = Infallible;
1320 type Future = future::Ready<Result<Self::Output, Self::Error>>;
1321
1322 fn upgrade_inbound(self, stream: C, _: Self::Info) -> Self::Future {
1323 future::ready(Ok(stream))
1324 }
1325 }
1326
1327 impl<C> OutboundUpgrade<C> for ManyProtocolsUpgrade {
1328 type Output = C;
1329 type Error = Infallible;
1330 type Future = future::Ready<Result<Self::Output, Self::Error>>;
1331
1332 fn upgrade_outbound(self, stream: C, _: Self::Info) -> Self::Future {
1333 future::ready(Ok(stream))
1334 }
1335 }
1336}
1337
1338#[derive(Debug, Clone, PartialEq, Eq, Hash)]
1340enum PendingPoint {
1341 Dialer {
1347 role_override: Endpoint,
1349 port_use: PortUse,
1350 },
1351 Listener {
1353 local_addr: Multiaddr,
1355 send_back_addr: Multiaddr,
1357 },
1358}
1359
1360impl From<ConnectedPoint> for PendingPoint {
1361 fn from(endpoint: ConnectedPoint) -> Self {
1362 match endpoint {
1363 ConnectedPoint::Dialer {
1364 role_override,
1365 port_use,
1366 ..
1367 } => PendingPoint::Dialer {
1368 role_override,
1369 port_use,
1370 },
1371 ConnectedPoint::Listener {
1372 local_addr,
1373 send_back_addr,
1374 } => PendingPoint::Listener {
1375 local_addr,
1376 send_back_addr,
1377 },
1378 }
1379 }
1380}