1use std::{
2 convert::Infallible,
3 io,
4 task::{Context, Poll},
5 time::Duration,
6};
7
8use either::Either;
9use futures::{
10 channel::{mpsc, oneshot},
11 AsyncRead, AsyncWrite, SinkExt, StreamExt,
12};
13use futures_bounded::FuturesSet;
14use libp2p_core::{
15 upgrade::{DeniedUpgrade, ReadyUpgrade},
16 Multiaddr,
17};
18use libp2p_identity::PeerId;
19use libp2p_swarm::{
20 handler::{ConnectionEvent, FullyNegotiatedInbound, ListenUpgradeError},
21 ConnectionHandler, ConnectionHandlerEvent, StreamProtocol, SubstreamProtocol,
22};
23use rand_core::RngCore;
24
25use crate::v2::{
26 generated::structs::{mod_DialResponse::ResponseStatus, DialStatus},
27 protocol::{Coder, DialDataRequest, DialRequest, DialResponse, Request, Response},
28 server::behaviour::Event,
29 Nonce, DIAL_REQUEST_PROTOCOL,
30};
31
32#[derive(Debug, PartialEq)]
33pub(crate) enum DialBackStatus {
34 DialErr,
36 DialBackErr,
38}
39
40#[derive(Debug)]
41pub struct DialBackCommand {
42 pub(crate) addr: Multiaddr,
43 pub(crate) nonce: Nonce,
44 pub(crate) back_channel: oneshot::Sender<Result<(), DialBackStatus>>,
45}
46
47pub struct Handler<R> {
48 client_id: PeerId,
49 observed_multiaddr: Multiaddr,
50 dial_back_cmd_sender: mpsc::Sender<DialBackCommand>,
51 dial_back_cmd_receiver: mpsc::Receiver<DialBackCommand>,
52 inbound: FuturesSet<Event>,
53 rng: R,
54}
55
56impl<R> Handler<R>
57where
58 R: RngCore,
59{
60 pub(crate) fn new(client_id: PeerId, observed_multiaddr: Multiaddr, rng: R) -> Self {
61 let (dial_back_cmd_sender, dial_back_cmd_receiver) = mpsc::channel(10);
62 Self {
63 client_id,
64 observed_multiaddr,
65 dial_back_cmd_sender,
66 dial_back_cmd_receiver,
67 inbound: FuturesSet::new(Duration::from_secs(10), 10),
68 rng,
69 }
70 }
71}
72
73impl<R> ConnectionHandler for Handler<R>
74where
75 R: RngCore + Send + Clone + 'static,
76{
77 type FromBehaviour = Infallible;
78 type ToBehaviour = Either<DialBackCommand, Event>;
79 type InboundProtocol = ReadyUpgrade<StreamProtocol>;
80 type OutboundProtocol = DeniedUpgrade;
81 type InboundOpenInfo = ();
82 type OutboundOpenInfo = ();
83
84 fn listen_protocol(&self) -> SubstreamProtocol<Self::InboundProtocol> {
85 SubstreamProtocol::new(ReadyUpgrade::new(DIAL_REQUEST_PROTOCOL), ())
86 }
87
88 fn poll(
89 &mut self,
90 cx: &mut Context<'_>,
91 ) -> Poll<ConnectionHandlerEvent<Self::OutboundProtocol, (), Self::ToBehaviour>> {
92 loop {
93 match self.inbound.poll_unpin(cx) {
94 Poll::Ready(Ok(event)) => {
95 if let Err(e) = &event.result {
96 tracing::warn!("inbound request handle failed: {:?}", e);
97 }
98 return Poll::Ready(ConnectionHandlerEvent::NotifyBehaviour(Either::Right(
99 event,
100 )));
101 }
102 Poll::Ready(Err(e)) => {
103 tracing::warn!("inbound request handle timed out {e:?}");
104 }
105 Poll::Pending => break,
106 }
107 }
108 if let Poll::Ready(Some(cmd)) = self.dial_back_cmd_receiver.poll_next_unpin(cx) {
109 return Poll::Ready(ConnectionHandlerEvent::NotifyBehaviour(Either::Left(cmd)));
110 }
111 Poll::Pending
112 }
113
114 fn on_behaviour_event(&mut self, _event: Self::FromBehaviour) {}
115
116 fn on_connection_event(
117 &mut self,
118 event: ConnectionEvent<Self::InboundProtocol, Self::OutboundProtocol>,
119 ) {
120 match event {
121 ConnectionEvent::FullyNegotiatedInbound(FullyNegotiatedInbound {
122 protocol, ..
123 }) => {
124 if self
125 .inbound
126 .try_push(handle_request(
127 protocol,
128 self.observed_multiaddr.clone(),
129 self.client_id,
130 self.dial_back_cmd_sender.clone(),
131 self.rng.clone(),
132 ))
133 .is_err()
134 {
135 tracing::warn!(
136 "failed to push inbound request handler, too many requests in flight"
137 );
138 }
139 }
140 ConnectionEvent::ListenUpgradeError(ListenUpgradeError { error, .. }) => {
141 tracing::debug!("inbound request failed: {:?}", error);
142 }
143 _ => {}
144 }
145 }
146}
147
148enum HandleFail {
149 InternalError(usize),
150 RequestRejected,
151 DialRefused,
152 DialBack {
153 idx: usize,
154 result: Result<(), DialBackStatus>,
155 },
156}
157
158impl From<HandleFail> for DialResponse {
159 fn from(value: HandleFail) -> Self {
160 match value {
161 HandleFail::InternalError(addr_idx) => Self {
162 status: ResponseStatus::E_INTERNAL_ERROR,
163 addr_idx,
164 dial_status: DialStatus::UNUSED,
165 },
166 HandleFail::RequestRejected => Self {
167 status: ResponseStatus::E_REQUEST_REJECTED,
168 addr_idx: 0,
169 dial_status: DialStatus::UNUSED,
170 },
171 HandleFail::DialRefused => Self {
172 status: ResponseStatus::E_DIAL_REFUSED,
173 addr_idx: 0,
174 dial_status: DialStatus::UNUSED,
175 },
176 HandleFail::DialBack { idx, result } => Self {
177 status: ResponseStatus::OK,
178 addr_idx: idx,
179 dial_status: match result {
180 Err(DialBackStatus::DialErr) => DialStatus::E_DIAL_ERROR,
181 Err(DialBackStatus::DialBackErr) => DialStatus::E_DIAL_BACK_ERROR,
182 Ok(()) => DialStatus::OK,
183 },
184 },
185 }
186 }
187}
188
189async fn handle_request(
190 stream: impl AsyncRead + AsyncWrite + Unpin,
191 observed_multiaddr: Multiaddr,
192 client: PeerId,
193 dial_back_cmd_sender: mpsc::Sender<DialBackCommand>,
194 rng: impl RngCore,
195) -> Event {
196 let mut coder = Coder::new(stream);
197 let mut all_addrs = Vec::new();
198 let mut tested_addr_opt = None;
199 let mut data_amount = 0;
200 let response = handle_request_internal(
201 &mut coder,
202 observed_multiaddr.clone(),
203 dial_back_cmd_sender,
204 rng,
205 &mut all_addrs,
206 &mut tested_addr_opt,
207 &mut data_amount,
208 )
209 .await
210 .unwrap_or_else(|e| e.into());
211 let Some(tested_addr) = tested_addr_opt else {
212 return Event {
213 all_addrs,
214 tested_addr: observed_multiaddr,
215 client,
216 data_amount,
217 result: Err(io::Error::new(
218 io::ErrorKind::Other,
219 "client is not conformint to protocol. the tested address is not the observed address",
220 )),
221 };
222 };
223 if let Err(e) = coder.send(Response::Dial(response)).await {
224 return Event {
225 all_addrs,
226 tested_addr,
227 client,
228 data_amount,
229 result: Err(e),
230 };
231 }
232 if let Err(e) = coder.close().await {
233 return Event {
234 all_addrs,
235 tested_addr,
236 client,
237 data_amount,
238 result: Err(e),
239 };
240 }
241 Event {
242 all_addrs,
243 tested_addr,
244 client,
245 data_amount,
246 result: Ok(()),
247 }
248}
249
250async fn handle_request_internal<I>(
251 coder: &mut Coder<I>,
252 observed_multiaddr: Multiaddr,
253 dial_back_cmd_sender: mpsc::Sender<DialBackCommand>,
254 mut rng: impl RngCore,
255 all_addrs: &mut Vec<Multiaddr>,
256 tested_addrs: &mut Option<Multiaddr>,
257 data_amount: &mut usize,
258) -> Result<DialResponse, HandleFail>
259where
260 I: AsyncRead + AsyncWrite + Unpin,
261{
262 let DialRequest { mut addrs, nonce } = match coder
263 .next()
264 .await
265 .map_err(|_| HandleFail::InternalError(0))?
266 {
267 Request::Dial(dial_request) => dial_request,
268 Request::Data(_) => {
269 return Err(HandleFail::RequestRejected);
270 }
271 };
272 all_addrs.clone_from(&addrs);
273 let idx = 0;
274 let addr = addrs.pop().ok_or(HandleFail::DialRefused)?;
275 *tested_addrs = Some(addr.clone());
276 *data_amount = 0;
277 if addr != observed_multiaddr {
278 let dial_data_request = DialDataRequest::from_rng(idx, &mut rng);
279 let mut rem_data = dial_data_request.num_bytes;
280 coder
281 .send(Response::Data(dial_data_request))
282 .await
283 .map_err(|_| HandleFail::InternalError(idx))?;
284 while rem_data > 0 {
285 let data_count = match coder
286 .next()
287 .await
288 .map_err(|_e| HandleFail::InternalError(idx))?
289 {
290 Request::Dial(_) => {
291 return Err(HandleFail::RequestRejected);
292 }
293 Request::Data(dial_data_response) => dial_data_response.get_data_count(),
294 };
295 rem_data = rem_data.saturating_sub(data_count);
296 *data_amount += data_count;
297 }
298 }
299 let (back_channel, rx) = oneshot::channel();
300 let dial_back_cmd = DialBackCommand {
301 addr,
302 nonce,
303 back_channel,
304 };
305 dial_back_cmd_sender
306 .clone()
307 .send(dial_back_cmd)
308 .await
309 .map_err(|_| HandleFail::DialBack {
310 idx,
311 result: Err(DialBackStatus::DialErr),
312 })?;
313
314 let dial_back = rx.await.map_err(|_e| HandleFail::InternalError(idx))?;
315 if let Err(err) = dial_back {
316 return Err(HandleFail::DialBack {
317 idx,
318 result: Err(err),
319 });
320 }
321 Ok(DialResponse {
322 status: ResponseStatus::OK,
323 addr_idx: idx,
324 dial_status: DialStatus::OK,
325 })
326}