libp2p_autonat/v2/server/handler/
dial_request.rs

1use std::{
2    convert::Infallible,
3    io,
4    task::{Context, Poll},
5    time::Duration,
6};
7
8use either::Either;
9use futures::{
10    channel::{mpsc, oneshot},
11    AsyncRead, AsyncWrite, SinkExt, StreamExt,
12};
13use futures_bounded::FuturesSet;
14use libp2p_core::{
15    upgrade::{DeniedUpgrade, ReadyUpgrade},
16    Multiaddr,
17};
18use libp2p_identity::PeerId;
19use libp2p_swarm::{
20    handler::{ConnectionEvent, FullyNegotiatedInbound, ListenUpgradeError},
21    ConnectionHandler, ConnectionHandlerEvent, StreamProtocol, SubstreamProtocol,
22};
23use rand_core::RngCore;
24
25use crate::v2::{
26    generated::structs::{mod_DialResponse::ResponseStatus, DialStatus},
27    protocol::{Coder, DialDataRequest, DialRequest, DialResponse, Request, Response},
28    server::behaviour::Event,
29    Nonce, DIAL_REQUEST_PROTOCOL,
30};
31
32#[derive(Debug, PartialEq)]
33pub(crate) enum DialBackStatus {
34    /// Failure during dial
35    DialErr,
36    /// Failure during dial back
37    DialBackErr,
38}
39
40#[derive(Debug)]
41pub struct DialBackCommand {
42    pub(crate) addr: Multiaddr,
43    pub(crate) nonce: Nonce,
44    pub(crate) back_channel: oneshot::Sender<Result<(), DialBackStatus>>,
45}
46
47pub struct Handler<R> {
48    client_id: PeerId,
49    observed_multiaddr: Multiaddr,
50    dial_back_cmd_sender: mpsc::Sender<DialBackCommand>,
51    dial_back_cmd_receiver: mpsc::Receiver<DialBackCommand>,
52    inbound: FuturesSet<Event>,
53    rng: R,
54}
55
56impl<R> Handler<R>
57where
58    R: RngCore,
59{
60    pub(crate) fn new(client_id: PeerId, observed_multiaddr: Multiaddr, rng: R) -> Self {
61        let (dial_back_cmd_sender, dial_back_cmd_receiver) = mpsc::channel(10);
62        Self {
63            client_id,
64            observed_multiaddr,
65            dial_back_cmd_sender,
66            dial_back_cmd_receiver,
67            inbound: FuturesSet::new(Duration::from_secs(10), 10),
68            rng,
69        }
70    }
71}
72
73impl<R> ConnectionHandler for Handler<R>
74where
75    R: RngCore + Send + Clone + 'static,
76{
77    type FromBehaviour = Infallible;
78    type ToBehaviour = Either<DialBackCommand, Event>;
79    type InboundProtocol = ReadyUpgrade<StreamProtocol>;
80    type OutboundProtocol = DeniedUpgrade;
81    type InboundOpenInfo = ();
82    type OutboundOpenInfo = ();
83
84    fn listen_protocol(&self) -> SubstreamProtocol<Self::InboundProtocol> {
85        SubstreamProtocol::new(ReadyUpgrade::new(DIAL_REQUEST_PROTOCOL), ())
86    }
87
88    fn poll(
89        &mut self,
90        cx: &mut Context<'_>,
91    ) -> Poll<ConnectionHandlerEvent<Self::OutboundProtocol, (), Self::ToBehaviour>> {
92        loop {
93            match self.inbound.poll_unpin(cx) {
94                Poll::Ready(Ok(event)) => {
95                    if let Err(e) = &event.result {
96                        tracing::warn!("inbound request handle failed: {:?}", e);
97                    }
98                    return Poll::Ready(ConnectionHandlerEvent::NotifyBehaviour(Either::Right(
99                        event,
100                    )));
101                }
102                Poll::Ready(Err(e)) => {
103                    tracing::warn!("inbound request handle timed out {e:?}");
104                }
105                Poll::Pending => break,
106            }
107        }
108        if let Poll::Ready(Some(cmd)) = self.dial_back_cmd_receiver.poll_next_unpin(cx) {
109            return Poll::Ready(ConnectionHandlerEvent::NotifyBehaviour(Either::Left(cmd)));
110        }
111        Poll::Pending
112    }
113
114    fn on_behaviour_event(&mut self, _event: Self::FromBehaviour) {}
115
116    fn on_connection_event(
117        &mut self,
118        event: ConnectionEvent<Self::InboundProtocol, Self::OutboundProtocol>,
119    ) {
120        match event {
121            ConnectionEvent::FullyNegotiatedInbound(FullyNegotiatedInbound {
122                protocol, ..
123            }) => {
124                if self
125                    .inbound
126                    .try_push(handle_request(
127                        protocol,
128                        self.observed_multiaddr.clone(),
129                        self.client_id,
130                        self.dial_back_cmd_sender.clone(),
131                        self.rng.clone(),
132                    ))
133                    .is_err()
134                {
135                    tracing::warn!(
136                        "failed to push inbound request handler, too many requests in flight"
137                    );
138                }
139            }
140            ConnectionEvent::ListenUpgradeError(ListenUpgradeError { error, .. }) => {
141                tracing::debug!("inbound request failed: {:?}", error);
142            }
143            _ => {}
144        }
145    }
146}
147
148enum HandleFail {
149    InternalError(usize),
150    RequestRejected,
151    DialRefused,
152    DialBack {
153        idx: usize,
154        result: Result<(), DialBackStatus>,
155    },
156}
157
158impl From<HandleFail> for DialResponse {
159    fn from(value: HandleFail) -> Self {
160        match value {
161            HandleFail::InternalError(addr_idx) => Self {
162                status: ResponseStatus::E_INTERNAL_ERROR,
163                addr_idx,
164                dial_status: DialStatus::UNUSED,
165            },
166            HandleFail::RequestRejected => Self {
167                status: ResponseStatus::E_REQUEST_REJECTED,
168                addr_idx: 0,
169                dial_status: DialStatus::UNUSED,
170            },
171            HandleFail::DialRefused => Self {
172                status: ResponseStatus::E_DIAL_REFUSED,
173                addr_idx: 0,
174                dial_status: DialStatus::UNUSED,
175            },
176            HandleFail::DialBack { idx, result } => Self {
177                status: ResponseStatus::OK,
178                addr_idx: idx,
179                dial_status: match result {
180                    Err(DialBackStatus::DialErr) => DialStatus::E_DIAL_ERROR,
181                    Err(DialBackStatus::DialBackErr) => DialStatus::E_DIAL_BACK_ERROR,
182                    Ok(()) => DialStatus::OK,
183                },
184            },
185        }
186    }
187}
188
189async fn handle_request(
190    stream: impl AsyncRead + AsyncWrite + Unpin,
191    observed_multiaddr: Multiaddr,
192    client: PeerId,
193    dial_back_cmd_sender: mpsc::Sender<DialBackCommand>,
194    rng: impl RngCore,
195) -> Event {
196    let mut coder = Coder::new(stream);
197    let mut all_addrs = Vec::new();
198    let mut tested_addr_opt = None;
199    let mut data_amount = 0;
200    let response = handle_request_internal(
201        &mut coder,
202        observed_multiaddr.clone(),
203        dial_back_cmd_sender,
204        rng,
205        &mut all_addrs,
206        &mut tested_addr_opt,
207        &mut data_amount,
208    )
209    .await
210    .unwrap_or_else(|e| e.into());
211    let Some(tested_addr) = tested_addr_opt else {
212        return Event {
213            all_addrs,
214            tested_addr: observed_multiaddr,
215            client,
216            data_amount,
217            result: Err(io::Error::new(
218                io::ErrorKind::Other,
219                "client is not conformint to protocol. the tested address is not the observed address",
220            )),
221        };
222    };
223    if let Err(e) = coder.send(Response::Dial(response)).await {
224        return Event {
225            all_addrs,
226            tested_addr,
227            client,
228            data_amount,
229            result: Err(e),
230        };
231    }
232    if let Err(e) = coder.close().await {
233        return Event {
234            all_addrs,
235            tested_addr,
236            client,
237            data_amount,
238            result: Err(e),
239        };
240    }
241    Event {
242        all_addrs,
243        tested_addr,
244        client,
245        data_amount,
246        result: Ok(()),
247    }
248}
249
250async fn handle_request_internal<I>(
251    coder: &mut Coder<I>,
252    observed_multiaddr: Multiaddr,
253    dial_back_cmd_sender: mpsc::Sender<DialBackCommand>,
254    mut rng: impl RngCore,
255    all_addrs: &mut Vec<Multiaddr>,
256    tested_addrs: &mut Option<Multiaddr>,
257    data_amount: &mut usize,
258) -> Result<DialResponse, HandleFail>
259where
260    I: AsyncRead + AsyncWrite + Unpin,
261{
262    let DialRequest { mut addrs, nonce } = match coder
263        .next()
264        .await
265        .map_err(|_| HandleFail::InternalError(0))?
266    {
267        Request::Dial(dial_request) => dial_request,
268        Request::Data(_) => {
269            return Err(HandleFail::RequestRejected);
270        }
271    };
272    all_addrs.clone_from(&addrs);
273    let idx = 0;
274    let addr = addrs.pop().ok_or(HandleFail::DialRefused)?;
275    *tested_addrs = Some(addr.clone());
276    *data_amount = 0;
277    if addr != observed_multiaddr {
278        let dial_data_request = DialDataRequest::from_rng(idx, &mut rng);
279        let mut rem_data = dial_data_request.num_bytes;
280        coder
281            .send(Response::Data(dial_data_request))
282            .await
283            .map_err(|_| HandleFail::InternalError(idx))?;
284        while rem_data > 0 {
285            let data_count = match coder
286                .next()
287                .await
288                .map_err(|_e| HandleFail::InternalError(idx))?
289            {
290                Request::Dial(_) => {
291                    return Err(HandleFail::RequestRejected);
292                }
293                Request::Data(dial_data_response) => dial_data_response.get_data_count(),
294            };
295            rem_data = rem_data.saturating_sub(data_count);
296            *data_amount += data_count;
297        }
298    }
299    let (back_channel, rx) = oneshot::channel();
300    let dial_back_cmd = DialBackCommand {
301        addr,
302        nonce,
303        back_channel,
304    };
305    dial_back_cmd_sender
306        .clone()
307        .send(dial_back_cmd)
308        .await
309        .map_err(|_| HandleFail::DialBack {
310            idx,
311            result: Err(DialBackStatus::DialErr),
312        })?;
313
314    let dial_back = rx.await.map_err(|_e| HandleFail::InternalError(idx))?;
315    if let Err(err) = dial_back {
316        return Err(HandleFail::DialBack {
317            idx,
318            result: Err(err),
319        });
320    }
321    Ok(DialResponse {
322        status: ResponseStatus::OK,
323        addr_idx: idx,
324        dial_status: DialStatus::OK,
325    })
326}