libp2p_swarm/connection/pool/
concurrent_dial.rs

1// Copyright 2021 Protocol Labs.
2//
3// Permission is hereby granted, free of charge, to any person obtaining a
4// copy of this software and associated documentation files (the "Software"),
5// to deal in the Software without restriction, including without limitation
6// the rights to use, copy, modify, merge, publish, distribute, sublicense,
7// and/or sell copies of the Software, and to permit persons to whom the
8// Software is furnished to do so, subject to the following conditions:
9//
10// The above copyright notice and this permission notice shall be included in
11// all copies or substantial portions of the Software.
12//
13// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
14// OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
16// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
18// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
19// DEALINGS IN THE SOFTWARE.
20
21use std::{
22    num::NonZeroU8,
23    pin::Pin,
24    task::{Context, Poll},
25};
26
27use futures::{
28    future::{BoxFuture, Future},
29    ready,
30    stream::{FuturesUnordered, StreamExt},
31};
32use libp2p_core::muxing::StreamMuxerBox;
33use libp2p_identity::PeerId;
34
35use crate::{transport::TransportError, Multiaddr};
36
37type Dial = BoxFuture<
38    'static,
39    (
40        Multiaddr,
41        Result<(PeerId, StreamMuxerBox), TransportError<std::io::Error>>,
42    ),
43>;
44
45pub(crate) struct ConcurrentDial {
46    dials: FuturesUnordered<Dial>,
47    pending_dials: Box<dyn Iterator<Item = Dial> + Send>,
48    errors: Vec<(Multiaddr, TransportError<std::io::Error>)>,
49}
50
51impl Unpin for ConcurrentDial {}
52
53impl ConcurrentDial {
54    pub(crate) fn new(pending_dials: Vec<Dial>, concurrency_factor: NonZeroU8) -> Self {
55        let mut pending_dials = pending_dials.into_iter();
56
57        let dials = FuturesUnordered::new();
58        for dial in pending_dials.by_ref() {
59            dials.push(dial);
60            if dials.len() == concurrency_factor.get() as usize {
61                break;
62            }
63        }
64
65        Self {
66            dials,
67            errors: Default::default(),
68            pending_dials: Box::new(pending_dials),
69        }
70    }
71}
72
73impl Future for ConcurrentDial {
74    type Output = Result<
75        // Either one dial succeeded, returning the negotiated [`PeerId`], the address, the
76        // muxer and the addresses and errors of the dials that failed before.
77        (
78            Multiaddr,
79            (PeerId, StreamMuxerBox),
80            Vec<(Multiaddr, TransportError<std::io::Error>)>,
81        ),
82        // Or all dials failed, thus returning the address and error for each dial.
83        Vec<(Multiaddr, TransportError<std::io::Error>)>,
84    >;
85
86    fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
87        loop {
88            match ready!(self.dials.poll_next_unpin(cx)) {
89                Some((addr, Ok(output))) => {
90                    let errors = std::mem::take(&mut self.errors);
91                    return Poll::Ready(Ok((addr, output, errors)));
92                }
93                Some((addr, Err(e))) => {
94                    self.errors.push((addr, e));
95                    if let Some(dial) = self.pending_dials.next() {
96                        self.dials.push(dial)
97                    }
98                }
99                None => {
100                    return Poll::Ready(Err(std::mem::take(&mut self.errors)));
101                }
102            }
103        }
104    }
105}