1use std::{
22 cmp,
23 task::{Context, Poll},
24};
25
26use either::Either;
27use futures::{future, ready};
28use libp2p_core::upgrade::SelectUpgrade;
29
30use crate::{
31 handler::{
32 AddressChange, ConnectionEvent, ConnectionHandler, ConnectionHandlerEvent,
33 DialUpgradeError, FullyNegotiatedInbound, FullyNegotiatedOutbound, InboundUpgradeSend,
34 ListenUpgradeError, OutboundUpgradeSend, StreamUpgradeError, SubstreamProtocol,
35 },
36 upgrade::SendWrapper,
37};
38
39#[derive(Debug, Clone)]
41pub struct ConnectionHandlerSelect<TProto1, TProto2> {
42 proto1: TProto1,
44 proto2: TProto2,
46}
47
48impl<TProto1, TProto2> ConnectionHandlerSelect<TProto1, TProto2> {
49 pub(crate) fn new(proto1: TProto1, proto2: TProto2) -> Self {
51 ConnectionHandlerSelect { proto1, proto2 }
52 }
53
54 pub fn into_inner(self) -> (TProto1, TProto2) {
55 (self.proto1, self.proto2)
56 }
57}
58
59impl<S1OOI, S2OOI, S1OP, S2OP>
60 FullyNegotiatedOutbound<Either<SendWrapper<S1OP>, SendWrapper<S2OP>>, Either<S1OOI, S2OOI>>
61where
62 S1OP: OutboundUpgradeSend,
63 S2OP: OutboundUpgradeSend,
64 S1OOI: Send + 'static,
65 S2OOI: Send + 'static,
66{
67 pub(crate) fn transpose(
68 self,
69 ) -> Either<FullyNegotiatedOutbound<S1OP, S1OOI>, FullyNegotiatedOutbound<S2OP, S2OOI>> {
70 match self {
71 FullyNegotiatedOutbound {
72 protocol: future::Either::Left(protocol),
73 info: Either::Left(info),
74 } => Either::Left(FullyNegotiatedOutbound { protocol, info }),
75 FullyNegotiatedOutbound {
76 protocol: future::Either::Right(protocol),
77 info: Either::Right(info),
78 } => Either::Right(FullyNegotiatedOutbound { protocol, info }),
79 _ => panic!("wrong API usage: the protocol doesn't match the upgrade info"),
80 }
81 }
82}
83
84impl<S1IP, S1IOI, S2IP, S2IOI>
85 FullyNegotiatedInbound<SelectUpgrade<SendWrapper<S1IP>, SendWrapper<S2IP>>, (S1IOI, S2IOI)>
86where
87 S1IP: InboundUpgradeSend,
88 S2IP: InboundUpgradeSend,
89{
90 pub(crate) fn transpose(
91 self,
92 ) -> Either<FullyNegotiatedInbound<S1IP, S1IOI>, FullyNegotiatedInbound<S2IP, S2IOI>> {
93 match self {
94 FullyNegotiatedInbound {
95 protocol: future::Either::Left(protocol),
96 info: (i1, _i2),
97 } => Either::Left(FullyNegotiatedInbound { protocol, info: i1 }),
98 FullyNegotiatedInbound {
99 protocol: future::Either::Right(protocol),
100 info: (_i1, i2),
101 } => Either::Right(FullyNegotiatedInbound { protocol, info: i2 }),
102 }
103 }
104}
105
106impl<S1OOI, S2OOI, S1OP, S2OP>
107 DialUpgradeError<Either<S1OOI, S2OOI>, Either<SendWrapper<S1OP>, SendWrapper<S2OP>>>
108where
109 S1OP: OutboundUpgradeSend,
110 S2OP: OutboundUpgradeSend,
111 S1OOI: Send + 'static,
112 S2OOI: Send + 'static,
113{
114 pub(crate) fn transpose(
115 self,
116 ) -> Either<DialUpgradeError<S1OOI, S1OP>, DialUpgradeError<S2OOI, S2OP>> {
117 match self {
118 DialUpgradeError {
119 info: Either::Left(info),
120 error: StreamUpgradeError::Apply(Either::Left(err)),
121 } => Either::Left(DialUpgradeError {
122 info,
123 error: StreamUpgradeError::Apply(err),
124 }),
125 DialUpgradeError {
126 info: Either::Right(info),
127 error: StreamUpgradeError::Apply(Either::Right(err)),
128 } => Either::Right(DialUpgradeError {
129 info,
130 error: StreamUpgradeError::Apply(err),
131 }),
132 DialUpgradeError {
133 info: Either::Left(info),
134 error: e,
135 } => Either::Left(DialUpgradeError {
136 info,
137 error: e.map_upgrade_err(|_| panic!("already handled above")),
138 }),
139 DialUpgradeError {
140 info: Either::Right(info),
141 error: e,
142 } => Either::Right(DialUpgradeError {
143 info,
144 error: e.map_upgrade_err(|_| panic!("already handled above")),
145 }),
146 }
147 }
148}
149
150impl<TProto1, TProto2> ConnectionHandlerSelect<TProto1, TProto2>
151where
152 TProto1: ConnectionHandler,
153 TProto2: ConnectionHandler,
154{
155 fn on_listen_upgrade_error(
156 &mut self,
157 ListenUpgradeError {
158 info: (i1, i2),
159 error,
160 }: ListenUpgradeError<
161 <Self as ConnectionHandler>::InboundOpenInfo,
162 <Self as ConnectionHandler>::InboundProtocol,
163 >,
164 ) {
165 match error {
166 Either::Left(error) => {
167 self.proto1
168 .on_connection_event(ConnectionEvent::ListenUpgradeError(ListenUpgradeError {
169 info: i1,
170 error,
171 }));
172 }
173 Either::Right(error) => {
174 self.proto2
175 .on_connection_event(ConnectionEvent::ListenUpgradeError(ListenUpgradeError {
176 info: i2,
177 error,
178 }));
179 }
180 }
181 }
182}
183
184impl<TProto1, TProto2> ConnectionHandler for ConnectionHandlerSelect<TProto1, TProto2>
185where
186 TProto1: ConnectionHandler,
187 TProto2: ConnectionHandler,
188{
189 type FromBehaviour = Either<TProto1::FromBehaviour, TProto2::FromBehaviour>;
190 type ToBehaviour = Either<TProto1::ToBehaviour, TProto2::ToBehaviour>;
191 type InboundProtocol = SelectUpgrade<
192 SendWrapper<<TProto1 as ConnectionHandler>::InboundProtocol>,
193 SendWrapper<<TProto2 as ConnectionHandler>::InboundProtocol>,
194 >;
195 type OutboundProtocol =
196 Either<SendWrapper<TProto1::OutboundProtocol>, SendWrapper<TProto2::OutboundProtocol>>;
197 type OutboundOpenInfo = Either<TProto1::OutboundOpenInfo, TProto2::OutboundOpenInfo>;
198 type InboundOpenInfo = (TProto1::InboundOpenInfo, TProto2::InboundOpenInfo);
199
200 fn listen_protocol(&self) -> SubstreamProtocol<Self::InboundProtocol, Self::InboundOpenInfo> {
201 let proto1 = self.proto1.listen_protocol();
202 let proto2 = self.proto2.listen_protocol();
203 let timeout = *std::cmp::max(proto1.timeout(), proto2.timeout());
204 let (u1, i1) = proto1.into_upgrade();
205 let (u2, i2) = proto2.into_upgrade();
206 let choice = SelectUpgrade::new(SendWrapper(u1), SendWrapper(u2));
207 SubstreamProtocol::new(choice, (i1, i2)).with_timeout(timeout)
208 }
209
210 fn on_behaviour_event(&mut self, event: Self::FromBehaviour) {
211 match event {
212 Either::Left(event) => self.proto1.on_behaviour_event(event),
213 Either::Right(event) => self.proto2.on_behaviour_event(event),
214 }
215 }
216
217 fn connection_keep_alive(&self) -> bool {
218 cmp::max(
219 self.proto1.connection_keep_alive(),
220 self.proto2.connection_keep_alive(),
221 )
222 }
223
224 fn poll(
225 &mut self,
226 cx: &mut Context<'_>,
227 ) -> Poll<
228 ConnectionHandlerEvent<Self::OutboundProtocol, Self::OutboundOpenInfo, Self::ToBehaviour>,
229 > {
230 match self.proto1.poll(cx) {
231 Poll::Ready(ConnectionHandlerEvent::NotifyBehaviour(event)) => {
232 return Poll::Ready(ConnectionHandlerEvent::NotifyBehaviour(Either::Left(event)));
233 }
234 Poll::Ready(ConnectionHandlerEvent::OutboundSubstreamRequest { protocol }) => {
235 return Poll::Ready(ConnectionHandlerEvent::OutboundSubstreamRequest {
236 protocol: protocol
237 .map_upgrade(|u| Either::Left(SendWrapper(u)))
238 .map_info(Either::Left),
239 });
240 }
241 Poll::Ready(ConnectionHandlerEvent::ReportRemoteProtocols(support)) => {
242 return Poll::Ready(ConnectionHandlerEvent::ReportRemoteProtocols(support));
243 }
244 Poll::Pending => (),
245 };
246
247 match self.proto2.poll(cx) {
248 Poll::Ready(ConnectionHandlerEvent::NotifyBehaviour(event)) => {
249 return Poll::Ready(ConnectionHandlerEvent::NotifyBehaviour(Either::Right(
250 event,
251 )));
252 }
253 Poll::Ready(ConnectionHandlerEvent::OutboundSubstreamRequest { protocol }) => {
254 return Poll::Ready(ConnectionHandlerEvent::OutboundSubstreamRequest {
255 protocol: protocol
256 .map_upgrade(|u| Either::Right(SendWrapper(u)))
257 .map_info(Either::Right),
258 });
259 }
260 Poll::Ready(ConnectionHandlerEvent::ReportRemoteProtocols(support)) => {
261 return Poll::Ready(ConnectionHandlerEvent::ReportRemoteProtocols(support));
262 }
263 Poll::Pending => (),
264 };
265
266 Poll::Pending
267 }
268
269 fn poll_close(&mut self, cx: &mut Context<'_>) -> Poll<Option<Self::ToBehaviour>> {
270 if let Some(e) = ready!(self.proto1.poll_close(cx)) {
271 return Poll::Ready(Some(Either::Left(e)));
272 }
273
274 if let Some(e) = ready!(self.proto2.poll_close(cx)) {
275 return Poll::Ready(Some(Either::Right(e)));
276 }
277
278 Poll::Ready(None)
279 }
280
281 fn on_connection_event(
282 &mut self,
283 event: ConnectionEvent<
284 Self::InboundProtocol,
285 Self::OutboundProtocol,
286 Self::InboundOpenInfo,
287 Self::OutboundOpenInfo,
288 >,
289 ) {
290 match event {
291 ConnectionEvent::FullyNegotiatedOutbound(fully_negotiated_outbound) => {
292 match fully_negotiated_outbound.transpose() {
293 Either::Left(f) => self
294 .proto1
295 .on_connection_event(ConnectionEvent::FullyNegotiatedOutbound(f)),
296 Either::Right(f) => self
297 .proto2
298 .on_connection_event(ConnectionEvent::FullyNegotiatedOutbound(f)),
299 }
300 }
301 ConnectionEvent::FullyNegotiatedInbound(fully_negotiated_inbound) => {
302 match fully_negotiated_inbound.transpose() {
303 Either::Left(f) => self
304 .proto1
305 .on_connection_event(ConnectionEvent::FullyNegotiatedInbound(f)),
306 Either::Right(f) => self
307 .proto2
308 .on_connection_event(ConnectionEvent::FullyNegotiatedInbound(f)),
309 }
310 }
311 ConnectionEvent::AddressChange(address) => {
312 self.proto1
313 .on_connection_event(ConnectionEvent::AddressChange(AddressChange {
314 new_address: address.new_address,
315 }));
316
317 self.proto2
318 .on_connection_event(ConnectionEvent::AddressChange(AddressChange {
319 new_address: address.new_address,
320 }));
321 }
322 ConnectionEvent::DialUpgradeError(dial_upgrade_error) => {
323 match dial_upgrade_error.transpose() {
324 Either::Left(err) => self
325 .proto1
326 .on_connection_event(ConnectionEvent::DialUpgradeError(err)),
327 Either::Right(err) => self
328 .proto2
329 .on_connection_event(ConnectionEvent::DialUpgradeError(err)),
330 }
331 }
332 ConnectionEvent::ListenUpgradeError(listen_upgrade_error) => {
333 self.on_listen_upgrade_error(listen_upgrade_error)
334 }
335 ConnectionEvent::LocalProtocolsChange(supported_protocols) => {
336 self.proto1
337 .on_connection_event(ConnectionEvent::LocalProtocolsChange(
338 supported_protocols.clone(),
339 ));
340 self.proto2
341 .on_connection_event(ConnectionEvent::LocalProtocolsChange(
342 supported_protocols,
343 ));
344 }
345 ConnectionEvent::RemoteProtocolsChange(supported_protocols) => {
346 self.proto1
347 .on_connection_event(ConnectionEvent::RemoteProtocolsChange(
348 supported_protocols.clone(),
349 ));
350 self.proto2
351 .on_connection_event(ConnectionEvent::RemoteProtocolsChange(
352 supported_protocols,
353 ));
354 }
355 }
356 }
357}