1use std::{
25 cmp,
26 collections::{HashMap, HashSet},
27 error,
28 fmt::{self, Debug},
29 hash::Hash,
30 iter,
31 task::{Context, Poll},
32 time::Duration,
33};
34
35use futures::{future::BoxFuture, prelude::*, ready};
36use rand::Rng;
37
38use crate::{
39 handler::{
40 AddressChange, ConnectionEvent, ConnectionHandler, ConnectionHandlerEvent,
41 DialUpgradeError, FullyNegotiatedInbound, FullyNegotiatedOutbound, ListenUpgradeError,
42 SubstreamProtocol,
43 },
44 upgrade::{InboundUpgradeSend, OutboundUpgradeSend, UpgradeInfoSend},
45 Stream,
46};
47
48#[derive(Clone)]
50pub struct MultiHandler<K, H> {
51 handlers: HashMap<K, H>,
52}
53
54impl<K, H> fmt::Debug for MultiHandler<K, H>
55where
56 K: fmt::Debug + Eq + Hash,
57 H: fmt::Debug,
58{
59 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
60 f.debug_struct("MultiHandler")
61 .field("handlers", &self.handlers)
62 .finish()
63 }
64}
65
66impl<K, H> MultiHandler<K, H>
67where
68 K: Clone + Debug + Hash + Eq + Send + 'static,
69 H: ConnectionHandler,
70{
71 pub fn try_from_iter<I>(iter: I) -> Result<Self, DuplicateProtonameError>
75 where
76 I: IntoIterator<Item = (K, H)>,
77 {
78 let m = MultiHandler {
79 handlers: HashMap::from_iter(iter),
80 };
81 uniq_proto_names(
82 m.handlers
83 .values()
84 .map(|h| h.listen_protocol().into_upgrade().0),
85 )?;
86 Ok(m)
87 }
88
89 fn on_listen_upgrade_error(
90 &mut self,
91 ListenUpgradeError {
92 error: (key, error),
93 mut info,
94 }: ListenUpgradeError<
95 <Self as ConnectionHandler>::InboundOpenInfo,
96 <Self as ConnectionHandler>::InboundProtocol,
97 >,
98 ) {
99 if let Some(h) = self.handlers.get_mut(&key) {
100 if let Some(i) = info.take(&key) {
101 h.on_connection_event(ConnectionEvent::ListenUpgradeError(ListenUpgradeError {
102 info: i,
103 error,
104 }));
105 }
106 }
107 }
108}
109
110impl<K, H> ConnectionHandler for MultiHandler<K, H>
111where
112 K: Clone + Debug + Hash + Eq + Send + 'static,
113 H: ConnectionHandler,
114 H::InboundProtocol: InboundUpgradeSend,
115 H::OutboundProtocol: OutboundUpgradeSend,
116{
117 type FromBehaviour = (K, <H as ConnectionHandler>::FromBehaviour);
118 type ToBehaviour = (K, <H as ConnectionHandler>::ToBehaviour);
119 type InboundProtocol = Upgrade<K, <H as ConnectionHandler>::InboundProtocol>;
120 type OutboundProtocol = <H as ConnectionHandler>::OutboundProtocol;
121 type InboundOpenInfo = Info<K, <H as ConnectionHandler>::InboundOpenInfo>;
122 type OutboundOpenInfo = (K, <H as ConnectionHandler>::OutboundOpenInfo);
123
124 fn listen_protocol(&self) -> SubstreamProtocol<Self::InboundProtocol, Self::InboundOpenInfo> {
125 let (upgrade, info, timeout) = self
126 .handlers
127 .iter()
128 .map(|(key, handler)| {
129 let proto = handler.listen_protocol();
130 let timeout = *proto.timeout();
131 let (upgrade, info) = proto.into_upgrade();
132 (key.clone(), (upgrade, info, timeout))
133 })
134 .fold(
135 (Upgrade::new(), Info::new(), Duration::from_secs(0)),
136 |(mut upg, mut inf, mut timeout), (k, (u, i, t))| {
137 upg.upgrades.push((k.clone(), u));
138 inf.infos.push((k, i));
139 timeout = cmp::max(timeout, t);
140 (upg, inf, timeout)
141 },
142 );
143 SubstreamProtocol::new(upgrade, info).with_timeout(timeout)
144 }
145
146 fn on_connection_event(
147 &mut self,
148 event: ConnectionEvent<
149 Self::InboundProtocol,
150 Self::OutboundProtocol,
151 Self::InboundOpenInfo,
152 Self::OutboundOpenInfo,
153 >,
154 ) {
155 match event {
156 ConnectionEvent::FullyNegotiatedOutbound(FullyNegotiatedOutbound {
157 protocol,
158 info: (key, arg),
159 }) => {
160 if let Some(h) = self.handlers.get_mut(&key) {
161 h.on_connection_event(ConnectionEvent::FullyNegotiatedOutbound(
162 FullyNegotiatedOutbound {
163 protocol,
164 info: arg,
165 },
166 ));
167 } else {
168 tracing::error!("FullyNegotiatedOutbound: no handler for key")
169 }
170 }
171 ConnectionEvent::FullyNegotiatedInbound(FullyNegotiatedInbound {
172 protocol: (key, arg),
173 mut info,
174 }) => {
175 if let Some(h) = self.handlers.get_mut(&key) {
176 if let Some(i) = info.take(&key) {
177 h.on_connection_event(ConnectionEvent::FullyNegotiatedInbound(
178 FullyNegotiatedInbound {
179 protocol: arg,
180 info: i,
181 },
182 ));
183 }
184 } else {
185 tracing::error!("FullyNegotiatedInbound: no handler for key")
186 }
187 }
188 ConnectionEvent::AddressChange(AddressChange { new_address }) => {
189 for h in self.handlers.values_mut() {
190 h.on_connection_event(ConnectionEvent::AddressChange(AddressChange {
191 new_address,
192 }));
193 }
194 }
195 ConnectionEvent::DialUpgradeError(DialUpgradeError {
196 info: (key, arg),
197 error,
198 }) => {
199 if let Some(h) = self.handlers.get_mut(&key) {
200 h.on_connection_event(ConnectionEvent::DialUpgradeError(DialUpgradeError {
201 info: arg,
202 error,
203 }));
204 } else {
205 tracing::error!("DialUpgradeError: no handler for protocol")
206 }
207 }
208 ConnectionEvent::ListenUpgradeError(listen_upgrade_error) => {
209 self.on_listen_upgrade_error(listen_upgrade_error)
210 }
211 ConnectionEvent::LocalProtocolsChange(supported_protocols) => {
212 for h in self.handlers.values_mut() {
213 h.on_connection_event(ConnectionEvent::LocalProtocolsChange(
214 supported_protocols.clone(),
215 ));
216 }
217 }
218 ConnectionEvent::RemoteProtocolsChange(supported_protocols) => {
219 for h in self.handlers.values_mut() {
220 h.on_connection_event(ConnectionEvent::RemoteProtocolsChange(
221 supported_protocols.clone(),
222 ));
223 }
224 }
225 }
226 }
227
228 fn on_behaviour_event(&mut self, (key, event): Self::FromBehaviour) {
229 if let Some(h) = self.handlers.get_mut(&key) {
230 h.on_behaviour_event(event)
231 } else {
232 tracing::error!("on_behaviour_event: no handler for key")
233 }
234 }
235
236 fn connection_keep_alive(&self) -> bool {
237 self.handlers
238 .values()
239 .map(|h| h.connection_keep_alive())
240 .max()
241 .unwrap_or(false)
242 }
243
244 fn poll(
245 &mut self,
246 cx: &mut Context<'_>,
247 ) -> Poll<
248 ConnectionHandlerEvent<Self::OutboundProtocol, Self::OutboundOpenInfo, Self::ToBehaviour>,
249 > {
250 if self.handlers.is_empty() {
253 return Poll::Pending;
254 }
255
256 let pos = rand::thread_rng().gen_range(0..self.handlers.len());
259
260 for (k, h) in self.handlers.iter_mut().skip(pos) {
261 if let Poll::Ready(e) = h.poll(cx) {
262 let e = e
263 .map_outbound_open_info(|i| (k.clone(), i))
264 .map_custom(|p| (k.clone(), p));
265 return Poll::Ready(e);
266 }
267 }
268
269 for (k, h) in self.handlers.iter_mut().take(pos) {
270 if let Poll::Ready(e) = h.poll(cx) {
271 let e = e
272 .map_outbound_open_info(|i| (k.clone(), i))
273 .map_custom(|p| (k.clone(), p));
274 return Poll::Ready(e);
275 }
276 }
277
278 Poll::Pending
279 }
280
281 fn poll_close(&mut self, cx: &mut Context<'_>) -> Poll<Option<Self::ToBehaviour>> {
282 for (k, h) in self.handlers.iter_mut() {
283 let Some(e) = ready!(h.poll_close(cx)) else {
284 continue;
285 };
286 return Poll::Ready(Some((k.clone(), e)));
287 }
288
289 Poll::Ready(None)
290 }
291}
292
293impl<K, H> IntoIterator for MultiHandler<K, H> {
295 type Item = <Self::IntoIter as Iterator>::Item;
296 type IntoIter = std::collections::hash_map::IntoIter<K, H>;
297
298 fn into_iter(self) -> Self::IntoIter {
299 self.handlers.into_iter()
300 }
301}
302
303#[derive(Debug, Clone)]
305pub struct IndexedProtoName<H>(usize, H);
306
307impl<H: AsRef<str>> AsRef<str> for IndexedProtoName<H> {
308 fn as_ref(&self) -> &str {
309 self.1.as_ref()
310 }
311}
312
313#[derive(Clone)]
315pub struct Info<K, I> {
316 infos: Vec<(K, I)>,
317}
318
319impl<K: Eq, I> Info<K, I> {
320 fn new() -> Self {
321 Info { infos: Vec::new() }
322 }
323
324 pub fn take(&mut self, k: &K) -> Option<I> {
325 if let Some(p) = self.infos.iter().position(|(key, _)| key == k) {
326 return Some(self.infos.remove(p).1);
327 }
328 None
329 }
330}
331
332#[derive(Clone)]
334pub struct Upgrade<K, H> {
335 upgrades: Vec<(K, H)>,
336}
337
338impl<K, H> Upgrade<K, H> {
339 fn new() -> Self {
340 Upgrade {
341 upgrades: Vec::new(),
342 }
343 }
344}
345
346impl<K, H> fmt::Debug for Upgrade<K, H>
347where
348 K: fmt::Debug + Eq + Hash,
349 H: fmt::Debug,
350{
351 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
352 f.debug_struct("Upgrade")
353 .field("upgrades", &self.upgrades)
354 .finish()
355 }
356}
357
358impl<K, H> UpgradeInfoSend for Upgrade<K, H>
359where
360 H: UpgradeInfoSend,
361 K: Send + 'static,
362{
363 type Info = IndexedProtoName<H::Info>;
364 type InfoIter = std::vec::IntoIter<Self::Info>;
365
366 fn protocol_info(&self) -> Self::InfoIter {
367 self.upgrades
368 .iter()
369 .enumerate()
370 .flat_map(|(i, (_, h))| iter::repeat(i).zip(h.protocol_info()))
371 .map(|(i, h)| IndexedProtoName(i, h))
372 .collect::<Vec<_>>()
373 .into_iter()
374 }
375}
376
377impl<K, H> InboundUpgradeSend for Upgrade<K, H>
378where
379 H: InboundUpgradeSend,
380 K: Send + 'static,
381{
382 type Output = (K, <H as InboundUpgradeSend>::Output);
383 type Error = (K, <H as InboundUpgradeSend>::Error);
384 type Future = BoxFuture<'static, Result<Self::Output, Self::Error>>;
385
386 fn upgrade_inbound(mut self, resource: Stream, info: Self::Info) -> Self::Future {
387 let IndexedProtoName(index, info) = info;
388 let (key, upgrade) = self.upgrades.remove(index);
389 upgrade
390 .upgrade_inbound(resource, info)
391 .map(move |out| match out {
392 Ok(o) => Ok((key, o)),
393 Err(e) => Err((key, e)),
394 })
395 .boxed()
396 }
397}
398
399impl<K, H> OutboundUpgradeSend for Upgrade<K, H>
400where
401 H: OutboundUpgradeSend,
402 K: Send + 'static,
403{
404 type Output = (K, <H as OutboundUpgradeSend>::Output);
405 type Error = (K, <H as OutboundUpgradeSend>::Error);
406 type Future = BoxFuture<'static, Result<Self::Output, Self::Error>>;
407
408 fn upgrade_outbound(mut self, resource: Stream, info: Self::Info) -> Self::Future {
409 let IndexedProtoName(index, info) = info;
410 let (key, upgrade) = self.upgrades.remove(index);
411 upgrade
412 .upgrade_outbound(resource, info)
413 .map(move |out| match out {
414 Ok(o) => Ok((key, o)),
415 Err(e) => Err((key, e)),
416 })
417 .boxed()
418 }
419}
420
421fn uniq_proto_names<I, T>(iter: I) -> Result<(), DuplicateProtonameError>
423where
424 I: Iterator<Item = T>,
425 T: UpgradeInfoSend,
426{
427 let mut set = HashSet::new();
428 for infos in iter {
429 for i in infos.protocol_info() {
430 let v = Vec::from(i.as_ref());
431 if set.contains(&v) {
432 return Err(DuplicateProtonameError(v));
433 } else {
434 set.insert(v);
435 }
436 }
437 }
438 Ok(())
439}
440
441#[derive(Debug, Clone)]
443pub struct DuplicateProtonameError(Vec<u8>);
444
445impl DuplicateProtonameError {
446 pub fn protocol_name(&self) -> &[u8] {
448 &self.0
449 }
450}
451
452impl fmt::Display for DuplicateProtonameError {
453 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
454 if let Ok(s) = std::str::from_utf8(&self.0) {
455 write!(f, "duplicate protocol name: {s}")
456 } else {
457 write!(f, "duplicate protocol name: {:?}", self.0)
458 }
459 }
460}
461
462impl error::Error for DuplicateProtonameError {}