libp2p_stream/
shared.rs

1use std::{
2    collections::{hash_map::Entry, HashMap},
3    io,
4    sync::{Arc, Mutex, MutexGuard},
5};
6
7use futures::channel::mpsc;
8use libp2p_identity::PeerId;
9use libp2p_swarm::{ConnectionId, Stream, StreamProtocol};
10use rand::seq::IteratorRandom as _;
11
12use crate::{handler::NewStream, AlreadyRegistered, IncomingStreams};
13
14pub(crate) struct Shared {
15    /// Tracks the supported inbound protocols created via
16    /// [`Control::accept`](crate::Control::accept).
17    ///
18    /// For each [`StreamProtocol`], we hold the [`mpsc::Sender`] corresponding to the
19    /// [`mpsc::Receiver`] in [`IncomingStreams`].
20    supported_inbound_protocols: HashMap<StreamProtocol, mpsc::Sender<(PeerId, Stream)>>,
21
22    connections: HashMap<ConnectionId, PeerId>,
23    senders: HashMap<ConnectionId, mpsc::Sender<NewStream>>,
24
25    /// Tracks channel pairs for a peer whilst we are dialing them.
26    pending_channels: HashMap<PeerId, (mpsc::Sender<NewStream>, mpsc::Receiver<NewStream>)>,
27
28    /// Sender for peers we want to dial.
29    ///
30    /// We manage this through a channel to avoid locks as part of
31    /// [`NetworkBehaviour::poll`](libp2p_swarm::NetworkBehaviour::poll).
32    dial_sender: mpsc::Sender<PeerId>,
33}
34
35impl Shared {
36    pub(crate) fn lock(shared: &Arc<Mutex<Shared>>) -> MutexGuard<'_, Shared> {
37        shared.lock().unwrap_or_else(|e| e.into_inner())
38    }
39}
40
41impl Shared {
42    pub(crate) fn new(dial_sender: mpsc::Sender<PeerId>) -> Self {
43        Self {
44            dial_sender,
45            connections: Default::default(),
46            senders: Default::default(),
47            pending_channels: Default::default(),
48            supported_inbound_protocols: Default::default(),
49        }
50    }
51
52    pub(crate) fn accept(
53        &mut self,
54        protocol: StreamProtocol,
55    ) -> Result<IncomingStreams, AlreadyRegistered> {
56        if self.supported_inbound_protocols.contains_key(&protocol) {
57            return Err(AlreadyRegistered);
58        }
59
60        let (sender, receiver) = mpsc::channel(0);
61        self.supported_inbound_protocols
62            .insert(protocol.clone(), sender);
63
64        Ok(IncomingStreams::new(receiver))
65    }
66
67    /// Lists the protocols for which we have an active [`IncomingStreams`] instance.
68    pub(crate) fn supported_inbound_protocols(&mut self) -> Vec<StreamProtocol> {
69        self.supported_inbound_protocols
70            .retain(|_, sender| !sender.is_closed());
71
72        self.supported_inbound_protocols.keys().cloned().collect()
73    }
74
75    pub(crate) fn on_inbound_stream(
76        &mut self,
77        remote: PeerId,
78        stream: Stream,
79        protocol: StreamProtocol,
80    ) {
81        match self.supported_inbound_protocols.entry(protocol.clone()) {
82            Entry::Occupied(mut entry) => match entry.get_mut().try_send((remote, stream)) {
83                Ok(()) => {}
84                Err(e) if e.is_full() => {
85                    tracing::debug!(%protocol, "Channel is full, dropping inbound stream");
86                }
87                Err(e) if e.is_disconnected() => {
88                    tracing::debug!(%protocol, "Channel is gone, dropping inbound stream");
89                    entry.remove();
90                }
91                _ => unreachable!(),
92            },
93            Entry::Vacant(_) => {
94                tracing::debug!(%protocol, "channel is gone, dropping inbound stream");
95            }
96        }
97    }
98
99    pub(crate) fn on_connection_established(&mut self, conn: ConnectionId, peer: PeerId) {
100        self.connections.insert(conn, peer);
101    }
102
103    pub(crate) fn on_connection_closed(&mut self, conn: ConnectionId) {
104        self.connections.remove(&conn);
105    }
106
107    pub(crate) fn on_dial_failure(&mut self, peer: PeerId, reason: String) {
108        let Some((_, mut receiver)) = self.pending_channels.remove(&peer) else {
109            return;
110        };
111
112        while let Ok(Some(new_stream)) = receiver.try_next() {
113            let _ = new_stream
114                .sender
115                .send(Err(crate::OpenStreamError::Io(io::Error::new(
116                    io::ErrorKind::NotConnected,
117                    reason.clone(),
118                ))));
119        }
120    }
121
122    pub(crate) fn sender(&mut self, peer: PeerId) -> mpsc::Sender<NewStream> {
123        let maybe_sender = self
124            .connections
125            .iter()
126            .filter_map(|(c, p)| (p == &peer).then_some(c))
127            .choose(&mut rand::thread_rng())
128            .and_then(|c| self.senders.get(c));
129
130        match maybe_sender {
131            Some(sender) => {
132                tracing::debug!("Returning sender to existing connection");
133
134                sender.clone()
135            }
136            None => {
137                tracing::debug!(%peer, "Not connected to peer, initiating dial");
138
139                let (sender, _) = self
140                    .pending_channels
141                    .entry(peer)
142                    .or_insert_with(|| mpsc::channel(0));
143
144                let _ = self.dial_sender.try_send(peer);
145
146                sender.clone()
147            }
148        }
149    }
150
151    pub(crate) fn receiver(
152        &mut self,
153        peer: PeerId,
154        connection: ConnectionId,
155    ) -> mpsc::Receiver<NewStream> {
156        if let Some((sender, receiver)) = self.pending_channels.remove(&peer) {
157            tracing::debug!(%peer, %connection, "Returning existing pending receiver");
158
159            self.senders.insert(connection, sender);
160            return receiver;
161        }
162
163        tracing::debug!(%peer, %connection, "Creating new channel pair");
164
165        let (sender, receiver) = mpsc::channel(0);
166        self.senders.insert(connection, sender);
167
168        receiver
169    }
170}