1use std::{
2 collections::{hash_map::Entry, HashMap},
3 io,
4 sync::{Arc, Mutex, MutexGuard},
5};
6
7use futures::channel::mpsc;
8use libp2p_identity::PeerId;
9use libp2p_swarm::{ConnectionId, Stream, StreamProtocol};
10use rand::seq::IteratorRandom as _;
11
12use crate::{handler::NewStream, AlreadyRegistered, IncomingStreams};
13
14pub(crate) struct Shared {
15 supported_inbound_protocols: HashMap<StreamProtocol, mpsc::Sender<(PeerId, Stream)>>,
21
22 connections: HashMap<ConnectionId, PeerId>,
23 senders: HashMap<ConnectionId, mpsc::Sender<NewStream>>,
24
25 pending_channels: HashMap<PeerId, (mpsc::Sender<NewStream>, mpsc::Receiver<NewStream>)>,
27
28 dial_sender: mpsc::Sender<PeerId>,
33}
34
35impl Shared {
36 pub(crate) fn lock(shared: &Arc<Mutex<Shared>>) -> MutexGuard<'_, Shared> {
37 shared.lock().unwrap_or_else(|e| e.into_inner())
38 }
39}
40
41impl Shared {
42 pub(crate) fn new(dial_sender: mpsc::Sender<PeerId>) -> Self {
43 Self {
44 dial_sender,
45 connections: Default::default(),
46 senders: Default::default(),
47 pending_channels: Default::default(),
48 supported_inbound_protocols: Default::default(),
49 }
50 }
51
52 pub(crate) fn accept(
53 &mut self,
54 protocol: StreamProtocol,
55 ) -> Result<IncomingStreams, AlreadyRegistered> {
56 if self.supported_inbound_protocols.contains_key(&protocol) {
57 return Err(AlreadyRegistered);
58 }
59
60 let (sender, receiver) = mpsc::channel(0);
61 self.supported_inbound_protocols
62 .insert(protocol.clone(), sender);
63
64 Ok(IncomingStreams::new(receiver))
65 }
66
67 pub(crate) fn supported_inbound_protocols(&mut self) -> Vec<StreamProtocol> {
69 self.supported_inbound_protocols
70 .retain(|_, sender| !sender.is_closed());
71
72 self.supported_inbound_protocols.keys().cloned().collect()
73 }
74
75 pub(crate) fn on_inbound_stream(
76 &mut self,
77 remote: PeerId,
78 stream: Stream,
79 protocol: StreamProtocol,
80 ) {
81 match self.supported_inbound_protocols.entry(protocol.clone()) {
82 Entry::Occupied(mut entry) => match entry.get_mut().try_send((remote, stream)) {
83 Ok(()) => {}
84 Err(e) if e.is_full() => {
85 tracing::debug!(%protocol, "Channel is full, dropping inbound stream");
86 }
87 Err(e) if e.is_disconnected() => {
88 tracing::debug!(%protocol, "Channel is gone, dropping inbound stream");
89 entry.remove();
90 }
91 _ => unreachable!(),
92 },
93 Entry::Vacant(_) => {
94 tracing::debug!(%protocol, "channel is gone, dropping inbound stream");
95 }
96 }
97 }
98
99 pub(crate) fn on_connection_established(&mut self, conn: ConnectionId, peer: PeerId) {
100 self.connections.insert(conn, peer);
101 }
102
103 pub(crate) fn on_connection_closed(&mut self, conn: ConnectionId) {
104 self.connections.remove(&conn);
105 }
106
107 pub(crate) fn on_dial_failure(&mut self, peer: PeerId, reason: String) {
108 let Some((_, mut receiver)) = self.pending_channels.remove(&peer) else {
109 return;
110 };
111
112 while let Ok(Some(new_stream)) = receiver.try_next() {
113 let _ = new_stream
114 .sender
115 .send(Err(crate::OpenStreamError::Io(io::Error::new(
116 io::ErrorKind::NotConnected,
117 reason.clone(),
118 ))));
119 }
120 }
121
122 pub(crate) fn sender(&mut self, peer: PeerId) -> mpsc::Sender<NewStream> {
123 let maybe_sender = self
124 .connections
125 .iter()
126 .filter_map(|(c, p)| (p == &peer).then_some(c))
127 .choose(&mut rand::thread_rng())
128 .and_then(|c| self.senders.get(c));
129
130 match maybe_sender {
131 Some(sender) => {
132 tracing::debug!("Returning sender to existing connection");
133
134 sender.clone()
135 }
136 None => {
137 tracing::debug!(%peer, "Not connected to peer, initiating dial");
138
139 let (sender, _) = self
140 .pending_channels
141 .entry(peer)
142 .or_insert_with(|| mpsc::channel(0));
143
144 let _ = self.dial_sender.try_send(peer);
145
146 sender.clone()
147 }
148 }
149 }
150
151 pub(crate) fn receiver(
152 &mut self,
153 peer: PeerId,
154 connection: ConnectionId,
155 ) -> mpsc::Receiver<NewStream> {
156 if let Some((sender, receiver)) = self.pending_channels.remove(&peer) {
157 tracing::debug!(%peer, %connection, "Returning existing pending receiver");
158
159 self.senders.insert(connection, sender);
160 return receiver;
161 }
162
163 tracing::debug!(%peer, %connection, "Creating new channel pair");
164
165 let (sender, receiver) = mpsc::channel(0);
166 self.senders.insert(connection, sender);
167
168 receiver
169 }
170}