1use std::{
2 convert::Infallible,
3 io,
4 sync::{Arc, Mutex},
5 task::{Context, Poll},
6};
7
8use futures::{
9 channel::{mpsc, oneshot},
10 StreamExt as _,
11};
12use libp2p_identity::PeerId;
13use libp2p_swarm::{
14 self as swarm,
15 handler::{ConnectionEvent, DialUpgradeError, FullyNegotiatedInbound, FullyNegotiatedOutbound},
16 ConnectionHandler, Stream, StreamProtocol,
17};
18
19use crate::{shared::Shared, upgrade::Upgrade, OpenStreamError};
20
21pub struct Handler {
22 remote: PeerId,
23 shared: Arc<Mutex<Shared>>,
24
25 receiver: mpsc::Receiver<NewStream>,
26 pending_upgrade: Option<(
27 StreamProtocol,
28 oneshot::Sender<Result<Stream, OpenStreamError>>,
29 )>,
30}
31
32impl Handler {
33 pub(crate) fn new(
34 remote: PeerId,
35 shared: Arc<Mutex<Shared>>,
36 receiver: mpsc::Receiver<NewStream>,
37 ) -> Self {
38 Self {
39 shared,
40 receiver,
41 pending_upgrade: None,
42 remote,
43 }
44 }
45}
46
47impl ConnectionHandler for Handler {
48 type FromBehaviour = Infallible;
49 type ToBehaviour = Infallible;
50 type InboundProtocol = Upgrade;
51 type OutboundProtocol = Upgrade;
52 type InboundOpenInfo = ();
53 type OutboundOpenInfo = ();
54
55 fn listen_protocol(&self) -> swarm::SubstreamProtocol<Self::InboundProtocol> {
56 swarm::SubstreamProtocol::new(
57 Upgrade {
58 supported_protocols: Shared::lock(&self.shared).supported_inbound_protocols(),
59 },
60 (),
61 )
62 }
63
64 fn poll(
65 &mut self,
66 cx: &mut Context<'_>,
67 ) -> Poll<swarm::ConnectionHandlerEvent<Self::OutboundProtocol, (), Self::ToBehaviour>> {
68 if self.pending_upgrade.is_some() {
69 return Poll::Pending;
70 }
71
72 match self.receiver.poll_next_unpin(cx) {
73 Poll::Ready(Some(new_stream)) => {
74 self.pending_upgrade = Some((new_stream.protocol.clone(), new_stream.sender));
75 return Poll::Ready(swarm::ConnectionHandlerEvent::OutboundSubstreamRequest {
76 protocol: swarm::SubstreamProtocol::new(
77 Upgrade {
78 supported_protocols: vec![new_stream.protocol],
79 },
80 (),
81 ),
82 });
83 }
84 Poll::Ready(None) => {} Poll::Pending => {}
86 }
87
88 Poll::Pending
89 }
90
91 fn on_behaviour_event(&mut self, event: Self::FromBehaviour) {
92 libp2p_core::util::unreachable(event)
93 }
94
95 fn on_connection_event(
96 &mut self,
97 event: ConnectionEvent<Self::InboundProtocol, Self::OutboundProtocol>,
98 ) {
99 match event {
100 ConnectionEvent::FullyNegotiatedInbound(FullyNegotiatedInbound {
101 protocol: (stream, protocol),
102 info: (),
103 }) => {
104 Shared::lock(&self.shared).on_inbound_stream(self.remote, stream, protocol);
105 }
106 ConnectionEvent::FullyNegotiatedOutbound(FullyNegotiatedOutbound {
107 protocol: (stream, actual_protocol),
108 info: (),
109 }) => {
110 let Some((expected_protocol, sender)) = self.pending_upgrade.take() else {
111 debug_assert!(
112 false,
113 "Negotiated an outbound stream without a back channel"
114 );
115 return;
116 };
117 debug_assert_eq!(expected_protocol, actual_protocol);
118
119 let _ = sender.send(Ok(stream));
120 }
121 ConnectionEvent::DialUpgradeError(DialUpgradeError { error, info: () }) => {
122 let Some((p, sender)) = self.pending_upgrade.take() else {
123 debug_assert!(
124 false,
125 "Received a `DialUpgradeError` without a back channel"
126 );
127 return;
128 };
129
130 let error = match error {
131 swarm::StreamUpgradeError::Timeout => {
132 OpenStreamError::Io(io::Error::from(io::ErrorKind::TimedOut))
133 }
134 swarm::StreamUpgradeError::Apply(v) => libp2p_core::util::unreachable(v),
135 swarm::StreamUpgradeError::NegotiationFailed => {
136 OpenStreamError::UnsupportedProtocol(p)
137 }
138 swarm::StreamUpgradeError::Io(io) => OpenStreamError::Io(io),
139 };
140
141 let _ = sender.send(Err(error));
142 }
143 _ => {}
144 }
145 }
146}
147
148#[derive(Debug)]
151pub(crate) struct NewStream {
152 pub(crate) protocol: StreamProtocol,
153 pub(crate) sender: oneshot::Sender<Result<Stream, OpenStreamError>>,
154}