1use std::{
22 collections::{hash_map::Entry, VecDeque},
23 error, fmt, io,
24 num::NonZeroU64,
25 pin::Pin,
26 sync::LazyLock,
27};
28
29use fnv::FnvHashMap;
30use futures::{
31 channel::mpsc,
32 future::Ready,
33 prelude::*,
34 task::{Context, Poll},
35};
36use multiaddr::{Multiaddr, Protocol};
37use parking_lot::Mutex;
38use rw_stream_sink::RwStreamSink;
39
40use crate::transport::{DialOpts, ListenerId, Transport, TransportError, TransportEvent};
41
42static HUB: LazyLock<Hub> = LazyLock::new(|| Hub(Mutex::new(FnvHashMap::default())));
43
44struct Hub(Mutex<FnvHashMap<NonZeroU64, ChannelSender>>);
45
46type ChannelSender = mpsc::Sender<(Channel<Vec<u8>>, NonZeroU64)>;
49
50type ChannelReceiver = mpsc::Receiver<(Channel<Vec<u8>>, NonZeroU64)>;
53
54impl Hub {
55 fn register_port(&self, port: u64) -> Option<(ChannelReceiver, NonZeroU64)> {
60 let mut hub = self.0.lock();
61
62 let port = if let Some(port) = NonZeroU64::new(port) {
63 port
64 } else {
65 loop {
66 let Some(port) = NonZeroU64::new(rand::random()) else {
67 continue;
68 };
69 if !hub.contains_key(&port) {
70 break port;
71 }
72 }
73 };
74
75 let (tx, rx) = mpsc::channel(2);
76 match hub.entry(port) {
77 Entry::Occupied(_) => return None,
78 Entry::Vacant(e) => e.insert(tx),
79 };
80
81 Some((rx, port))
82 }
83
84 fn unregister_port(&self, port: &NonZeroU64) -> Option<ChannelSender> {
85 self.0.lock().remove(port)
86 }
87
88 fn get(&self, port: &NonZeroU64) -> Option<ChannelSender> {
89 self.0.lock().get(port).cloned()
90 }
91}
92
93#[derive(Default)]
95pub struct MemoryTransport {
96 listeners: VecDeque<Pin<Box<Listener>>>,
97}
98
99impl MemoryTransport {
100 pub fn new() -> Self {
101 Self::default()
102 }
103}
104
105pub struct DialFuture {
107 dial_port: NonZeroU64,
114 sender: ChannelSender,
115 channel_to_send: Option<Channel<Vec<u8>>>,
116 channel_to_return: Option<Channel<Vec<u8>>>,
117}
118
119impl DialFuture {
120 fn new(port: NonZeroU64) -> Option<Self> {
121 let sender = HUB.get(&port)?;
122
123 let (_dial_port_channel, dial_port) = HUB
124 .register_port(0)
125 .expect("there to be some random unoccupied port.");
126
127 let (a_tx, a_rx) = mpsc::channel(4096);
128 let (b_tx, b_rx) = mpsc::channel(4096);
129 Some(DialFuture {
130 dial_port,
131 sender,
132 channel_to_send: Some(RwStreamSink::new(Chan {
133 incoming: a_rx,
134 outgoing: b_tx,
135 dial_port: None,
136 })),
137 channel_to_return: Some(RwStreamSink::new(Chan {
138 incoming: b_rx,
139 outgoing: a_tx,
140 dial_port: Some(dial_port),
141 })),
142 })
143 }
144}
145
146impl Future for DialFuture {
147 type Output = Result<Channel<Vec<u8>>, MemoryTransportError>;
148
149 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
150 match self.sender.poll_ready(cx) {
151 Poll::Pending => return Poll::Pending,
152 Poll::Ready(Ok(())) => {}
153 Poll::Ready(Err(_)) => return Poll::Ready(Err(MemoryTransportError::Unreachable)),
154 }
155
156 let channel_to_send = self
157 .channel_to_send
158 .take()
159 .expect("Future should not be polled again once complete");
160 let dial_port = self.dial_port;
161 if self
162 .sender
163 .start_send((channel_to_send, dial_port))
164 .is_err()
165 {
166 return Poll::Ready(Err(MemoryTransportError::Unreachable));
167 }
168
169 Poll::Ready(Ok(self
170 .channel_to_return
171 .take()
172 .expect("Future should not be polled again once complete")))
173 }
174}
175
176impl Transport for MemoryTransport {
177 type Output = Channel<Vec<u8>>;
178 type Error = MemoryTransportError;
179 type ListenerUpgrade = Ready<Result<Self::Output, Self::Error>>;
180 type Dial = DialFuture;
181
182 fn listen_on(
183 &mut self,
184 id: ListenerId,
185 addr: Multiaddr,
186 ) -> Result<(), TransportError<Self::Error>> {
187 let port =
188 parse_memory_addr(&addr).map_err(|_| TransportError::MultiaddrNotSupported(addr))?;
189
190 let (rx, port) = HUB
191 .register_port(port)
192 .ok_or(TransportError::Other(MemoryTransportError::Unreachable))?;
193
194 let listener = Listener {
195 id,
196 port,
197 addr: Protocol::Memory(port.get()).into(),
198 receiver: rx,
199 tell_listen_addr: true,
200 };
201 self.listeners.push_back(Box::pin(listener));
202
203 Ok(())
204 }
205
206 fn remove_listener(&mut self, id: ListenerId) -> bool {
207 if let Some(index) = self.listeners.iter().position(|listener| listener.id == id) {
208 let listener = self.listeners.get_mut(index).unwrap();
209 let val_in = HUB.unregister_port(&listener.port);
210 debug_assert!(val_in.is_some());
211 listener.receiver.close();
212 true
213 } else {
214 false
215 }
216 }
217
218 fn dial(
219 &mut self,
220 addr: Multiaddr,
221 _opts: DialOpts,
222 ) -> Result<DialFuture, TransportError<Self::Error>> {
223 let port = if let Ok(port) = parse_memory_addr(&addr) {
224 if let Some(port) = NonZeroU64::new(port) {
225 port
226 } else {
227 return Err(TransportError::Other(MemoryTransportError::Unreachable));
228 }
229 } else {
230 return Err(TransportError::MultiaddrNotSupported(addr));
231 };
232
233 DialFuture::new(port).ok_or(TransportError::Other(MemoryTransportError::Unreachable))
234 }
235
236 fn poll(
237 mut self: Pin<&mut Self>,
238 cx: &mut Context<'_>,
239 ) -> Poll<TransportEvent<Self::ListenerUpgrade, Self::Error>>
240 where
241 Self: Sized,
242 {
243 let mut remaining = self.listeners.len();
244 while let Some(mut listener) = self.listeners.pop_back() {
245 if listener.tell_listen_addr {
246 listener.tell_listen_addr = false;
247 let listen_addr = listener.addr.clone();
248 let listener_id = listener.id;
249 self.listeners.push_front(listener);
250 return Poll::Ready(TransportEvent::NewAddress {
251 listen_addr,
252 listener_id,
253 });
254 }
255
256 let event = match Stream::poll_next(Pin::new(&mut listener.receiver), cx) {
257 Poll::Pending => None,
258 Poll::Ready(Some((channel, dial_port))) => Some(TransportEvent::Incoming {
259 listener_id: listener.id,
260 upgrade: future::ready(Ok(channel)),
261 local_addr: listener.addr.clone(),
262 send_back_addr: Protocol::Memory(dial_port.get()).into(),
263 }),
264 Poll::Ready(None) => {
265 return Poll::Ready(TransportEvent::ListenerClosed {
267 listener_id: listener.id,
268 reason: Ok(()),
269 });
270 }
271 };
272
273 self.listeners.push_front(listener);
274 if let Some(event) = event {
275 return Poll::Ready(event);
276 } else {
277 remaining -= 1;
278 if remaining == 0 {
279 break;
280 }
281 }
282 }
283 Poll::Pending
284 }
285}
286
287#[derive(Debug, Copy, Clone)]
289pub enum MemoryTransportError {
290 Unreachable,
292 AlreadyInUse,
294}
295
296impl fmt::Display for MemoryTransportError {
297 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
298 match *self {
299 MemoryTransportError::Unreachable => write!(f, "No listener on the given port."),
300 MemoryTransportError::AlreadyInUse => write!(f, "Port already occupied."),
301 }
302 }
303}
304
305impl error::Error for MemoryTransportError {}
306
307pub struct Listener {
309 id: ListenerId,
310 port: NonZeroU64,
312 addr: Multiaddr,
314 receiver: ChannelReceiver,
316 tell_listen_addr: bool,
318}
319
320fn parse_memory_addr(a: &Multiaddr) -> Result<u64, ()> {
322 let mut protocols = a.iter();
323 match protocols.next() {
324 Some(Protocol::Memory(port)) => match protocols.next() {
325 None | Some(Protocol::P2p(_)) => Ok(port),
326 _ => Err(()),
327 },
328 _ => Err(()),
329 }
330}
331
332pub type Channel<T> = RwStreamSink<Chan<T>>;
336
337pub struct Chan<T = Vec<u8>> {
341 incoming: mpsc::Receiver<T>,
342 outgoing: mpsc::Sender<T>,
343
344 dial_port: Option<NonZeroU64>,
351}
352
353impl<T> Unpin for Chan<T> {}
354
355impl<T> Stream for Chan<T> {
356 type Item = Result<T, io::Error>;
357
358 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
359 match Stream::poll_next(Pin::new(&mut self.incoming), cx) {
360 Poll::Pending => Poll::Pending,
361 Poll::Ready(None) => Poll::Ready(None),
362 Poll::Ready(Some(v)) => Poll::Ready(Some(Ok(v))),
363 }
364 }
365}
366
367impl<T> Sink<T> for Chan<T> {
368 type Error = io::Error;
369
370 fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
371 self.outgoing
372 .poll_ready(cx)
373 .map(|v| v.map_err(|_| io::ErrorKind::BrokenPipe.into()))
374 }
375
376 fn start_send(mut self: Pin<&mut Self>, item: T) -> Result<(), Self::Error> {
377 self.outgoing
378 .start_send(item)
379 .map_err(|_| io::ErrorKind::BrokenPipe.into())
380 }
381
382 fn poll_flush(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
383 Poll::Ready(Ok(()))
384 }
385
386 fn poll_close(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
387 Poll::Ready(Ok(()))
388 }
389}
390
391impl<T: AsRef<[u8]>> From<Chan<T>> for RwStreamSink<Chan<T>> {
392 fn from(channel: Chan<T>) -> RwStreamSink<Chan<T>> {
393 RwStreamSink::new(channel)
394 }
395}
396
397impl<T> Drop for Chan<T> {
398 fn drop(&mut self) {
399 if let Some(port) = self.dial_port {
400 let channel_sender = HUB.unregister_port(&port);
401 debug_assert!(channel_sender.is_some());
402 }
403 }
404}
405
406#[cfg(test)]
407mod tests {
408 use super::*;
409 use crate::{transport::PortUse, Endpoint};
410
411 #[test]
412 fn parse_memory_addr_works() {
413 assert_eq!(parse_memory_addr(&"/memory/5".parse().unwrap()), Ok(5));
414 assert_eq!(parse_memory_addr(&"/tcp/150".parse().unwrap()), Err(()));
415 assert_eq!(parse_memory_addr(&"/memory/0".parse().unwrap()), Ok(0));
416 assert_eq!(
417 parse_memory_addr(&"/memory/5/tcp/150".parse().unwrap()),
418 Err(())
419 );
420 assert_eq!(
421 parse_memory_addr(&"/tcp/150/memory/5".parse().unwrap()),
422 Err(())
423 );
424 assert_eq!(
425 parse_memory_addr(&"/memory/1234567890".parse().unwrap()),
426 Ok(1_234_567_890)
427 );
428 assert_eq!(
429 parse_memory_addr(
430 &"/memory/5/p2p/12D3KooWETLZBFBfkzvH3BQEtA1TJZPmjb4a18ss5TpwNU7DHDX6"
431 .parse()
432 .unwrap()
433 ),
434 Ok(5)
435 );
436 assert_eq!(
437 parse_memory_addr(
438 &"/memory/5/p2p/12D3KooWETLZBFBfkzvH3BQEtA1TJZPmjb4a18ss5TpwNU7DHDX6/p2p-circuit/p2p/12D3KooWLiQ7i8sY6LkPvHmEymncicEgzrdpXegbxEr3xgN8oxMU"
439 .parse()
440 .unwrap()
441 ),
442 Ok(5)
443 );
444 }
445
446 #[test]
447 fn listening_twice() {
448 let mut transport = MemoryTransport::default();
449
450 let addr_1: Multiaddr = "/memory/1639174018481".parse().unwrap();
451 let addr_2: Multiaddr = "/memory/8459375923478".parse().unwrap();
452
453 let listener_id_1 = ListenerId::next();
454
455 transport.listen_on(listener_id_1, addr_1.clone()).unwrap();
456 assert!(
457 transport.remove_listener(listener_id_1),
458 "Listener doesn't exist."
459 );
460
461 let listener_id_2 = ListenerId::next();
462 transport.listen_on(listener_id_2, addr_1.clone()).unwrap();
463 let listener_id_3 = ListenerId::next();
464 transport.listen_on(listener_id_3, addr_2.clone()).unwrap();
465
466 assert!(transport
467 .listen_on(ListenerId::next(), addr_1.clone())
468 .is_err());
469 assert!(transport
470 .listen_on(ListenerId::next(), addr_2.clone())
471 .is_err());
472
473 assert!(
474 transport.remove_listener(listener_id_2),
475 "Listener doesn't exist."
476 );
477 assert!(transport.listen_on(ListenerId::next(), addr_1).is_ok());
478 assert!(transport
479 .listen_on(ListenerId::next(), addr_2.clone())
480 .is_err());
481
482 assert!(
483 transport.remove_listener(listener_id_3),
484 "Listener doesn't exist."
485 );
486 assert!(transport.listen_on(ListenerId::next(), addr_2).is_ok());
487 }
488
489 #[test]
490 fn port_not_in_use() {
491 let mut transport = MemoryTransport::default();
492 assert!(transport
493 .dial(
494 "/memory/810172461024613".parse().unwrap(),
495 DialOpts {
496 role: Endpoint::Dialer,
497 port_use: PortUse::New
498 }
499 )
500 .is_err());
501 transport
502 .listen_on(
503 ListenerId::next(),
504 "/memory/810172461024613".parse().unwrap(),
505 )
506 .unwrap();
507 assert!(transport
508 .dial(
509 "/memory/810172461024613".parse().unwrap(),
510 DialOpts {
511 role: Endpoint::Dialer,
512 port_use: PortUse::New
513 }
514 )
515 .is_ok());
516 }
517
518 #[test]
519 fn stop_listening() {
520 let rand_port = rand::random::<u64>().saturating_add(1);
521 let addr: Multiaddr = format!("/memory/{rand_port}").parse().unwrap();
522
523 let mut transport = MemoryTransport::default().boxed();
524 futures::executor::block_on(async {
525 let listener_id = ListenerId::next();
526 transport.listen_on(listener_id, addr.clone()).unwrap();
527 let reported_addr = transport
528 .select_next_some()
529 .await
530 .into_new_address()
531 .expect("new address");
532 assert_eq!(addr, reported_addr);
533 assert!(transport.remove_listener(listener_id));
534 match transport.select_next_some().await {
535 TransportEvent::ListenerClosed {
536 listener_id: id,
537 reason,
538 } => {
539 assert_eq!(id, listener_id);
540 assert!(reason.is_ok())
541 }
542 other => panic!("Unexpected transport event: {other:?}"),
543 }
544 assert!(!transport.remove_listener(listener_id));
545 })
546 }
547
548 #[test]
549 fn communicating_between_dialer_and_listener() {
550 let msg = [1, 2, 3];
551
552 let rand_port = rand::random::<u64>().saturating_add(1);
555 let t1_addr: Multiaddr = format!("/memory/{rand_port}").parse().unwrap();
556 let cloned_t1_addr = t1_addr.clone();
557
558 let mut t1 = MemoryTransport::default().boxed();
559
560 let listener = async move {
561 t1.listen_on(ListenerId::next(), t1_addr.clone()).unwrap();
562 let upgrade = loop {
563 let event = t1.select_next_some().await;
564 if let Some(upgrade) = event.into_incoming() {
565 break upgrade;
566 }
567 };
568
569 let mut socket = upgrade.0.await.unwrap();
570
571 let mut buf = [0; 3];
572 socket.read_exact(&mut buf).await.unwrap();
573
574 assert_eq!(buf, msg);
575 };
576
577 let mut t2 = MemoryTransport::default();
580 let dialer = async move {
581 let mut socket = t2
582 .dial(
583 cloned_t1_addr,
584 DialOpts {
585 role: Endpoint::Dialer,
586 port_use: PortUse::New,
587 },
588 )
589 .unwrap()
590 .await
591 .unwrap();
592 socket.write_all(&msg).await.unwrap();
593 };
594
595 futures::executor::block_on(futures::future::join(listener, dialer));
598 }
599
600 #[test]
601 fn dialer_address_unequal_to_listener_address() {
602 let listener_addr: Multiaddr =
603 Protocol::Memory(rand::random::<u64>().saturating_add(1)).into();
604 let listener_addr_cloned = listener_addr.clone();
605
606 let mut listener_transport = MemoryTransport::default().boxed();
607
608 let listener = async move {
609 listener_transport
610 .listen_on(ListenerId::next(), listener_addr.clone())
611 .unwrap();
612 loop {
613 if let TransportEvent::Incoming { send_back_addr, .. } =
614 listener_transport.select_next_some().await
615 {
616 assert!(
617 send_back_addr != listener_addr,
618 "Expect dialer address not to equal listener address."
619 );
620 return;
621 }
622 }
623 };
624
625 let dialer = async move {
626 MemoryTransport::default()
627 .dial(
628 listener_addr_cloned,
629 DialOpts {
630 role: Endpoint::Dialer,
631 port_use: PortUse::New,
632 },
633 )
634 .unwrap()
635 .await
636 .unwrap();
637 };
638
639 futures::executor::block_on(futures::future::join(listener, dialer));
640 }
641
642 #[test]
643 fn dialer_port_is_deregistered() {
644 let (terminate, should_terminate) = futures::channel::oneshot::channel();
645 let (terminated, is_terminated) = futures::channel::oneshot::channel();
646
647 let listener_addr: Multiaddr =
648 Protocol::Memory(rand::random::<u64>().saturating_add(1)).into();
649 let listener_addr_cloned = listener_addr.clone();
650
651 let mut listener_transport = MemoryTransport::default().boxed();
652
653 let listener = async move {
654 listener_transport
655 .listen_on(ListenerId::next(), listener_addr.clone())
656 .unwrap();
657 loop {
658 if let TransportEvent::Incoming { send_back_addr, .. } =
659 listener_transport.select_next_some().await
660 {
661 let dialer_port =
662 NonZeroU64::new(parse_memory_addr(&send_back_addr).unwrap()).unwrap();
663
664 assert!(
665 HUB.get(&dialer_port).is_some(),
666 "Expect dialer port to stay registered while connection is in use.",
667 );
668
669 terminate.send(()).unwrap();
670 is_terminated.await.unwrap();
671
672 assert!(
673 HUB.get(&dialer_port).is_none(),
674 "Expect dialer port to be deregistered once connection is dropped.",
675 );
676
677 return;
678 }
679 }
680 };
681
682 let dialer = async move {
683 let chan = MemoryTransport::default()
684 .dial(
685 listener_addr_cloned,
686 DialOpts {
687 role: Endpoint::Dialer,
688 port_use: PortUse::New,
689 },
690 )
691 .unwrap()
692 .await
693 .unwrap();
694
695 should_terminate.await.unwrap();
696 drop(chan);
697 terminated.send(()).unwrap();
698 };
699
700 futures::executor::block_on(futures::future::join(listener, dialer));
701 }
702}