libp2p_request_response/
handler.rs

1// Copyright 2020 Parity Technologies (UK) Ltd.
2//
3// Permission is hereby granted, free of charge, to any person obtaining a
4// copy of this software and associated documentation files (the "Software"),
5// to deal in the Software without restriction, including without limitation
6// the rights to use, copy, modify, merge, publish, distribute, sublicense,
7// and/or sell copies of the Software, and to permit persons to whom the
8// Software is furnished to do so, subject to the following conditions:
9//
10// The above copyright notice and this permission notice shall be included in
11// all copies or substantial portions of the Software.
12//
13// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
14// OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
16// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
18// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
19// DEALINGS IN THE SOFTWARE.
20
21pub(crate) mod protocol;
22
23use std::{
24    collections::VecDeque,
25    fmt, io,
26    sync::{
27        atomic::{AtomicU64, Ordering},
28        Arc,
29    },
30    task::{Context, Poll},
31    time::Duration,
32};
33
34use futures::{
35    channel::{mpsc, oneshot},
36    prelude::*,
37};
38use libp2p_swarm::{
39    handler::{
40        ConnectionEvent, ConnectionHandler, ConnectionHandlerEvent, DialUpgradeError,
41        FullyNegotiatedInbound, FullyNegotiatedOutbound, ListenUpgradeError, StreamUpgradeError,
42    },
43    SubstreamProtocol,
44};
45pub use protocol::ProtocolSupport;
46use smallvec::SmallVec;
47
48use crate::{
49    codec::Codec, handler::protocol::Protocol, InboundRequestId, OutboundRequestId,
50    EMPTY_QUEUE_SHRINK_THRESHOLD,
51};
52
53/// A connection handler for a request response [`Behaviour`](super::Behaviour) protocol.
54pub struct Handler<TCodec>
55where
56    TCodec: Codec,
57{
58    /// The supported inbound protocols.
59    inbound_protocols: SmallVec<[TCodec::Protocol; 2]>,
60    /// The request/response message codec.
61    codec: TCodec,
62    /// Queue of events to emit in `poll()`.
63    pending_events: VecDeque<Event<TCodec>>,
64    /// Outbound upgrades waiting to be emitted as an `OutboundSubstreamRequest`.
65    pending_outbound: VecDeque<OutboundMessage<TCodec>>,
66
67    requested_outbound: VecDeque<OutboundMessage<TCodec>>,
68    /// A channel for receiving inbound requests.
69    inbound_receiver: mpsc::Receiver<(
70        InboundRequestId,
71        TCodec::Request,
72        oneshot::Sender<TCodec::Response>,
73    )>,
74    /// The [`mpsc::Sender`] for the above receiver. Cloned for each inbound request.
75    inbound_sender: mpsc::Sender<(
76        InboundRequestId,
77        TCodec::Request,
78        oneshot::Sender<TCodec::Response>,
79    )>,
80
81    inbound_request_id: Arc<AtomicU64>,
82
83    worker_streams: futures_bounded::FuturesMap<RequestId, Result<Event<TCodec>, io::Error>>,
84}
85
86#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
87enum RequestId {
88    Inbound(InboundRequestId),
89    Outbound(OutboundRequestId),
90}
91
92impl<TCodec> Handler<TCodec>
93where
94    TCodec: Codec + Send + Clone + 'static,
95{
96    pub(super) fn new(
97        inbound_protocols: SmallVec<[TCodec::Protocol; 2]>,
98        codec: TCodec,
99        substream_timeout: Duration,
100        inbound_request_id: Arc<AtomicU64>,
101        max_concurrent_streams: usize,
102    ) -> Self {
103        let (inbound_sender, inbound_receiver) = mpsc::channel(0);
104        Self {
105            inbound_protocols,
106            codec,
107            pending_outbound: VecDeque::new(),
108            requested_outbound: Default::default(),
109            inbound_receiver,
110            inbound_sender,
111            pending_events: VecDeque::new(),
112            inbound_request_id,
113            worker_streams: futures_bounded::FuturesMap::new(
114                substream_timeout,
115                max_concurrent_streams,
116            ),
117        }
118    }
119
120    /// Returns the next inbound request ID.
121    fn next_inbound_request_id(&mut self) -> InboundRequestId {
122        InboundRequestId(self.inbound_request_id.fetch_add(1, Ordering::Relaxed))
123    }
124
125    fn on_fully_negotiated_inbound(
126        &mut self,
127        FullyNegotiatedInbound {
128            protocol: (mut stream, protocol),
129            info: (),
130        }: FullyNegotiatedInbound<<Self as ConnectionHandler>::InboundProtocol>,
131    ) {
132        let mut codec = self.codec.clone();
133        let request_id = self.next_inbound_request_id();
134        let mut sender = self.inbound_sender.clone();
135
136        let recv = async move {
137            // A channel for notifying the inbound upgrade when the
138            // response is sent.
139            let (rs_send, rs_recv) = oneshot::channel();
140
141            let read = codec.read_request(&protocol, &mut stream);
142            let request = read.await?;
143            sender
144                .send((request_id, request, rs_send))
145                .await
146                .expect("`ConnectionHandler` owns both ends of the channel");
147            drop(sender);
148
149            if let Ok(response) = rs_recv.await {
150                let write = codec.write_response(&protocol, &mut stream, response);
151                write.await?;
152
153                stream.close().await?;
154                Ok(Event::ResponseSent(request_id))
155            } else {
156                stream.close().await?;
157                Ok(Event::ResponseOmission(request_id))
158            }
159        };
160
161        // Inbound connections are reported to the upper layer from within the above task,
162        // so by failing to schedule it, it means the upper layer will never know about the
163        // inbound request. Because of that we do not report any inbound failure.
164        if self
165            .worker_streams
166            .try_push(RequestId::Inbound(request_id), recv.boxed())
167            .is_err()
168        {
169            tracing::warn!("Dropping inbound stream because we are at capacity")
170        }
171    }
172
173    fn on_fully_negotiated_outbound(
174        &mut self,
175        FullyNegotiatedOutbound {
176            protocol: (mut stream, protocol),
177            info: (),
178        }: FullyNegotiatedOutbound<<Self as ConnectionHandler>::OutboundProtocol>,
179    ) {
180        let message = self
181            .requested_outbound
182            .pop_front()
183            .expect("negotiated a stream without a pending message");
184
185        let mut codec = self.codec.clone();
186        let request_id = message.request_id;
187
188        let send = async move {
189            let write = codec.write_request(&protocol, &mut stream, message.request);
190            write.await?;
191            stream.close().await?;
192            let read = codec.read_response(&protocol, &mut stream);
193            let response = read.await?;
194
195            Ok(Event::Response {
196                request_id,
197                response,
198            })
199        };
200
201        if self
202            .worker_streams
203            .try_push(RequestId::Outbound(request_id), send.boxed())
204            .is_err()
205        {
206            self.pending_events.push_back(Event::OutboundStreamFailed {
207                request_id: message.request_id,
208                error: io::Error::new(io::ErrorKind::Other, "max sub-streams reached"),
209            });
210        }
211    }
212
213    fn on_dial_upgrade_error(
214        &mut self,
215        DialUpgradeError { error, info: () }: DialUpgradeError<
216            (),
217            <Self as ConnectionHandler>::OutboundProtocol,
218        >,
219    ) {
220        let message = self
221            .requested_outbound
222            .pop_front()
223            .expect("negotiated a stream without a pending message");
224
225        match error {
226            StreamUpgradeError::Timeout => {
227                self.pending_events
228                    .push_back(Event::OutboundTimeout(message.request_id));
229            }
230            StreamUpgradeError::NegotiationFailed => {
231                // The remote merely doesn't support the protocol(s) we requested.
232                // This is no reason to close the connection, which may
233                // successfully communicate with other protocols already.
234                // An event is reported to permit user code to react to the fact that
235                // the remote peer does not support the requested protocol(s).
236                self.pending_events
237                    .push_back(Event::OutboundUnsupportedProtocols(message.request_id));
238            }
239            StreamUpgradeError::Apply(e) => libp2p_core::util::unreachable(e),
240            StreamUpgradeError::Io(e) => {
241                self.pending_events.push_back(Event::OutboundStreamFailed {
242                    request_id: message.request_id,
243                    error: e,
244                });
245            }
246        }
247    }
248    fn on_listen_upgrade_error(
249        &mut self,
250        ListenUpgradeError { error, .. }: ListenUpgradeError<
251            (),
252            <Self as ConnectionHandler>::InboundProtocol,
253        >,
254    ) {
255        libp2p_core::util::unreachable(error)
256    }
257}
258
259/// The events emitted by the [`Handler`].
260pub enum Event<TCodec>
261where
262    TCodec: Codec,
263{
264    /// A request has been received.
265    Request {
266        request_id: InboundRequestId,
267        request: TCodec::Request,
268        sender: oneshot::Sender<TCodec::Response>,
269    },
270    /// A response has been received.
271    Response {
272        request_id: OutboundRequestId,
273        response: TCodec::Response,
274    },
275    /// A response to an inbound request has been sent.
276    ResponseSent(InboundRequestId),
277    /// A response to an inbound request was omitted as a result
278    /// of dropping the response `sender` of an inbound `Request`.
279    ResponseOmission(InboundRequestId),
280    /// An outbound request timed out while sending the request
281    /// or waiting for the response.
282    OutboundTimeout(OutboundRequestId),
283    /// An outbound request failed to negotiate a mutually supported protocol.
284    OutboundUnsupportedProtocols(OutboundRequestId),
285    OutboundStreamFailed {
286        request_id: OutboundRequestId,
287        error: io::Error,
288    },
289    /// An inbound request timed out while waiting for the request
290    /// or sending the response.
291    InboundTimeout(InboundRequestId),
292    InboundStreamFailed {
293        request_id: InboundRequestId,
294        error: io::Error,
295    },
296}
297
298impl<TCodec: Codec> fmt::Debug for Event<TCodec> {
299    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
300        match self {
301            Event::Request {
302                request_id,
303                request: _,
304                sender: _,
305            } => f
306                .debug_struct("Event::Request")
307                .field("request_id", request_id)
308                .finish(),
309            Event::Response {
310                request_id,
311                response: _,
312            } => f
313                .debug_struct("Event::Response")
314                .field("request_id", request_id)
315                .finish(),
316            Event::ResponseSent(request_id) => f
317                .debug_tuple("Event::ResponseSent")
318                .field(request_id)
319                .finish(),
320            Event::ResponseOmission(request_id) => f
321                .debug_tuple("Event::ResponseOmission")
322                .field(request_id)
323                .finish(),
324            Event::OutboundTimeout(request_id) => f
325                .debug_tuple("Event::OutboundTimeout")
326                .field(request_id)
327                .finish(),
328            Event::OutboundUnsupportedProtocols(request_id) => f
329                .debug_tuple("Event::OutboundUnsupportedProtocols")
330                .field(request_id)
331                .finish(),
332            Event::OutboundStreamFailed { request_id, error } => f
333                .debug_struct("Event::OutboundStreamFailed")
334                .field("request_id", &request_id)
335                .field("error", &error)
336                .finish(),
337            Event::InboundTimeout(request_id) => f
338                .debug_tuple("Event::InboundTimeout")
339                .field(request_id)
340                .finish(),
341            Event::InboundStreamFailed { request_id, error } => f
342                .debug_struct("Event::InboundStreamFailed")
343                .field("request_id", &request_id)
344                .field("error", &error)
345                .finish(),
346        }
347    }
348}
349
350pub struct OutboundMessage<TCodec: Codec> {
351    pub(crate) request_id: OutboundRequestId,
352    pub(crate) request: TCodec::Request,
353    pub(crate) protocols: SmallVec<[TCodec::Protocol; 2]>,
354}
355
356impl<TCodec> fmt::Debug for OutboundMessage<TCodec>
357where
358    TCodec: Codec,
359{
360    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
361        f.debug_struct("OutboundMessage").finish_non_exhaustive()
362    }
363}
364
365impl<TCodec> ConnectionHandler for Handler<TCodec>
366where
367    TCodec: Codec + Send + Clone + 'static,
368{
369    type FromBehaviour = OutboundMessage<TCodec>;
370    type ToBehaviour = Event<TCodec>;
371    type InboundProtocol = Protocol<TCodec::Protocol>;
372    type OutboundProtocol = Protocol<TCodec::Protocol>;
373    type OutboundOpenInfo = ();
374    type InboundOpenInfo = ();
375
376    fn listen_protocol(&self) -> SubstreamProtocol<Self::InboundProtocol> {
377        SubstreamProtocol::new(
378            Protocol {
379                protocols: self.inbound_protocols.clone(),
380            },
381            (),
382        )
383    }
384
385    fn on_behaviour_event(&mut self, request: Self::FromBehaviour) {
386        self.pending_outbound.push_back(request);
387    }
388
389    #[tracing::instrument(level = "trace", name = "ConnectionHandler::poll", skip(self, cx))]
390    fn poll(
391        &mut self,
392        cx: &mut Context<'_>,
393    ) -> Poll<ConnectionHandlerEvent<Protocol<TCodec::Protocol>, (), Self::ToBehaviour>> {
394        match self.worker_streams.poll_unpin(cx) {
395            Poll::Ready((_, Ok(Ok(event)))) => {
396                return Poll::Ready(ConnectionHandlerEvent::NotifyBehaviour(event));
397            }
398            Poll::Ready((RequestId::Inbound(id), Ok(Err(e)))) => {
399                return Poll::Ready(ConnectionHandlerEvent::NotifyBehaviour(
400                    Event::InboundStreamFailed {
401                        request_id: id,
402                        error: e,
403                    },
404                ));
405            }
406            Poll::Ready((RequestId::Outbound(id), Ok(Err(e)))) => {
407                return Poll::Ready(ConnectionHandlerEvent::NotifyBehaviour(
408                    Event::OutboundStreamFailed {
409                        request_id: id,
410                        error: e,
411                    },
412                ));
413            }
414            Poll::Ready((RequestId::Inbound(id), Err(futures_bounded::Timeout { .. }))) => {
415                return Poll::Ready(ConnectionHandlerEvent::NotifyBehaviour(
416                    Event::InboundTimeout(id),
417                ));
418            }
419            Poll::Ready((RequestId::Outbound(id), Err(futures_bounded::Timeout { .. }))) => {
420                return Poll::Ready(ConnectionHandlerEvent::NotifyBehaviour(
421                    Event::OutboundTimeout(id),
422                ));
423            }
424            Poll::Pending => {}
425        }
426
427        // Drain pending events that were produced by `worker_streams`.
428        if let Some(event) = self.pending_events.pop_front() {
429            return Poll::Ready(ConnectionHandlerEvent::NotifyBehaviour(event));
430        } else if self.pending_events.capacity() > EMPTY_QUEUE_SHRINK_THRESHOLD {
431            self.pending_events.shrink_to_fit();
432        }
433
434        // Check for inbound requests.
435        if let Poll::Ready(Some((id, rq, rs_sender))) = self.inbound_receiver.poll_next_unpin(cx) {
436            // We received an inbound request.
437
438            return Poll::Ready(ConnectionHandlerEvent::NotifyBehaviour(Event::Request {
439                request_id: id,
440                request: rq,
441                sender: rs_sender,
442            }));
443        }
444
445        // Emit outbound requests.
446        if let Some(request) = self.pending_outbound.pop_front() {
447            let protocols = request.protocols.clone();
448            self.requested_outbound.push_back(request);
449
450            return Poll::Ready(ConnectionHandlerEvent::OutboundSubstreamRequest {
451                protocol: SubstreamProtocol::new(Protocol { protocols }, ()),
452            });
453        }
454
455        debug_assert!(self.pending_outbound.is_empty());
456
457        if self.pending_outbound.capacity() > EMPTY_QUEUE_SHRINK_THRESHOLD {
458            self.pending_outbound.shrink_to_fit();
459        }
460
461        Poll::Pending
462    }
463
464    fn on_connection_event(
465        &mut self,
466        event: ConnectionEvent<Self::InboundProtocol, Self::OutboundProtocol>,
467    ) {
468        match event {
469            ConnectionEvent::FullyNegotiatedInbound(fully_negotiated_inbound) => {
470                self.on_fully_negotiated_inbound(fully_negotiated_inbound)
471            }
472            ConnectionEvent::FullyNegotiatedOutbound(fully_negotiated_outbound) => {
473                self.on_fully_negotiated_outbound(fully_negotiated_outbound)
474            }
475            ConnectionEvent::DialUpgradeError(dial_upgrade_error) => {
476                self.on_dial_upgrade_error(dial_upgrade_error)
477            }
478            ConnectionEvent::ListenUpgradeError(listen_upgrade_error) => {
479                self.on_listen_upgrade_error(listen_upgrade_error)
480            }
481            _ => {}
482        }
483    }
484}