libp2p_swarm/handler/
multi.rs

1// Copyright 2020 Parity Technologies (UK) Ltd.
2//
3// Permission is hereby granted, free of charge, to any person obtaining a
4// copy of this software and associated documentation files (the "Software"),
5// to deal in the Software without restriction, including without limitation
6// the rights to use, copy, modify, merge, publish, distribute, sublicense,
7// and/or sell copies of the Software, and to permit persons to whom the
8// Software is furnished to do so, subject to the following conditions:
9//
10// The above copyright notice and this permission notice shall be included in
11// all copies or substantial portions of the Software.
12//
13// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
14// OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
16// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
18// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
19// DEALINGS IN THE SOFTWARE.
20
21//! A [`ConnectionHandler`] implementation that combines multiple other [`ConnectionHandler`]s
22//! indexed by some key.
23
24use std::{
25    cmp,
26    collections::{HashMap, HashSet},
27    error,
28    fmt::{self, Debug},
29    hash::Hash,
30    iter,
31    task::{Context, Poll},
32    time::Duration,
33};
34
35use futures::{future::BoxFuture, prelude::*, ready};
36use rand::Rng;
37
38use crate::{
39    handler::{
40        AddressChange, ConnectionEvent, ConnectionHandler, ConnectionHandlerEvent,
41        DialUpgradeError, FullyNegotiatedInbound, FullyNegotiatedOutbound, ListenUpgradeError,
42        SubstreamProtocol,
43    },
44    upgrade::{InboundUpgradeSend, OutboundUpgradeSend, UpgradeInfoSend},
45    Stream,
46};
47
48/// A [`ConnectionHandler`] for multiple [`ConnectionHandler`]s of the same type.
49#[derive(Clone)]
50pub struct MultiHandler<K, H> {
51    handlers: HashMap<K, H>,
52}
53
54impl<K, H> fmt::Debug for MultiHandler<K, H>
55where
56    K: fmt::Debug + Eq + Hash,
57    H: fmt::Debug,
58{
59    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
60        f.debug_struct("MultiHandler")
61            .field("handlers", &self.handlers)
62            .finish()
63    }
64}
65
66impl<K, H> MultiHandler<K, H>
67where
68    K: Clone + Debug + Hash + Eq + Send + 'static,
69    H: ConnectionHandler,
70{
71    /// Create and populate a `MultiHandler` from the given handler iterator.
72    ///
73    /// It is an error for any two protocols handlers to share the same protocol name.
74    pub fn try_from_iter<I>(iter: I) -> Result<Self, DuplicateProtonameError>
75    where
76        I: IntoIterator<Item = (K, H)>,
77    {
78        let m = MultiHandler {
79            handlers: HashMap::from_iter(iter),
80        };
81        uniq_proto_names(
82            m.handlers
83                .values()
84                .map(|h| h.listen_protocol().into_upgrade().0),
85        )?;
86        Ok(m)
87    }
88
89    fn on_listen_upgrade_error(
90        &mut self,
91        ListenUpgradeError {
92            error: (key, error),
93            mut info,
94        }: ListenUpgradeError<
95            <Self as ConnectionHandler>::InboundOpenInfo,
96            <Self as ConnectionHandler>::InboundProtocol,
97        >,
98    ) {
99        if let Some(h) = self.handlers.get_mut(&key) {
100            if let Some(i) = info.take(&key) {
101                h.on_connection_event(ConnectionEvent::ListenUpgradeError(ListenUpgradeError {
102                    info: i,
103                    error,
104                }));
105            }
106        }
107    }
108}
109
110impl<K, H> ConnectionHandler for MultiHandler<K, H>
111where
112    K: Clone + Debug + Hash + Eq + Send + 'static,
113    H: ConnectionHandler,
114    H::InboundProtocol: InboundUpgradeSend,
115    H::OutboundProtocol: OutboundUpgradeSend,
116{
117    type FromBehaviour = (K, <H as ConnectionHandler>::FromBehaviour);
118    type ToBehaviour = (K, <H as ConnectionHandler>::ToBehaviour);
119    type InboundProtocol = Upgrade<K, <H as ConnectionHandler>::InboundProtocol>;
120    type OutboundProtocol = <H as ConnectionHandler>::OutboundProtocol;
121    type InboundOpenInfo = Info<K, <H as ConnectionHandler>::InboundOpenInfo>;
122    type OutboundOpenInfo = (K, <H as ConnectionHandler>::OutboundOpenInfo);
123
124    fn listen_protocol(&self) -> SubstreamProtocol<Self::InboundProtocol, Self::InboundOpenInfo> {
125        let (upgrade, info, timeout) = self
126            .handlers
127            .iter()
128            .map(|(key, handler)| {
129                let proto = handler.listen_protocol();
130                let timeout = *proto.timeout();
131                let (upgrade, info) = proto.into_upgrade();
132                (key.clone(), (upgrade, info, timeout))
133            })
134            .fold(
135                (Upgrade::new(), Info::new(), Duration::from_secs(0)),
136                |(mut upg, mut inf, mut timeout), (k, (u, i, t))| {
137                    upg.upgrades.push((k.clone(), u));
138                    inf.infos.push((k, i));
139                    timeout = cmp::max(timeout, t);
140                    (upg, inf, timeout)
141                },
142            );
143        SubstreamProtocol::new(upgrade, info).with_timeout(timeout)
144    }
145
146    fn on_connection_event(
147        &mut self,
148        event: ConnectionEvent<
149            Self::InboundProtocol,
150            Self::OutboundProtocol,
151            Self::InboundOpenInfo,
152            Self::OutboundOpenInfo,
153        >,
154    ) {
155        match event {
156            ConnectionEvent::FullyNegotiatedOutbound(FullyNegotiatedOutbound {
157                protocol,
158                info: (key, arg),
159            }) => {
160                if let Some(h) = self.handlers.get_mut(&key) {
161                    h.on_connection_event(ConnectionEvent::FullyNegotiatedOutbound(
162                        FullyNegotiatedOutbound {
163                            protocol,
164                            info: arg,
165                        },
166                    ));
167                } else {
168                    tracing::error!("FullyNegotiatedOutbound: no handler for key")
169                }
170            }
171            ConnectionEvent::FullyNegotiatedInbound(FullyNegotiatedInbound {
172                protocol: (key, arg),
173                mut info,
174            }) => {
175                if let Some(h) = self.handlers.get_mut(&key) {
176                    if let Some(i) = info.take(&key) {
177                        h.on_connection_event(ConnectionEvent::FullyNegotiatedInbound(
178                            FullyNegotiatedInbound {
179                                protocol: arg,
180                                info: i,
181                            },
182                        ));
183                    }
184                } else {
185                    tracing::error!("FullyNegotiatedInbound: no handler for key")
186                }
187            }
188            ConnectionEvent::AddressChange(AddressChange { new_address }) => {
189                for h in self.handlers.values_mut() {
190                    h.on_connection_event(ConnectionEvent::AddressChange(AddressChange {
191                        new_address,
192                    }));
193                }
194            }
195            ConnectionEvent::DialUpgradeError(DialUpgradeError {
196                info: (key, arg),
197                error,
198            }) => {
199                if let Some(h) = self.handlers.get_mut(&key) {
200                    h.on_connection_event(ConnectionEvent::DialUpgradeError(DialUpgradeError {
201                        info: arg,
202                        error,
203                    }));
204                } else {
205                    tracing::error!("DialUpgradeError: no handler for protocol")
206                }
207            }
208            ConnectionEvent::ListenUpgradeError(listen_upgrade_error) => {
209                self.on_listen_upgrade_error(listen_upgrade_error)
210            }
211            ConnectionEvent::LocalProtocolsChange(supported_protocols) => {
212                for h in self.handlers.values_mut() {
213                    h.on_connection_event(ConnectionEvent::LocalProtocolsChange(
214                        supported_protocols.clone(),
215                    ));
216                }
217            }
218            ConnectionEvent::RemoteProtocolsChange(supported_protocols) => {
219                for h in self.handlers.values_mut() {
220                    h.on_connection_event(ConnectionEvent::RemoteProtocolsChange(
221                        supported_protocols.clone(),
222                    ));
223                }
224            }
225        }
226    }
227
228    fn on_behaviour_event(&mut self, (key, event): Self::FromBehaviour) {
229        if let Some(h) = self.handlers.get_mut(&key) {
230            h.on_behaviour_event(event)
231        } else {
232            tracing::error!("on_behaviour_event: no handler for key")
233        }
234    }
235
236    fn connection_keep_alive(&self) -> bool {
237        self.handlers
238            .values()
239            .map(|h| h.connection_keep_alive())
240            .max()
241            .unwrap_or(false)
242    }
243
244    fn poll(
245        &mut self,
246        cx: &mut Context<'_>,
247    ) -> Poll<
248        ConnectionHandlerEvent<Self::OutboundProtocol, Self::OutboundOpenInfo, Self::ToBehaviour>,
249    > {
250        // Calling `gen_range(0, 0)` (see below) would panic, so we have return early to avoid
251        // that situation.
252        if self.handlers.is_empty() {
253            return Poll::Pending;
254        }
255
256        // Not always polling handlers in the same order
257        // should give anyone the chance to make progress.
258        let pos = rand::thread_rng().gen_range(0..self.handlers.len());
259
260        for (k, h) in self.handlers.iter_mut().skip(pos) {
261            if let Poll::Ready(e) = h.poll(cx) {
262                let e = e
263                    .map_outbound_open_info(|i| (k.clone(), i))
264                    .map_custom(|p| (k.clone(), p));
265                return Poll::Ready(e);
266            }
267        }
268
269        for (k, h) in self.handlers.iter_mut().take(pos) {
270            if let Poll::Ready(e) = h.poll(cx) {
271                let e = e
272                    .map_outbound_open_info(|i| (k.clone(), i))
273                    .map_custom(|p| (k.clone(), p));
274                return Poll::Ready(e);
275            }
276        }
277
278        Poll::Pending
279    }
280
281    fn poll_close(&mut self, cx: &mut Context<'_>) -> Poll<Option<Self::ToBehaviour>> {
282        for (k, h) in self.handlers.iter_mut() {
283            let Some(e) = ready!(h.poll_close(cx)) else {
284                continue;
285            };
286            return Poll::Ready(Some((k.clone(), e)));
287        }
288
289        Poll::Ready(None)
290    }
291}
292
293/// Split [`MultiHandler`] into parts.
294impl<K, H> IntoIterator for MultiHandler<K, H> {
295    type Item = <Self::IntoIter as Iterator>::Item;
296    type IntoIter = std::collections::hash_map::IntoIter<K, H>;
297
298    fn into_iter(self) -> Self::IntoIter {
299        self.handlers.into_iter()
300    }
301}
302
303/// Index and protocol name pair used as `UpgradeInfo::Info`.
304#[derive(Debug, Clone)]
305pub struct IndexedProtoName<H>(usize, H);
306
307impl<H: AsRef<str>> AsRef<str> for IndexedProtoName<H> {
308    fn as_ref(&self) -> &str {
309        self.1.as_ref()
310    }
311}
312
313/// The aggregated `InboundOpenInfo`s of supported inbound substream protocols.
314#[derive(Clone)]
315pub struct Info<K, I> {
316    infos: Vec<(K, I)>,
317}
318
319impl<K: Eq, I> Info<K, I> {
320    fn new() -> Self {
321        Info { infos: Vec::new() }
322    }
323
324    pub fn take(&mut self, k: &K) -> Option<I> {
325        if let Some(p) = self.infos.iter().position(|(key, _)| key == k) {
326            return Some(self.infos.remove(p).1);
327        }
328        None
329    }
330}
331
332/// Inbound and outbound upgrade for all [`ConnectionHandler`]s.
333#[derive(Clone)]
334pub struct Upgrade<K, H> {
335    upgrades: Vec<(K, H)>,
336}
337
338impl<K, H> Upgrade<K, H> {
339    fn new() -> Self {
340        Upgrade {
341            upgrades: Vec::new(),
342        }
343    }
344}
345
346impl<K, H> fmt::Debug for Upgrade<K, H>
347where
348    K: fmt::Debug + Eq + Hash,
349    H: fmt::Debug,
350{
351    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
352        f.debug_struct("Upgrade")
353            .field("upgrades", &self.upgrades)
354            .finish()
355    }
356}
357
358impl<K, H> UpgradeInfoSend for Upgrade<K, H>
359where
360    H: UpgradeInfoSend,
361    K: Send + 'static,
362{
363    type Info = IndexedProtoName<H::Info>;
364    type InfoIter = std::vec::IntoIter<Self::Info>;
365
366    fn protocol_info(&self) -> Self::InfoIter {
367        self.upgrades
368            .iter()
369            .enumerate()
370            .flat_map(|(i, (_, h))| iter::repeat(i).zip(h.protocol_info()))
371            .map(|(i, h)| IndexedProtoName(i, h))
372            .collect::<Vec<_>>()
373            .into_iter()
374    }
375}
376
377impl<K, H> InboundUpgradeSend for Upgrade<K, H>
378where
379    H: InboundUpgradeSend,
380    K: Send + 'static,
381{
382    type Output = (K, <H as InboundUpgradeSend>::Output);
383    type Error = (K, <H as InboundUpgradeSend>::Error);
384    type Future = BoxFuture<'static, Result<Self::Output, Self::Error>>;
385
386    fn upgrade_inbound(mut self, resource: Stream, info: Self::Info) -> Self::Future {
387        let IndexedProtoName(index, info) = info;
388        let (key, upgrade) = self.upgrades.remove(index);
389        upgrade
390            .upgrade_inbound(resource, info)
391            .map(move |out| match out {
392                Ok(o) => Ok((key, o)),
393                Err(e) => Err((key, e)),
394            })
395            .boxed()
396    }
397}
398
399impl<K, H> OutboundUpgradeSend for Upgrade<K, H>
400where
401    H: OutboundUpgradeSend,
402    K: Send + 'static,
403{
404    type Output = (K, <H as OutboundUpgradeSend>::Output);
405    type Error = (K, <H as OutboundUpgradeSend>::Error);
406    type Future = BoxFuture<'static, Result<Self::Output, Self::Error>>;
407
408    fn upgrade_outbound(mut self, resource: Stream, info: Self::Info) -> Self::Future {
409        let IndexedProtoName(index, info) = info;
410        let (key, upgrade) = self.upgrades.remove(index);
411        upgrade
412            .upgrade_outbound(resource, info)
413            .map(move |out| match out {
414                Ok(o) => Ok((key, o)),
415                Err(e) => Err((key, e)),
416            })
417            .boxed()
418    }
419}
420
421/// Check that no two protocol names are equal.
422fn uniq_proto_names<I, T>(iter: I) -> Result<(), DuplicateProtonameError>
423where
424    I: Iterator<Item = T>,
425    T: UpgradeInfoSend,
426{
427    let mut set = HashSet::new();
428    for infos in iter {
429        for i in infos.protocol_info() {
430            let v = Vec::from(i.as_ref());
431            if set.contains(&v) {
432                return Err(DuplicateProtonameError(v));
433            } else {
434                set.insert(v);
435            }
436        }
437    }
438    Ok(())
439}
440
441/// It is an error if two handlers share the same protocol name.
442#[derive(Debug, Clone)]
443pub struct DuplicateProtonameError(Vec<u8>);
444
445impl DuplicateProtonameError {
446    /// The protocol name bytes that occurred in more than one handler.
447    pub fn protocol_name(&self) -> &[u8] {
448        &self.0
449    }
450}
451
452impl fmt::Display for DuplicateProtonameError {
453    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
454        if let Ok(s) = std::str::from_utf8(&self.0) {
455            write!(f, "duplicate protocol name: {s}")
456        } else {
457            write!(f, "duplicate protocol name: {:?}", self.0)
458        }
459    }
460}
461
462impl error::Error for DuplicateProtonameError {}