libp2p_swarm/connection/pool/concurrent_dial.rs
1// Copyright 2021 Protocol Labs.
2//
3// Permission is hereby granted, free of charge, to any person obtaining a
4// copy of this software and associated documentation files (the "Software"),
5// to deal in the Software without restriction, including without limitation
6// the rights to use, copy, modify, merge, publish, distribute, sublicense,
7// and/or sell copies of the Software, and to permit persons to whom the
8// Software is furnished to do so, subject to the following conditions:
9//
10// The above copyright notice and this permission notice shall be included in
11// all copies or substantial portions of the Software.
12//
13// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
14// OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
16// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
18// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
19// DEALINGS IN THE SOFTWARE.
20
21use std::{
22 num::NonZeroU8,
23 pin::Pin,
24 task::{Context, Poll},
25};
26
27use futures::{
28 future::{BoxFuture, Future},
29 ready,
30 stream::{FuturesUnordered, StreamExt},
31};
32use libp2p_core::muxing::StreamMuxerBox;
33use libp2p_identity::PeerId;
34
35use crate::{transport::TransportError, Multiaddr};
36
37type Dial = BoxFuture<
38 'static,
39 (
40 Multiaddr,
41 Result<(PeerId, StreamMuxerBox), TransportError<std::io::Error>>,
42 ),
43>;
44
45pub(crate) struct ConcurrentDial {
46 dials: FuturesUnordered<Dial>,
47 pending_dials: Box<dyn Iterator<Item = Dial> + Send>,
48 errors: Vec<(Multiaddr, TransportError<std::io::Error>)>,
49}
50
51impl Unpin for ConcurrentDial {}
52
53impl ConcurrentDial {
54 pub(crate) fn new(pending_dials: Vec<Dial>, concurrency_factor: NonZeroU8) -> Self {
55 let mut pending_dials = pending_dials.into_iter();
56
57 let dials = FuturesUnordered::new();
58 for dial in pending_dials.by_ref() {
59 dials.push(dial);
60 if dials.len() == concurrency_factor.get() as usize {
61 break;
62 }
63 }
64
65 Self {
66 dials,
67 errors: Default::default(),
68 pending_dials: Box::new(pending_dials),
69 }
70 }
71}
72
73impl Future for ConcurrentDial {
74 type Output = Result<
75 // Either one dial succeeded, returning the negotiated [`PeerId`], the address, the
76 // muxer and the addresses and errors of the dials that failed before.
77 (
78 Multiaddr,
79 (PeerId, StreamMuxerBox),
80 Vec<(Multiaddr, TransportError<std::io::Error>)>,
81 ),
82 // Or all dials failed, thus returning the address and error for each dial.
83 Vec<(Multiaddr, TransportError<std::io::Error>)>,
84 >;
85
86 fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
87 loop {
88 match ready!(self.dials.poll_next_unpin(cx)) {
89 Some((addr, Ok(output))) => {
90 let errors = std::mem::take(&mut self.errors);
91 return Poll::Ready(Ok((addr, output, errors)));
92 }
93 Some((addr, Err(e))) => {
94 self.errors.push((addr, e));
95 if let Some(dial) = self.pending_dials.next() {
96 self.dials.push(dial)
97 }
98 }
99 None => {
100 return Poll::Ready(Err(std::mem::take(&mut self.errors)));
101 }
102 }
103 }
104 }
105}