1pub(crate) mod protocol;
22
23use std::{
24 collections::VecDeque,
25 fmt, io,
26 sync::{
27 atomic::{AtomicU64, Ordering},
28 Arc,
29 },
30 task::{Context, Poll},
31 time::Duration,
32};
33
34use futures::{
35 channel::{mpsc, oneshot},
36 prelude::*,
37};
38use libp2p_swarm::{
39 handler::{
40 ConnectionEvent, ConnectionHandler, ConnectionHandlerEvent, DialUpgradeError,
41 FullyNegotiatedInbound, FullyNegotiatedOutbound, ListenUpgradeError, StreamUpgradeError,
42 },
43 SubstreamProtocol,
44};
45pub use protocol::ProtocolSupport;
46use smallvec::SmallVec;
47
48use crate::{
49 codec::Codec, handler::protocol::Protocol, InboundRequestId, OutboundRequestId,
50 EMPTY_QUEUE_SHRINK_THRESHOLD,
51};
52
53pub struct Handler<TCodec>
55where
56 TCodec: Codec,
57{
58 inbound_protocols: SmallVec<[TCodec::Protocol; 2]>,
60 codec: TCodec,
62 pending_events: VecDeque<Event<TCodec>>,
64 pending_outbound: VecDeque<OutboundMessage<TCodec>>,
66
67 requested_outbound: VecDeque<OutboundMessage<TCodec>>,
68 inbound_receiver: mpsc::Receiver<(
70 InboundRequestId,
71 TCodec::Request,
72 oneshot::Sender<TCodec::Response>,
73 )>,
74 inbound_sender: mpsc::Sender<(
76 InboundRequestId,
77 TCodec::Request,
78 oneshot::Sender<TCodec::Response>,
79 )>,
80
81 inbound_request_id: Arc<AtomicU64>,
82
83 worker_streams: futures_bounded::FuturesMap<RequestId, Result<Event<TCodec>, io::Error>>,
84}
85
86#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
87enum RequestId {
88 Inbound(InboundRequestId),
89 Outbound(OutboundRequestId),
90}
91
92impl<TCodec> Handler<TCodec>
93where
94 TCodec: Codec + Send + Clone + 'static,
95{
96 pub(super) fn new(
97 inbound_protocols: SmallVec<[TCodec::Protocol; 2]>,
98 codec: TCodec,
99 substream_timeout: Duration,
100 inbound_request_id: Arc<AtomicU64>,
101 max_concurrent_streams: usize,
102 ) -> Self {
103 let (inbound_sender, inbound_receiver) = mpsc::channel(0);
104 Self {
105 inbound_protocols,
106 codec,
107 pending_outbound: VecDeque::new(),
108 requested_outbound: Default::default(),
109 inbound_receiver,
110 inbound_sender,
111 pending_events: VecDeque::new(),
112 inbound_request_id,
113 worker_streams: futures_bounded::FuturesMap::new(
114 substream_timeout,
115 max_concurrent_streams,
116 ),
117 }
118 }
119
120 fn next_inbound_request_id(&mut self) -> InboundRequestId {
122 InboundRequestId(self.inbound_request_id.fetch_add(1, Ordering::Relaxed))
123 }
124
125 fn on_fully_negotiated_inbound(
126 &mut self,
127 FullyNegotiatedInbound {
128 protocol: (mut stream, protocol),
129 info: (),
130 }: FullyNegotiatedInbound<<Self as ConnectionHandler>::InboundProtocol>,
131 ) {
132 let mut codec = self.codec.clone();
133 let request_id = self.next_inbound_request_id();
134 let mut sender = self.inbound_sender.clone();
135
136 let recv = async move {
137 let (rs_send, rs_recv) = oneshot::channel();
140
141 let read = codec.read_request(&protocol, &mut stream);
142 let request = read.await?;
143 sender
144 .send((request_id, request, rs_send))
145 .await
146 .expect("`ConnectionHandler` owns both ends of the channel");
147 drop(sender);
148
149 if let Ok(response) = rs_recv.await {
150 let write = codec.write_response(&protocol, &mut stream, response);
151 write.await?;
152
153 stream.close().await?;
154 Ok(Event::ResponseSent(request_id))
155 } else {
156 stream.close().await?;
157 Ok(Event::ResponseOmission(request_id))
158 }
159 };
160
161 if self
165 .worker_streams
166 .try_push(RequestId::Inbound(request_id), recv.boxed())
167 .is_err()
168 {
169 tracing::warn!("Dropping inbound stream because we are at capacity")
170 }
171 }
172
173 fn on_fully_negotiated_outbound(
174 &mut self,
175 FullyNegotiatedOutbound {
176 protocol: (mut stream, protocol),
177 info: (),
178 }: FullyNegotiatedOutbound<<Self as ConnectionHandler>::OutboundProtocol>,
179 ) {
180 let message = self
181 .requested_outbound
182 .pop_front()
183 .expect("negotiated a stream without a pending message");
184
185 let mut codec = self.codec.clone();
186 let request_id = message.request_id;
187
188 let send = async move {
189 let write = codec.write_request(&protocol, &mut stream, message.request);
190 write.await?;
191 stream.close().await?;
192 let read = codec.read_response(&protocol, &mut stream);
193 let response = read.await?;
194
195 Ok(Event::Response {
196 request_id,
197 response,
198 })
199 };
200
201 if self
202 .worker_streams
203 .try_push(RequestId::Outbound(request_id), send.boxed())
204 .is_err()
205 {
206 self.pending_events.push_back(Event::OutboundStreamFailed {
207 request_id: message.request_id,
208 error: io::Error::new(io::ErrorKind::Other, "max sub-streams reached"),
209 });
210 }
211 }
212
213 fn on_dial_upgrade_error(
214 &mut self,
215 DialUpgradeError { error, info: () }: DialUpgradeError<
216 (),
217 <Self as ConnectionHandler>::OutboundProtocol,
218 >,
219 ) {
220 let message = self
221 .requested_outbound
222 .pop_front()
223 .expect("negotiated a stream without a pending message");
224
225 match error {
226 StreamUpgradeError::Timeout => {
227 self.pending_events
228 .push_back(Event::OutboundTimeout(message.request_id));
229 }
230 StreamUpgradeError::NegotiationFailed => {
231 self.pending_events
237 .push_back(Event::OutboundUnsupportedProtocols(message.request_id));
238 }
239 StreamUpgradeError::Apply(e) => libp2p_core::util::unreachable(e),
240 StreamUpgradeError::Io(e) => {
241 self.pending_events.push_back(Event::OutboundStreamFailed {
242 request_id: message.request_id,
243 error: e,
244 });
245 }
246 }
247 }
248 fn on_listen_upgrade_error(
249 &mut self,
250 ListenUpgradeError { error, .. }: ListenUpgradeError<
251 (),
252 <Self as ConnectionHandler>::InboundProtocol,
253 >,
254 ) {
255 libp2p_core::util::unreachable(error)
256 }
257}
258
259pub enum Event<TCodec>
261where
262 TCodec: Codec,
263{
264 Request {
266 request_id: InboundRequestId,
267 request: TCodec::Request,
268 sender: oneshot::Sender<TCodec::Response>,
269 },
270 Response {
272 request_id: OutboundRequestId,
273 response: TCodec::Response,
274 },
275 ResponseSent(InboundRequestId),
277 ResponseOmission(InboundRequestId),
280 OutboundTimeout(OutboundRequestId),
283 OutboundUnsupportedProtocols(OutboundRequestId),
285 OutboundStreamFailed {
286 request_id: OutboundRequestId,
287 error: io::Error,
288 },
289 InboundTimeout(InboundRequestId),
292 InboundStreamFailed {
293 request_id: InboundRequestId,
294 error: io::Error,
295 },
296}
297
298impl<TCodec: Codec> fmt::Debug for Event<TCodec> {
299 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
300 match self {
301 Event::Request {
302 request_id,
303 request: _,
304 sender: _,
305 } => f
306 .debug_struct("Event::Request")
307 .field("request_id", request_id)
308 .finish(),
309 Event::Response {
310 request_id,
311 response: _,
312 } => f
313 .debug_struct("Event::Response")
314 .field("request_id", request_id)
315 .finish(),
316 Event::ResponseSent(request_id) => f
317 .debug_tuple("Event::ResponseSent")
318 .field(request_id)
319 .finish(),
320 Event::ResponseOmission(request_id) => f
321 .debug_tuple("Event::ResponseOmission")
322 .field(request_id)
323 .finish(),
324 Event::OutboundTimeout(request_id) => f
325 .debug_tuple("Event::OutboundTimeout")
326 .field(request_id)
327 .finish(),
328 Event::OutboundUnsupportedProtocols(request_id) => f
329 .debug_tuple("Event::OutboundUnsupportedProtocols")
330 .field(request_id)
331 .finish(),
332 Event::OutboundStreamFailed { request_id, error } => f
333 .debug_struct("Event::OutboundStreamFailed")
334 .field("request_id", &request_id)
335 .field("error", &error)
336 .finish(),
337 Event::InboundTimeout(request_id) => f
338 .debug_tuple("Event::InboundTimeout")
339 .field(request_id)
340 .finish(),
341 Event::InboundStreamFailed { request_id, error } => f
342 .debug_struct("Event::InboundStreamFailed")
343 .field("request_id", &request_id)
344 .field("error", &error)
345 .finish(),
346 }
347 }
348}
349
350pub struct OutboundMessage<TCodec: Codec> {
351 pub(crate) request_id: OutboundRequestId,
352 pub(crate) request: TCodec::Request,
353 pub(crate) protocols: SmallVec<[TCodec::Protocol; 2]>,
354}
355
356impl<TCodec> fmt::Debug for OutboundMessage<TCodec>
357where
358 TCodec: Codec,
359{
360 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
361 f.debug_struct("OutboundMessage").finish_non_exhaustive()
362 }
363}
364
365impl<TCodec> ConnectionHandler for Handler<TCodec>
366where
367 TCodec: Codec + Send + Clone + 'static,
368{
369 type FromBehaviour = OutboundMessage<TCodec>;
370 type ToBehaviour = Event<TCodec>;
371 type InboundProtocol = Protocol<TCodec::Protocol>;
372 type OutboundProtocol = Protocol<TCodec::Protocol>;
373 type OutboundOpenInfo = ();
374 type InboundOpenInfo = ();
375
376 fn listen_protocol(&self) -> SubstreamProtocol<Self::InboundProtocol> {
377 SubstreamProtocol::new(
378 Protocol {
379 protocols: self.inbound_protocols.clone(),
380 },
381 (),
382 )
383 }
384
385 fn on_behaviour_event(&mut self, request: Self::FromBehaviour) {
386 self.pending_outbound.push_back(request);
387 }
388
389 #[tracing::instrument(level = "trace", name = "ConnectionHandler::poll", skip(self, cx))]
390 fn poll(
391 &mut self,
392 cx: &mut Context<'_>,
393 ) -> Poll<ConnectionHandlerEvent<Protocol<TCodec::Protocol>, (), Self::ToBehaviour>> {
394 match self.worker_streams.poll_unpin(cx) {
395 Poll::Ready((_, Ok(Ok(event)))) => {
396 return Poll::Ready(ConnectionHandlerEvent::NotifyBehaviour(event));
397 }
398 Poll::Ready((RequestId::Inbound(id), Ok(Err(e)))) => {
399 return Poll::Ready(ConnectionHandlerEvent::NotifyBehaviour(
400 Event::InboundStreamFailed {
401 request_id: id,
402 error: e,
403 },
404 ));
405 }
406 Poll::Ready((RequestId::Outbound(id), Ok(Err(e)))) => {
407 return Poll::Ready(ConnectionHandlerEvent::NotifyBehaviour(
408 Event::OutboundStreamFailed {
409 request_id: id,
410 error: e,
411 },
412 ));
413 }
414 Poll::Ready((RequestId::Inbound(id), Err(futures_bounded::Timeout { .. }))) => {
415 return Poll::Ready(ConnectionHandlerEvent::NotifyBehaviour(
416 Event::InboundTimeout(id),
417 ));
418 }
419 Poll::Ready((RequestId::Outbound(id), Err(futures_bounded::Timeout { .. }))) => {
420 return Poll::Ready(ConnectionHandlerEvent::NotifyBehaviour(
421 Event::OutboundTimeout(id),
422 ));
423 }
424 Poll::Pending => {}
425 }
426
427 if let Some(event) = self.pending_events.pop_front() {
429 return Poll::Ready(ConnectionHandlerEvent::NotifyBehaviour(event));
430 } else if self.pending_events.capacity() > EMPTY_QUEUE_SHRINK_THRESHOLD {
431 self.pending_events.shrink_to_fit();
432 }
433
434 if let Poll::Ready(Some((id, rq, rs_sender))) = self.inbound_receiver.poll_next_unpin(cx) {
436 return Poll::Ready(ConnectionHandlerEvent::NotifyBehaviour(Event::Request {
439 request_id: id,
440 request: rq,
441 sender: rs_sender,
442 }));
443 }
444
445 if let Some(request) = self.pending_outbound.pop_front() {
447 let protocols = request.protocols.clone();
448 self.requested_outbound.push_back(request);
449
450 return Poll::Ready(ConnectionHandlerEvent::OutboundSubstreamRequest {
451 protocol: SubstreamProtocol::new(Protocol { protocols }, ()),
452 });
453 }
454
455 debug_assert!(self.pending_outbound.is_empty());
456
457 if self.pending_outbound.capacity() > EMPTY_QUEUE_SHRINK_THRESHOLD {
458 self.pending_outbound.shrink_to_fit();
459 }
460
461 Poll::Pending
462 }
463
464 fn on_connection_event(
465 &mut self,
466 event: ConnectionEvent<Self::InboundProtocol, Self::OutboundProtocol>,
467 ) {
468 match event {
469 ConnectionEvent::FullyNegotiatedInbound(fully_negotiated_inbound) => {
470 self.on_fully_negotiated_inbound(fully_negotiated_inbound)
471 }
472 ConnectionEvent::FullyNegotiatedOutbound(fully_negotiated_outbound) => {
473 self.on_fully_negotiated_outbound(fully_negotiated_outbound)
474 }
475 ConnectionEvent::DialUpgradeError(dial_upgrade_error) => {
476 self.on_dial_upgrade_error(dial_upgrade_error)
477 }
478 ConnectionEvent::ListenUpgradeError(listen_upgrade_error) => {
479 self.on_listen_upgrade_error(listen_upgrade_error)
480 }
481 _ => {}
482 }
483 }
484}