libp2p_autonat/v2/server/handler/
dial_back.rs

1use std::{
2    convert::identity,
3    io,
4    task::{Context, Poll},
5    time::Duration,
6};
7
8use futures::{AsyncRead, AsyncWrite};
9use futures_bounded::FuturesSet;
10use libp2p_core::upgrade::{DeniedUpgrade, ReadyUpgrade};
11use libp2p_swarm::{
12    handler::{ConnectionEvent, DialUpgradeError, FullyNegotiatedOutbound},
13    ConnectionHandler, ConnectionHandlerEvent, StreamProtocol, StreamUpgradeError,
14    SubstreamProtocol,
15};
16
17use super::dial_request::{DialBackCommand, DialBackStatus as DialBackRes};
18use crate::v2::{
19    protocol::{dial_back, recv_dial_back_response},
20    DIAL_BACK_PROTOCOL,
21};
22
23pub(crate) type ToBehaviour = io::Result<()>;
24
25pub struct Handler {
26    pending_nonce: Option<DialBackCommand>,
27    requested_substream_nonce: Option<DialBackCommand>,
28    outbound: FuturesSet<ToBehaviour>,
29}
30
31impl Handler {
32    pub(crate) fn new(cmd: DialBackCommand) -> Self {
33        Self {
34            pending_nonce: Some(cmd),
35            requested_substream_nonce: None,
36            outbound: FuturesSet::new(Duration::from_secs(10), 5),
37        }
38    }
39}
40
41impl ConnectionHandler for Handler {
42    type FromBehaviour = ();
43    type ToBehaviour = ToBehaviour;
44    type InboundProtocol = DeniedUpgrade;
45    type OutboundProtocol = ReadyUpgrade<StreamProtocol>;
46    type InboundOpenInfo = ();
47    type OutboundOpenInfo = ();
48
49    fn listen_protocol(&self) -> SubstreamProtocol<Self::InboundProtocol> {
50        SubstreamProtocol::new(DeniedUpgrade, ())
51    }
52
53    fn poll(
54        &mut self,
55        cx: &mut Context<'_>,
56    ) -> Poll<ConnectionHandlerEvent<Self::OutboundProtocol, (), Self::ToBehaviour>> {
57        if let Poll::Ready(result) = self.outbound.poll_unpin(cx) {
58            return Poll::Ready(ConnectionHandlerEvent::NotifyBehaviour(
59                result
60                    .map_err(|timeout| io::Error::new(io::ErrorKind::TimedOut, timeout))
61                    .and_then(identity),
62            ));
63        }
64        if let Some(cmd) = self.pending_nonce.take() {
65            self.requested_substream_nonce = Some(cmd);
66            return Poll::Ready(ConnectionHandlerEvent::OutboundSubstreamRequest {
67                protocol: SubstreamProtocol::new(ReadyUpgrade::new(DIAL_BACK_PROTOCOL), ()),
68            });
69        }
70        Poll::Pending
71    }
72
73    fn on_behaviour_event(&mut self, _event: Self::FromBehaviour) {}
74
75    fn on_connection_event(
76        &mut self,
77        event: ConnectionEvent<Self::InboundProtocol, Self::OutboundProtocol>,
78    ) {
79        match event {
80            ConnectionEvent::FullyNegotiatedOutbound(FullyNegotiatedOutbound {
81                protocol, ..
82            }) => {
83                if let Some(cmd) = self.requested_substream_nonce.take() {
84                    if self
85                        .outbound
86                        .try_push(perform_dial_back(protocol, cmd))
87                        .is_err()
88                    {
89                        tracing::warn!("Dial back dropped, too many requests in flight");
90                    }
91                } else {
92                    tracing::warn!("received dial back substream without nonce");
93                }
94            }
95            ConnectionEvent::DialUpgradeError(DialUpgradeError {
96                error: StreamUpgradeError::NegotiationFailed | StreamUpgradeError::Timeout,
97                ..
98            }) => {
99                if let Some(cmd) = self.requested_substream_nonce.take() {
100                    let _ = cmd.back_channel.send(Err(DialBackRes::DialBackErr));
101                }
102            }
103            _ => {}
104        }
105    }
106}
107
108async fn perform_dial_back(
109    mut stream: impl AsyncRead + AsyncWrite + Unpin,
110    DialBackCommand {
111        nonce,
112        back_channel,
113        ..
114    }: DialBackCommand,
115) -> io::Result<()> {
116    let res = dial_back(&mut stream, nonce)
117        .await
118        .map_err(|_| DialBackRes::DialBackErr)
119        .map(|_| ());
120
121    let res = match res {
122        Ok(()) => recv_dial_back_response(stream)
123            .await
124            .map_err(|_| DialBackRes::DialBackErr)
125            .map(|_| ()),
126        Err(e) => Err(e),
127    };
128    back_channel
129        .send(res)
130        .map_err(|_| io::Error::new(io::ErrorKind::Other, "send error"))?;
131    Ok(())
132}