libp2p_stream/
handler.rs

1use std::{
2    convert::Infallible,
3    io,
4    sync::{Arc, Mutex},
5    task::{Context, Poll},
6};
7
8use futures::{
9    channel::{mpsc, oneshot},
10    StreamExt as _,
11};
12use libp2p_identity::PeerId;
13use libp2p_swarm::{
14    self as swarm,
15    handler::{ConnectionEvent, DialUpgradeError, FullyNegotiatedInbound, FullyNegotiatedOutbound},
16    ConnectionHandler, Stream, StreamProtocol,
17};
18
19use crate::{shared::Shared, upgrade::Upgrade, OpenStreamError};
20
21pub struct Handler {
22    remote: PeerId,
23    shared: Arc<Mutex<Shared>>,
24
25    receiver: mpsc::Receiver<NewStream>,
26    pending_upgrade: Option<(
27        StreamProtocol,
28        oneshot::Sender<Result<Stream, OpenStreamError>>,
29    )>,
30}
31
32impl Handler {
33    pub(crate) fn new(
34        remote: PeerId,
35        shared: Arc<Mutex<Shared>>,
36        receiver: mpsc::Receiver<NewStream>,
37    ) -> Self {
38        Self {
39            shared,
40            receiver,
41            pending_upgrade: None,
42            remote,
43        }
44    }
45}
46
47impl ConnectionHandler for Handler {
48    type FromBehaviour = Infallible;
49    type ToBehaviour = Infallible;
50    type InboundProtocol = Upgrade;
51    type OutboundProtocol = Upgrade;
52    type InboundOpenInfo = ();
53    type OutboundOpenInfo = ();
54
55    fn listen_protocol(&self) -> swarm::SubstreamProtocol<Self::InboundProtocol> {
56        swarm::SubstreamProtocol::new(
57            Upgrade {
58                supported_protocols: Shared::lock(&self.shared).supported_inbound_protocols(),
59            },
60            (),
61        )
62    }
63
64    fn poll(
65        &mut self,
66        cx: &mut Context<'_>,
67    ) -> Poll<swarm::ConnectionHandlerEvent<Self::OutboundProtocol, (), Self::ToBehaviour>> {
68        if self.pending_upgrade.is_some() {
69            return Poll::Pending;
70        }
71
72        match self.receiver.poll_next_unpin(cx) {
73            Poll::Ready(Some(new_stream)) => {
74                self.pending_upgrade = Some((new_stream.protocol.clone(), new_stream.sender));
75                return Poll::Ready(swarm::ConnectionHandlerEvent::OutboundSubstreamRequest {
76                    protocol: swarm::SubstreamProtocol::new(
77                        Upgrade {
78                            supported_protocols: vec![new_stream.protocol],
79                        },
80                        (),
81                    ),
82                });
83            }
84            Poll::Ready(None) => {} // Sender is gone, no more work to do.
85            Poll::Pending => {}
86        }
87
88        Poll::Pending
89    }
90
91    fn on_behaviour_event(&mut self, event: Self::FromBehaviour) {
92        libp2p_core::util::unreachable(event)
93    }
94
95    fn on_connection_event(
96        &mut self,
97        event: ConnectionEvent<Self::InboundProtocol, Self::OutboundProtocol>,
98    ) {
99        match event {
100            ConnectionEvent::FullyNegotiatedInbound(FullyNegotiatedInbound {
101                protocol: (stream, protocol),
102                info: (),
103            }) => {
104                Shared::lock(&self.shared).on_inbound_stream(self.remote, stream, protocol);
105            }
106            ConnectionEvent::FullyNegotiatedOutbound(FullyNegotiatedOutbound {
107                protocol: (stream, actual_protocol),
108                info: (),
109            }) => {
110                let Some((expected_protocol, sender)) = self.pending_upgrade.take() else {
111                    debug_assert!(
112                        false,
113                        "Negotiated an outbound stream without a back channel"
114                    );
115                    return;
116                };
117                debug_assert_eq!(expected_protocol, actual_protocol);
118
119                let _ = sender.send(Ok(stream));
120            }
121            ConnectionEvent::DialUpgradeError(DialUpgradeError { error, info: () }) => {
122                let Some((p, sender)) = self.pending_upgrade.take() else {
123                    debug_assert!(
124                        false,
125                        "Received a `DialUpgradeError` without a back channel"
126                    );
127                    return;
128                };
129
130                let error = match error {
131                    swarm::StreamUpgradeError::Timeout => {
132                        OpenStreamError::Io(io::Error::from(io::ErrorKind::TimedOut))
133                    }
134                    swarm::StreamUpgradeError::Apply(v) => libp2p_core::util::unreachable(v),
135                    swarm::StreamUpgradeError::NegotiationFailed => {
136                        OpenStreamError::UnsupportedProtocol(p)
137                    }
138                    swarm::StreamUpgradeError::Io(io) => OpenStreamError::Io(io),
139                };
140
141                let _ = sender.send(Err(error));
142            }
143            _ => {}
144        }
145    }
146}
147
148/// Message from a [`Control`](crate::Control) to
149/// a [`ConnectionHandler`] to negotiate a new outbound stream.
150#[derive(Debug)]
151pub(crate) struct NewStream {
152    pub(crate) protocol: StreamProtocol,
153    pub(crate) sender: oneshot::Sender<Result<Stream, OpenStreamError>>,
154}