libp2p_autonat/v2/client/handler/
dial_back.rs

1use std::{
2    convert::Infallible,
3    io,
4    task::{Context, Poll},
5    time::Duration,
6};
7
8use futures::channel::oneshot;
9use futures_bounded::StreamSet;
10use libp2p_core::upgrade::{DeniedUpgrade, ReadyUpgrade};
11use libp2p_swarm::{
12    handler::{ConnectionEvent, FullyNegotiatedInbound, ListenUpgradeError},
13    ConnectionHandler, ConnectionHandlerEvent, StreamProtocol, SubstreamProtocol,
14};
15
16use crate::v2::{protocol, Nonce, DIAL_BACK_PROTOCOL};
17
18pub struct Handler {
19    inbound: StreamSet<io::Result<IncomingNonce>>,
20}
21
22impl Handler {
23    pub(crate) fn new() -> Self {
24        Self {
25            inbound: StreamSet::new(Duration::from_secs(5), 2),
26        }
27    }
28}
29
30impl ConnectionHandler for Handler {
31    type FromBehaviour = Infallible;
32    type ToBehaviour = IncomingNonce;
33    type InboundProtocol = ReadyUpgrade<StreamProtocol>;
34    type OutboundProtocol = DeniedUpgrade;
35    type InboundOpenInfo = ();
36    type OutboundOpenInfo = ();
37
38    fn listen_protocol(&self) -> SubstreamProtocol<Self::InboundProtocol> {
39        SubstreamProtocol::new(ReadyUpgrade::new(DIAL_BACK_PROTOCOL), ())
40    }
41
42    fn poll(
43        &mut self,
44        cx: &mut Context<'_>,
45    ) -> Poll<ConnectionHandlerEvent<Self::OutboundProtocol, (), Self::ToBehaviour>> {
46        loop {
47            match self.inbound.poll_next_unpin(cx) {
48                Poll::Pending => return Poll::Pending,
49                Poll::Ready(None) => continue,
50                Poll::Ready(Some(Err(err))) => {
51                    tracing::debug!("Stream timed out: {err}");
52                    continue;
53                }
54                Poll::Ready(Some(Ok(Err(err)))) => {
55                    tracing::debug!("Dial back handler failed with: {err:?}");
56                    continue;
57                }
58                Poll::Ready(Some(Ok(Ok(incoming_nonce)))) => {
59                    return Poll::Ready(ConnectionHandlerEvent::NotifyBehaviour(incoming_nonce));
60                }
61            }
62        }
63    }
64
65    fn on_behaviour_event(&mut self, _event: Self::FromBehaviour) {}
66
67    fn on_connection_event(
68        &mut self,
69        event: ConnectionEvent<Self::InboundProtocol, Self::OutboundProtocol>,
70    ) {
71        match event {
72            ConnectionEvent::FullyNegotiatedInbound(FullyNegotiatedInbound {
73                protocol, ..
74            }) => {
75                if self.inbound.try_push(perform_dial_back(protocol)).is_err() {
76                    tracing::warn!("Dial back request dropped, too many requests in flight");
77                }
78            }
79            ConnectionEvent::ListenUpgradeError(ListenUpgradeError { error, .. }) => {
80                libp2p_core::util::unreachable(error);
81            }
82            _ => {}
83        }
84    }
85}
86
87struct State {
88    stream: libp2p_swarm::Stream,
89    oneshot: Option<oneshot::Receiver<io::Result<()>>>,
90}
91
92#[derive(Debug)]
93pub struct IncomingNonce {
94    pub nonce: Nonce,
95    pub sender: oneshot::Sender<io::Result<()>>,
96}
97
98fn perform_dial_back(
99    stream: libp2p_swarm::Stream,
100) -> impl futures::Stream<Item = io::Result<IncomingNonce>> {
101    let state = State {
102        stream,
103        oneshot: None,
104    };
105    futures::stream::unfold(state, |mut state| async move {
106        if let Some(ref mut receiver) = state.oneshot {
107            match receiver.await {
108                Ok(Ok(())) => {}
109                Ok(Err(e)) => return Some((Err(e), state)),
110                Err(_) => return None,
111            }
112            if let Err(e) = protocol::dial_back_response(&mut state.stream).await {
113                return Some((Err(e), state));
114            }
115            return None;
116        }
117
118        let nonce = match protocol::recv_dial_back(&mut state.stream).await {
119            Ok(nonce) => nonce,
120            Err(err) => {
121                return Some((Err(err), state));
122            }
123        };
124
125        let (sender, receiver) = oneshot::channel();
126        state.oneshot = Some(receiver);
127        Some((Ok(IncomingNonce { nonce, sender }), state))
128    })
129}