libp2p_autonat/v2/client/handler/
dial_request.rs

1use std::{
2    collections::VecDeque,
3    convert::Infallible,
4    io,
5    iter::{once, repeat_n},
6    task::{Context, Poll},
7    time::Duration,
8};
9
10use futures::{channel::oneshot, AsyncWrite};
11use futures_bounded::FuturesMap;
12use libp2p_core::{
13    upgrade::{DeniedUpgrade, ReadyUpgrade},
14    Multiaddr,
15};
16use libp2p_swarm::{
17    handler::{
18        ConnectionEvent, DialUpgradeError, FullyNegotiatedOutbound, OutboundUpgradeSend,
19        ProtocolsChange,
20    },
21    ConnectionHandler, ConnectionHandlerEvent, Stream, StreamProtocol, StreamUpgradeError,
22    SubstreamProtocol,
23};
24
25use crate::v2::{
26    generated::structs::{mod_DialResponse::ResponseStatus, DialStatus},
27    protocol::{
28        Coder, DialDataRequest, DialDataResponse, DialRequest, Response,
29        DATA_FIELD_LEN_UPPER_BOUND, DATA_LEN_LOWER_BOUND, DATA_LEN_UPPER_BOUND,
30    },
31    Nonce, DIAL_REQUEST_PROTOCOL,
32};
33
34#[derive(Debug)]
35pub enum ToBehaviour {
36    TestOutcome {
37        nonce: Nonce,
38        outcome: Result<(Multiaddr, usize), Error>,
39    },
40    PeerHasServerSupport,
41}
42
43#[derive(thiserror::Error, Debug)]
44pub enum Error {
45    #[error("Address is not reachable: {error}")]
46    AddressNotReachable {
47        address: Multiaddr,
48        bytes_sent: usize,
49        error: DialBackError,
50    },
51    #[error("Peer does not support AutoNAT dial-request protocol")]
52    UnsupportedProtocol,
53    #[error("IO error: {0}")]
54    Io(io::Error),
55}
56
57impl From<io::Error> for Error {
58    fn from(value: io::Error) -> Self {
59        Self::Io(value)
60    }
61}
62
63#[derive(thiserror::Error, Debug)]
64pub enum DialBackError {
65    #[error("server failed to establish a connection")]
66    NoConnection,
67    #[error("dial back stream failed")]
68    StreamFailed,
69}
70
71pub struct Handler {
72    queued_events: VecDeque<
73        ConnectionHandlerEvent<
74            <Self as ConnectionHandler>::OutboundProtocol,
75            (),
76            <Self as ConnectionHandler>::ToBehaviour,
77        >,
78    >,
79    outbound: FuturesMap<Nonce, Result<(Multiaddr, usize), Error>>,
80    queued_streams: VecDeque<
81        oneshot::Sender<
82            Result<
83                Stream,
84                StreamUpgradeError<<ReadyUpgrade<StreamProtocol> as OutboundUpgradeSend>::Error>,
85            >,
86        >,
87    >,
88}
89
90impl Handler {
91    pub(crate) fn new() -> Self {
92        Self {
93            queued_events: VecDeque::new(),
94            outbound: FuturesMap::new(Duration::from_secs(10), 10),
95            queued_streams: VecDeque::default(),
96        }
97    }
98
99    fn perform_request(&mut self, req: DialRequest) {
100        let (tx, rx) = oneshot::channel();
101        self.queued_streams.push_back(tx);
102        self.queued_events
103            .push_back(ConnectionHandlerEvent::OutboundSubstreamRequest {
104                protocol: SubstreamProtocol::new(ReadyUpgrade::new(DIAL_REQUEST_PROTOCOL), ()),
105            });
106        if self
107            .outbound
108            .try_push(req.nonce, start_stream_handle(req, rx))
109            .is_err()
110        {
111            tracing::debug!("Dial request dropped, too many requests in flight");
112        }
113    }
114}
115
116impl ConnectionHandler for Handler {
117    type FromBehaviour = DialRequest;
118    type ToBehaviour = ToBehaviour;
119    type InboundProtocol = DeniedUpgrade;
120    type OutboundProtocol = ReadyUpgrade<StreamProtocol>;
121    type InboundOpenInfo = ();
122    type OutboundOpenInfo = ();
123
124    fn listen_protocol(&self) -> SubstreamProtocol<Self::InboundProtocol> {
125        SubstreamProtocol::new(DeniedUpgrade, ())
126    }
127
128    fn poll(
129        &mut self,
130        cx: &mut Context<'_>,
131    ) -> Poll<ConnectionHandlerEvent<Self::OutboundProtocol, (), Self::ToBehaviour>> {
132        if let Some(event) = self.queued_events.pop_front() {
133            return Poll::Ready(event);
134        }
135
136        match self.outbound.poll_unpin(cx) {
137            Poll::Ready((nonce, Ok(outcome))) => {
138                return Poll::Ready(ConnectionHandlerEvent::NotifyBehaviour(
139                    ToBehaviour::TestOutcome { nonce, outcome },
140                ))
141            }
142            Poll::Ready((nonce, Err(_))) => {
143                return Poll::Ready(ConnectionHandlerEvent::NotifyBehaviour(
144                    ToBehaviour::TestOutcome {
145                        nonce,
146                        outcome: Err(Error::Io(io::ErrorKind::TimedOut.into())),
147                    },
148                ));
149            }
150            Poll::Pending => {}
151        }
152
153        Poll::Pending
154    }
155
156    fn on_behaviour_event(&mut self, event: Self::FromBehaviour) {
157        self.perform_request(event);
158    }
159
160    fn on_connection_event(
161        &mut self,
162        event: ConnectionEvent<Self::InboundProtocol, Self::OutboundProtocol>,
163    ) {
164        match event {
165            ConnectionEvent::DialUpgradeError(DialUpgradeError { error, .. }) => {
166                tracing::debug!("Dial request failed: {}", error);
167                match self.queued_streams.pop_front() {
168                    Some(stream_tx) => {
169                        let _ = stream_tx.send(Err(error));
170                    }
171                    None => {
172                        tracing::warn!(
173                            "Opened unexpected substream without a pending dial request"
174                        );
175                    }
176                }
177            }
178            ConnectionEvent::FullyNegotiatedOutbound(FullyNegotiatedOutbound {
179                protocol, ..
180            }) => match self.queued_streams.pop_front() {
181                Some(stream_tx) => {
182                    if stream_tx.send(Ok(protocol)).is_err() {
183                        tracing::debug!("Failed to send stream to dead handler");
184                    }
185                }
186                None => {
187                    tracing::warn!("Opened unexpected substream without a pending dial request");
188                }
189            },
190            ConnectionEvent::RemoteProtocolsChange(ProtocolsChange::Added(mut added)) => {
191                if added.any(|p| p.as_ref() == DIAL_REQUEST_PROTOCOL) {
192                    self.queued_events
193                        .push_back(ConnectionHandlerEvent::NotifyBehaviour(
194                            ToBehaviour::PeerHasServerSupport,
195                        ));
196                }
197            }
198            _ => {}
199        }
200    }
201}
202
203async fn start_stream_handle(
204    req: DialRequest,
205    stream_recv: oneshot::Receiver<Result<Stream, StreamUpgradeError<Infallible>>>,
206) -> Result<(Multiaddr, usize), Error> {
207    let stream = stream_recv
208        .await
209        .map_err(|_| io::Error::from(io::ErrorKind::BrokenPipe))?
210        .map_err(|e| match e {
211            StreamUpgradeError::NegotiationFailed => Error::UnsupportedProtocol,
212            StreamUpgradeError::Timeout => Error::Io(io::ErrorKind::TimedOut.into()),
213            StreamUpgradeError::Apply(v) => libp2p_core::util::unreachable(v),
214            StreamUpgradeError::Io(e) => Error::Io(e),
215        })?;
216
217    let mut coder = Coder::new(stream);
218    coder.send(req.clone()).await?;
219
220    let (res, bytes_sent) = match coder.next().await? {
221        Response::Data(DialDataRequest {
222            addr_idx,
223            num_bytes,
224        }) => {
225            if addr_idx >= req.addrs.len() {
226                return Err(Error::Io(io::Error::new(
227                    io::ErrorKind::InvalidInput,
228                    "address index out of bounds",
229                )));
230            }
231            if !(DATA_LEN_LOWER_BOUND..=DATA_LEN_UPPER_BOUND).contains(&num_bytes) {
232                return Err(Error::Io(io::Error::new(
233                    io::ErrorKind::InvalidInput,
234                    "requested bytes out of bounds",
235                )));
236            }
237
238            send_aap_data(&mut coder, num_bytes).await?;
239
240            let Response::Dial(dial_response) = coder.next().await? else {
241                return Err(Error::Io(io::Error::new(
242                    io::ErrorKind::InvalidInput,
243                    "expected message",
244                )));
245            };
246
247            (dial_response, num_bytes)
248        }
249        Response::Dial(dial_response) => (dial_response, 0),
250    };
251    match coder.close().await {
252        Ok(_) => {}
253        Err(err) => {
254            if err.kind() == io::ErrorKind::ConnectionReset {
255                // The AutoNAT server may have already closed the stream
256                // (this is normal because the probe is finished),
257                // in this case we have this error:
258                // Err(Custom { kind: ConnectionReset, error: Stopped(0) })
259                // so we silently ignore this error
260            } else {
261                return Err(err.into());
262            }
263        }
264    }
265
266    match res.status {
267        ResponseStatus::E_REQUEST_REJECTED => {
268            return Err(Error::Io(io::Error::other("server rejected request")))
269        }
270        ResponseStatus::E_DIAL_REFUSED => {
271            return Err(Error::Io(io::Error::other("server refused dial")))
272        }
273        ResponseStatus::E_INTERNAL_ERROR => {
274            return Err(Error::Io(io::Error::other(
275                "server encountered internal error",
276            )))
277        }
278        ResponseStatus::OK => {}
279    }
280
281    let tested_address = req
282        .addrs
283        .get(res.addr_idx)
284        .ok_or_else(|| io::Error::new(io::ErrorKind::InvalidInput, "address index out of bounds"))?
285        .clone();
286
287    match res.dial_status {
288        DialStatus::UNUSED => {
289            return Err(Error::Io(io::Error::new(
290                io::ErrorKind::InvalidInput,
291                "unexpected message",
292            )))
293        }
294        DialStatus::E_DIAL_ERROR => {
295            return Err(Error::AddressNotReachable {
296                address: tested_address,
297                bytes_sent,
298                error: DialBackError::NoConnection,
299            })
300        }
301        DialStatus::E_DIAL_BACK_ERROR => {
302            return Err(Error::AddressNotReachable {
303                address: tested_address,
304                bytes_sent,
305                error: DialBackError::StreamFailed,
306            })
307        }
308        DialStatus::OK => {}
309    }
310
311    Ok((tested_address, bytes_sent))
312}
313
314async fn send_aap_data<I>(stream: &mut Coder<I>, num_bytes: usize) -> io::Result<()>
315where
316    I: AsyncWrite + Unpin,
317{
318    let count_full = num_bytes / DATA_FIELD_LEN_UPPER_BOUND;
319    let partial_len = num_bytes % DATA_FIELD_LEN_UPPER_BOUND;
320    for req in repeat_n(DATA_FIELD_LEN_UPPER_BOUND, count_full)
321        .chain(once(partial_len))
322        .filter(|e| *e > 0)
323        .map(|data_count| {
324            DialDataResponse::new(data_count).expect("data count is unexpectedly too big")
325        })
326    {
327        stream.send(req).await?;
328    }
329
330    Ok(())
331}