libp2p_stream/
shared.rs

1use std::{
2    collections::{hash_map::Entry, HashMap},
3    io,
4    sync::{Arc, Mutex, MutexGuard},
5};
6
7use futures::channel::mpsc;
8use libp2p_identity::PeerId;
9use libp2p_swarm::{ConnectionId, Stream, StreamProtocol};
10use rand::seq::IteratorRandom as _;
11
12use crate::{handler::NewStream, AlreadyRegistered, IncomingStreams};
13
14pub(crate) struct Shared {
15    /// Tracks the supported inbound protocols created via
16    /// [`Control::accept`](crate::Control::accept).
17    ///
18    /// For each [`StreamProtocol`], we hold the [`mpsc::Sender`] corresponding to the
19    /// [`mpsc::Receiver`] in [`IncomingStreams`].
20    supported_inbound_protocols: HashMap<StreamProtocol, mpsc::Sender<(PeerId, Stream)>>,
21
22    connections: HashMap<ConnectionId, PeerId>,
23    senders: HashMap<ConnectionId, mpsc::Sender<NewStream>>,
24
25    /// Tracks channel pairs for a peer whilst we are dialing them.
26    pending_channels: HashMap<PeerId, (mpsc::Sender<NewStream>, mpsc::Receiver<NewStream>)>,
27
28    /// Sender for peers we want to dial.
29    ///
30    /// We manage this through a channel to avoid locks as part of
31    /// [`NetworkBehaviour::poll`](libp2p_swarm::NetworkBehaviour::poll).
32    dial_sender: mpsc::Sender<PeerId>,
33}
34
35impl Shared {
36    pub(crate) fn lock(shared: &Arc<Mutex<Shared>>) -> MutexGuard<'_, Shared> {
37        shared.lock().unwrap_or_else(|e| e.into_inner())
38    }
39}
40
41impl Shared {
42    pub(crate) fn new(dial_sender: mpsc::Sender<PeerId>) -> Self {
43        Self {
44            dial_sender,
45            connections: Default::default(),
46            senders: Default::default(),
47            pending_channels: Default::default(),
48            supported_inbound_protocols: Default::default(),
49        }
50    }
51
52    pub(crate) fn accept(
53        &mut self,
54        protocol: StreamProtocol,
55    ) -> Result<IncomingStreams, AlreadyRegistered> {
56        self.supported_inbound_protocols
57            .retain(|_, sender| !sender.is_closed());
58
59        if self.supported_inbound_protocols.contains_key(&protocol) {
60            return Err(AlreadyRegistered);
61        }
62
63        let (sender, receiver) = mpsc::channel(0);
64        self.supported_inbound_protocols
65            .insert(protocol.clone(), sender);
66
67        Ok(IncomingStreams::new(receiver))
68    }
69
70    /// Lists the protocols for which we have an active [`IncomingStreams`] instance.
71    pub(crate) fn supported_inbound_protocols(&mut self) -> Vec<StreamProtocol> {
72        self.supported_inbound_protocols
73            .retain(|_, sender| !sender.is_closed());
74
75        self.supported_inbound_protocols.keys().cloned().collect()
76    }
77
78    pub(crate) fn on_inbound_stream(
79        &mut self,
80        remote: PeerId,
81        stream: Stream,
82        protocol: StreamProtocol,
83    ) {
84        match self.supported_inbound_protocols.entry(protocol.clone()) {
85            Entry::Occupied(mut entry) => match entry.get_mut().try_send((remote, stream)) {
86                Ok(()) => {}
87                Err(e) if e.is_full() => {
88                    tracing::debug!(%protocol, "Channel is full, dropping inbound stream");
89                }
90                Err(e) if e.is_disconnected() => {
91                    tracing::debug!(%protocol, "Channel is gone, dropping inbound stream");
92                    entry.remove();
93                }
94                _ => unreachable!(),
95            },
96            Entry::Vacant(_) => {
97                tracing::debug!(%protocol, "channel is gone, dropping inbound stream");
98            }
99        }
100    }
101
102    pub(crate) fn on_connection_established(&mut self, conn: ConnectionId, peer: PeerId) {
103        self.connections.insert(conn, peer);
104    }
105
106    pub(crate) fn on_connection_closed(&mut self, conn: ConnectionId) {
107        self.connections.remove(&conn);
108    }
109
110    pub(crate) fn on_dial_failure(&mut self, peer: PeerId, reason: String) {
111        let Some((_, mut receiver)) = self.pending_channels.remove(&peer) else {
112            return;
113        };
114
115        while let Ok(Some(new_stream)) = receiver.try_next() {
116            let _ = new_stream
117                .sender
118                .send(Err(crate::OpenStreamError::Io(io::Error::new(
119                    io::ErrorKind::NotConnected,
120                    reason.clone(),
121                ))));
122        }
123    }
124
125    pub(crate) fn sender(&mut self, peer: PeerId) -> mpsc::Sender<NewStream> {
126        let maybe_sender = self
127            .connections
128            .iter()
129            .filter_map(|(c, p)| (p == &peer).then_some(c))
130            .choose(&mut rand::thread_rng())
131            .and_then(|c| self.senders.get(c));
132
133        match maybe_sender {
134            Some(sender) => {
135                tracing::debug!("Returning sender to existing connection");
136
137                sender.clone()
138            }
139            None => {
140                tracing::debug!(%peer, "Not connected to peer, initiating dial");
141
142                let (sender, _) = self
143                    .pending_channels
144                    .entry(peer)
145                    .or_insert_with(|| mpsc::channel(0));
146
147                let _ = self.dial_sender.try_send(peer);
148
149                sender.clone()
150            }
151        }
152    }
153
154    pub(crate) fn receiver(
155        &mut self,
156        peer: PeerId,
157        connection: ConnectionId,
158    ) -> mpsc::Receiver<NewStream> {
159        if let Some((sender, receiver)) = self.pending_channels.remove(&peer) {
160            tracing::debug!(%peer, %connection, "Returning existing pending receiver");
161
162            self.senders.insert(connection, sender);
163            return receiver;
164        }
165
166        tracing::debug!(%peer, %connection, "Creating new channel pair");
167
168        let (sender, receiver) = mpsc::channel(0);
169        self.senders.insert(connection, sender);
170
171        receiver
172    }
173}