1use std::{
2 collections::{hash_map::Entry, HashMap},
3 io,
4 sync::{Arc, Mutex, MutexGuard},
5};
6
7use futures::channel::mpsc;
8use libp2p_identity::PeerId;
9use libp2p_swarm::{ConnectionId, Stream, StreamProtocol};
10use rand::seq::IteratorRandom as _;
11
12use crate::{handler::NewStream, AlreadyRegistered, IncomingStreams};
13
14pub(crate) struct Shared {
15 supported_inbound_protocols: HashMap<StreamProtocol, mpsc::Sender<(PeerId, Stream)>>,
21
22 connections: HashMap<ConnectionId, PeerId>,
23 senders: HashMap<ConnectionId, mpsc::Sender<NewStream>>,
24
25 pending_channels: HashMap<PeerId, (mpsc::Sender<NewStream>, mpsc::Receiver<NewStream>)>,
27
28 dial_sender: mpsc::Sender<PeerId>,
33}
34
35impl Shared {
36 pub(crate) fn lock(shared: &Arc<Mutex<Shared>>) -> MutexGuard<'_, Shared> {
37 shared.lock().unwrap_or_else(|e| e.into_inner())
38 }
39}
40
41impl Shared {
42 pub(crate) fn new(dial_sender: mpsc::Sender<PeerId>) -> Self {
43 Self {
44 dial_sender,
45 connections: Default::default(),
46 senders: Default::default(),
47 pending_channels: Default::default(),
48 supported_inbound_protocols: Default::default(),
49 }
50 }
51
52 pub(crate) fn accept(
53 &mut self,
54 protocol: StreamProtocol,
55 ) -> Result<IncomingStreams, AlreadyRegistered> {
56 self.supported_inbound_protocols
57 .retain(|_, sender| !sender.is_closed());
58
59 if self.supported_inbound_protocols.contains_key(&protocol) {
60 return Err(AlreadyRegistered);
61 }
62
63 let (sender, receiver) = mpsc::channel(0);
64 self.supported_inbound_protocols
65 .insert(protocol.clone(), sender);
66
67 Ok(IncomingStreams::new(receiver))
68 }
69
70 pub(crate) fn supported_inbound_protocols(&mut self) -> Vec<StreamProtocol> {
72 self.supported_inbound_protocols
73 .retain(|_, sender| !sender.is_closed());
74
75 self.supported_inbound_protocols.keys().cloned().collect()
76 }
77
78 pub(crate) fn on_inbound_stream(
79 &mut self,
80 remote: PeerId,
81 stream: Stream,
82 protocol: StreamProtocol,
83 ) {
84 match self.supported_inbound_protocols.entry(protocol.clone()) {
85 Entry::Occupied(mut entry) => match entry.get_mut().try_send((remote, stream)) {
86 Ok(()) => {}
87 Err(e) if e.is_full() => {
88 tracing::debug!(%protocol, "Channel is full, dropping inbound stream");
89 }
90 Err(e) if e.is_disconnected() => {
91 tracing::debug!(%protocol, "Channel is gone, dropping inbound stream");
92 entry.remove();
93 }
94 _ => unreachable!(),
95 },
96 Entry::Vacant(_) => {
97 tracing::debug!(%protocol, "channel is gone, dropping inbound stream");
98 }
99 }
100 }
101
102 pub(crate) fn on_connection_established(&mut self, conn: ConnectionId, peer: PeerId) {
103 self.connections.insert(conn, peer);
104 }
105
106 pub(crate) fn on_connection_closed(&mut self, conn: ConnectionId) {
107 self.connections.remove(&conn);
108 }
109
110 pub(crate) fn on_dial_failure(&mut self, peer: PeerId, reason: String) {
111 let Some((_, mut receiver)) = self.pending_channels.remove(&peer) else {
112 return;
113 };
114
115 while let Ok(Some(new_stream)) = receiver.try_next() {
116 let _ = new_stream
117 .sender
118 .send(Err(crate::OpenStreamError::Io(io::Error::new(
119 io::ErrorKind::NotConnected,
120 reason.clone(),
121 ))));
122 }
123 }
124
125 pub(crate) fn sender(&mut self, peer: PeerId) -> mpsc::Sender<NewStream> {
126 let maybe_sender = self
127 .connections
128 .iter()
129 .filter_map(|(c, p)| (p == &peer).then_some(c))
130 .choose(&mut rand::thread_rng())
131 .and_then(|c| self.senders.get(c));
132
133 match maybe_sender {
134 Some(sender) => {
135 tracing::debug!("Returning sender to existing connection");
136
137 sender.clone()
138 }
139 None => {
140 tracing::debug!(%peer, "Not connected to peer, initiating dial");
141
142 let (sender, _) = self
143 .pending_channels
144 .entry(peer)
145 .or_insert_with(|| mpsc::channel(0));
146
147 let _ = self.dial_sender.try_send(peer);
148
149 sender.clone()
150 }
151 }
152 }
153
154 pub(crate) fn receiver(
155 &mut self,
156 peer: PeerId,
157 connection: ConnectionId,
158 ) -> mpsc::Receiver<NewStream> {
159 if let Some((sender, receiver)) = self.pending_channels.remove(&peer) {
160 tracing::debug!(%peer, %connection, "Returning existing pending receiver");
161
162 self.senders.insert(connection, sender);
163 return receiver;
164 }
165
166 tracing::debug!(%peer, %connection, "Creating new channel pair");
167
168 let (sender, receiver) = mpsc::channel(0);
169 self.senders.insert(connection, sender);
170
171 receiver
172 }
173}