libp2p_autonat/v2/client/handler/
dial_request.rs

1use std::{
2    collections::VecDeque,
3    convert::Infallible,
4    io,
5    iter::{once, repeat_n},
6    task::{Context, Poll},
7    time::Duration,
8};
9
10use futures::{channel::oneshot, AsyncWrite};
11use futures_bounded::FuturesMap;
12use libp2p_core::{
13    upgrade::{DeniedUpgrade, ReadyUpgrade},
14    Multiaddr,
15};
16use libp2p_swarm::{
17    handler::{
18        ConnectionEvent, DialUpgradeError, FullyNegotiatedOutbound, OutboundUpgradeSend,
19        ProtocolsChange,
20    },
21    ConnectionHandler, ConnectionHandlerEvent, Stream, StreamProtocol, StreamUpgradeError,
22    SubstreamProtocol,
23};
24
25use crate::v2::{
26    generated::structs::{mod_DialResponse::ResponseStatus, DialStatus},
27    protocol::{
28        Coder, DialDataRequest, DialDataResponse, DialRequest, Response,
29        DATA_FIELD_LEN_UPPER_BOUND, DATA_LEN_LOWER_BOUND, DATA_LEN_UPPER_BOUND,
30    },
31    Nonce, DIAL_REQUEST_PROTOCOL,
32};
33
34#[derive(Debug)]
35pub enum ToBehaviour {
36    TestOutcome {
37        nonce: Nonce,
38        outcome: Result<(Multiaddr, usize), Error>,
39    },
40    PeerHasServerSupport,
41}
42
43#[derive(thiserror::Error, Debug)]
44pub enum Error {
45    #[error("Address is not reachable: {error}")]
46    AddressNotReachable {
47        address: Multiaddr,
48        bytes_sent: usize,
49        error: DialBackError,
50    },
51    #[error("Peer does not support AutoNAT dial-request protocol")]
52    UnsupportedProtocol,
53    #[error("IO error: {0}")]
54    Io(io::Error),
55}
56
57impl From<io::Error> for Error {
58    fn from(value: io::Error) -> Self {
59        Self::Io(value)
60    }
61}
62
63#[derive(thiserror::Error, Debug)]
64pub enum DialBackError {
65    #[error("server failed to establish a connection")]
66    NoConnection,
67    #[error("dial back stream failed")]
68    StreamFailed,
69}
70
71pub struct Handler {
72    queued_events: VecDeque<
73        ConnectionHandlerEvent<
74            <Self as ConnectionHandler>::OutboundProtocol,
75            (),
76            <Self as ConnectionHandler>::ToBehaviour,
77        >,
78    >,
79    outbound: FuturesMap<Nonce, Result<(Multiaddr, usize), Error>>,
80    queued_streams: VecDeque<
81        oneshot::Sender<
82            Result<
83                Stream,
84                StreamUpgradeError<<ReadyUpgrade<StreamProtocol> as OutboundUpgradeSend>::Error>,
85            >,
86        >,
87    >,
88}
89
90impl Handler {
91    pub(crate) fn new() -> Self {
92        Self {
93            queued_events: VecDeque::new(),
94            outbound: FuturesMap::new(Duration::from_secs(10), 10),
95            queued_streams: VecDeque::default(),
96        }
97    }
98
99    fn perform_request(&mut self, req: DialRequest) {
100        let (tx, rx) = oneshot::channel();
101        self.queued_streams.push_back(tx);
102        self.queued_events
103            .push_back(ConnectionHandlerEvent::OutboundSubstreamRequest {
104                protocol: SubstreamProtocol::new(ReadyUpgrade::new(DIAL_REQUEST_PROTOCOL), ()),
105            });
106        if self
107            .outbound
108            .try_push(req.nonce, start_stream_handle(req, rx))
109            .is_err()
110        {
111            tracing::debug!("Dial request dropped, too many requests in flight");
112        }
113    }
114}
115
116impl ConnectionHandler for Handler {
117    type FromBehaviour = DialRequest;
118    type ToBehaviour = ToBehaviour;
119    type InboundProtocol = DeniedUpgrade;
120    type OutboundProtocol = ReadyUpgrade<StreamProtocol>;
121    type InboundOpenInfo = ();
122    type OutboundOpenInfo = ();
123
124    fn listen_protocol(&self) -> SubstreamProtocol<Self::InboundProtocol> {
125        SubstreamProtocol::new(DeniedUpgrade, ())
126    }
127
128    fn poll(
129        &mut self,
130        cx: &mut Context<'_>,
131    ) -> Poll<ConnectionHandlerEvent<Self::OutboundProtocol, (), Self::ToBehaviour>> {
132        if let Some(event) = self.queued_events.pop_front() {
133            return Poll::Ready(event);
134        }
135
136        match self.outbound.poll_unpin(cx) {
137            Poll::Ready((nonce, Ok(outcome))) => {
138                return Poll::Ready(ConnectionHandlerEvent::NotifyBehaviour(
139                    ToBehaviour::TestOutcome { nonce, outcome },
140                ))
141            }
142            Poll::Ready((nonce, Err(_))) => {
143                return Poll::Ready(ConnectionHandlerEvent::NotifyBehaviour(
144                    ToBehaviour::TestOutcome {
145                        nonce,
146                        outcome: Err(Error::Io(io::ErrorKind::TimedOut.into())),
147                    },
148                ));
149            }
150            Poll::Pending => {}
151        }
152
153        Poll::Pending
154    }
155
156    fn on_behaviour_event(&mut self, event: Self::FromBehaviour) {
157        self.perform_request(event);
158    }
159
160    fn on_connection_event(
161        &mut self,
162        event: ConnectionEvent<Self::InboundProtocol, Self::OutboundProtocol>,
163    ) {
164        match event {
165            ConnectionEvent::DialUpgradeError(DialUpgradeError { error, .. }) => {
166                tracing::debug!("Dial request failed: {}", error);
167                match self.queued_streams.pop_front() {
168                    Some(stream_tx) => {
169                        let _ = stream_tx.send(Err(error));
170                    }
171                    None => {
172                        tracing::warn!(
173                            "Opened unexpected substream without a pending dial request"
174                        );
175                    }
176                }
177            }
178            ConnectionEvent::FullyNegotiatedOutbound(FullyNegotiatedOutbound {
179                protocol, ..
180            }) => match self.queued_streams.pop_front() {
181                Some(stream_tx) => {
182                    if stream_tx.send(Ok(protocol)).is_err() {
183                        tracing::debug!("Failed to send stream to dead handler");
184                    }
185                }
186                None => {
187                    tracing::warn!("Opened unexpected substream without a pending dial request");
188                }
189            },
190            ConnectionEvent::RemoteProtocolsChange(ProtocolsChange::Added(mut added)) => {
191                if added.any(|p| p.as_ref() == DIAL_REQUEST_PROTOCOL) {
192                    self.queued_events
193                        .push_back(ConnectionHandlerEvent::NotifyBehaviour(
194                            ToBehaviour::PeerHasServerSupport,
195                        ));
196                }
197            }
198            _ => {}
199        }
200    }
201}
202
203async fn start_stream_handle(
204    req: DialRequest,
205    stream_recv: oneshot::Receiver<Result<Stream, StreamUpgradeError<Infallible>>>,
206) -> Result<(Multiaddr, usize), Error> {
207    let stream = stream_recv
208        .await
209        .map_err(|_| io::Error::from(io::ErrorKind::BrokenPipe))?
210        .map_err(|e| match e {
211            StreamUpgradeError::NegotiationFailed => Error::UnsupportedProtocol,
212            StreamUpgradeError::Timeout => Error::Io(io::ErrorKind::TimedOut.into()),
213            StreamUpgradeError::Apply(v) => libp2p_core::util::unreachable(v),
214            StreamUpgradeError::Io(e) => Error::Io(e),
215        })?;
216
217    let mut coder = Coder::new(stream);
218    coder.send(req.clone()).await?;
219
220    let (res, bytes_sent) = match coder.next().await? {
221        Response::Data(DialDataRequest {
222            addr_idx,
223            num_bytes,
224        }) => {
225            if addr_idx >= req.addrs.len() {
226                return Err(Error::Io(io::Error::new(
227                    io::ErrorKind::InvalidInput,
228                    "address index out of bounds",
229                )));
230            }
231            if !(DATA_LEN_LOWER_BOUND..=DATA_LEN_UPPER_BOUND).contains(&num_bytes) {
232                return Err(Error::Io(io::Error::new(
233                    io::ErrorKind::InvalidInput,
234                    "requested bytes out of bounds",
235                )));
236            }
237
238            send_aap_data(&mut coder, num_bytes).await?;
239
240            let Response::Dial(dial_response) = coder.next().await? else {
241                return Err(Error::Io(io::Error::new(
242                    io::ErrorKind::InvalidInput,
243                    "expected message",
244                )));
245            };
246
247            (dial_response, num_bytes)
248        }
249        Response::Dial(dial_response) => (dial_response, 0),
250    };
251    match coder.close().await {
252        Ok(_) => {}
253        Err(err) => {
254            if err.kind() == io::ErrorKind::ConnectionReset {
255                // The AutoNAT server may have already closed the stream
256                // (this is normal because the probe is finished),
257                // in this case we have this error:
258                // Err(Custom { kind: ConnectionReset, error: Stopped(0) })
259                // so we silently ignore this error
260            } else {
261                return Err(err.into());
262            }
263        }
264    }
265
266    match res.status {
267        ResponseStatus::E_REQUEST_REJECTED => {
268            return Err(Error::Io(io::Error::new(
269                io::ErrorKind::Other,
270                "server rejected request",
271            )))
272        }
273        ResponseStatus::E_DIAL_REFUSED => {
274            return Err(Error::Io(io::Error::new(
275                io::ErrorKind::Other,
276                "server refused dial",
277            )))
278        }
279        ResponseStatus::E_INTERNAL_ERROR => {
280            return Err(Error::Io(io::Error::new(
281                io::ErrorKind::Other,
282                "server encountered internal error",
283            )))
284        }
285        ResponseStatus::OK => {}
286    }
287
288    let tested_address = req
289        .addrs
290        .get(res.addr_idx)
291        .ok_or_else(|| io::Error::new(io::ErrorKind::InvalidInput, "address index out of bounds"))?
292        .clone();
293
294    match res.dial_status {
295        DialStatus::UNUSED => {
296            return Err(Error::Io(io::Error::new(
297                io::ErrorKind::InvalidInput,
298                "unexpected message",
299            )))
300        }
301        DialStatus::E_DIAL_ERROR => {
302            return Err(Error::AddressNotReachable {
303                address: tested_address,
304                bytes_sent,
305                error: DialBackError::NoConnection,
306            })
307        }
308        DialStatus::E_DIAL_BACK_ERROR => {
309            return Err(Error::AddressNotReachable {
310                address: tested_address,
311                bytes_sent,
312                error: DialBackError::StreamFailed,
313            })
314        }
315        DialStatus::OK => {}
316    }
317
318    Ok((tested_address, bytes_sent))
319}
320
321async fn send_aap_data<I>(stream: &mut Coder<I>, num_bytes: usize) -> io::Result<()>
322where
323    I: AsyncWrite + Unpin,
324{
325    let count_full = num_bytes / DATA_FIELD_LEN_UPPER_BOUND;
326    let partial_len = num_bytes % DATA_FIELD_LEN_UPPER_BOUND;
327    for req in repeat_n(DATA_FIELD_LEN_UPPER_BOUND, count_full)
328        .chain(once(partial_len))
329        .filter(|e| *e > 0)
330        .map(|data_count| {
331            DialDataResponse::new(data_count).expect("data count is unexpectedly too big")
332        })
333    {
334        stream.send(req).await?;
335    }
336
337    Ok(())
338}