libp2p_core/transport/
memory.rs

1// Copyright 2018 Parity Technologies (UK) Ltd.
2//
3// Permission is hereby granted, free of charge, to any person obtaining a
4// copy of this software and associated documentation files (the "Software"),
5// to deal in the Software without restriction, including without limitation
6// the rights to use, copy, modify, merge, publish, distribute, sublicense,
7// and/or sell copies of the Software, and to permit persons to whom the
8// Software is furnished to do so, subject to the following conditions:
9//
10// The above copyright notice and this permission notice shall be included in
11// all copies or substantial portions of the Software.
12//
13// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
14// OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
16// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
18// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
19// DEALINGS IN THE SOFTWARE.
20
21use std::{
22    collections::{hash_map::Entry, VecDeque},
23    error, fmt, io,
24    num::NonZeroU64,
25    pin::Pin,
26    sync::LazyLock,
27};
28
29use fnv::FnvHashMap;
30use futures::{
31    channel::mpsc,
32    future::Ready,
33    prelude::*,
34    task::{Context, Poll},
35};
36use multiaddr::{Multiaddr, Protocol};
37use parking_lot::Mutex;
38use rw_stream_sink::RwStreamSink;
39
40use crate::transport::{DialOpts, ListenerId, Transport, TransportError, TransportEvent};
41
42static HUB: LazyLock<Hub> = LazyLock::new(|| Hub(Mutex::new(FnvHashMap::default())));
43
44struct Hub(Mutex<FnvHashMap<NonZeroU64, ChannelSender>>);
45
46/// A [`mpsc::Sender`] enabling a [`DialFuture`] to send a [`Channel`] and the
47/// port of the dialer to a [`Listener`].
48type ChannelSender = mpsc::Sender<(Channel<Vec<u8>>, NonZeroU64)>;
49
50/// A [`mpsc::Receiver`] enabling a [`Listener`] to receive a [`Channel`] and
51/// the port of the dialer from a [`DialFuture`].
52type ChannelReceiver = mpsc::Receiver<(Channel<Vec<u8>>, NonZeroU64)>;
53
54impl Hub {
55    /// Registers the given port on the hub.
56    ///
57    /// Randomizes port when given port is `0`. Returns [`None`] when given port
58    /// is already occupied.
59    fn register_port(&self, port: u64) -> Option<(ChannelReceiver, NonZeroU64)> {
60        let mut hub = self.0.lock();
61
62        let port = if let Some(port) = NonZeroU64::new(port) {
63            port
64        } else {
65            loop {
66                let Some(port) = NonZeroU64::new(rand::random()) else {
67                    continue;
68                };
69                if !hub.contains_key(&port) {
70                    break port;
71                }
72            }
73        };
74
75        let (tx, rx) = mpsc::channel(2);
76        match hub.entry(port) {
77            Entry::Occupied(_) => return None,
78            Entry::Vacant(e) => e.insert(tx),
79        };
80
81        Some((rx, port))
82    }
83
84    fn unregister_port(&self, port: &NonZeroU64) -> Option<ChannelSender> {
85        self.0.lock().remove(port)
86    }
87
88    fn get(&self, port: &NonZeroU64) -> Option<ChannelSender> {
89        self.0.lock().get(port).cloned()
90    }
91}
92
93/// Transport that supports `/memory/N` multiaddresses.
94#[derive(Default)]
95pub struct MemoryTransport {
96    listeners: VecDeque<Pin<Box<Listener>>>,
97}
98
99impl MemoryTransport {
100    pub fn new() -> Self {
101        Self::default()
102    }
103}
104
105/// Connection to a `MemoryTransport` currently being opened.
106pub struct DialFuture {
107    /// Ephemeral source port.
108    ///
109    /// These ports mimic TCP ephemeral source ports but are not actually used
110    /// by the memory transport due to the direct use of channels. They merely
111    /// ensure that every connection has a unique address for each dialer, which
112    /// is not at the same time a listen address (analogous to TCP).
113    dial_port: NonZeroU64,
114    sender: ChannelSender,
115    channel_to_send: Option<Channel<Vec<u8>>>,
116    channel_to_return: Option<Channel<Vec<u8>>>,
117}
118
119impl DialFuture {
120    fn new(port: NonZeroU64) -> Option<Self> {
121        let sender = HUB.get(&port)?;
122
123        let (_dial_port_channel, dial_port) = HUB
124            .register_port(0)
125            .expect("there to be some random unoccupied port.");
126
127        let (a_tx, a_rx) = mpsc::channel(4096);
128        let (b_tx, b_rx) = mpsc::channel(4096);
129        Some(DialFuture {
130            dial_port,
131            sender,
132            channel_to_send: Some(RwStreamSink::new(Chan {
133                incoming: a_rx,
134                outgoing: b_tx,
135                dial_port: None,
136            })),
137            channel_to_return: Some(RwStreamSink::new(Chan {
138                incoming: b_rx,
139                outgoing: a_tx,
140                dial_port: Some(dial_port),
141            })),
142        })
143    }
144}
145
146impl Future for DialFuture {
147    type Output = Result<Channel<Vec<u8>>, MemoryTransportError>;
148
149    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
150        match self.sender.poll_ready(cx) {
151            Poll::Pending => return Poll::Pending,
152            Poll::Ready(Ok(())) => {}
153            Poll::Ready(Err(_)) => return Poll::Ready(Err(MemoryTransportError::Unreachable)),
154        }
155
156        let channel_to_send = self
157            .channel_to_send
158            .take()
159            .expect("Future should not be polled again once complete");
160        let dial_port = self.dial_port;
161        if self
162            .sender
163            .start_send((channel_to_send, dial_port))
164            .is_err()
165        {
166            return Poll::Ready(Err(MemoryTransportError::Unreachable));
167        }
168
169        Poll::Ready(Ok(self
170            .channel_to_return
171            .take()
172            .expect("Future should not be polled again once complete")))
173    }
174}
175
176impl Transport for MemoryTransport {
177    type Output = Channel<Vec<u8>>;
178    type Error = MemoryTransportError;
179    type ListenerUpgrade = Ready<Result<Self::Output, Self::Error>>;
180    type Dial = DialFuture;
181
182    fn listen_on(
183        &mut self,
184        id: ListenerId,
185        addr: Multiaddr,
186    ) -> Result<(), TransportError<Self::Error>> {
187        let port =
188            parse_memory_addr(&addr).map_err(|_| TransportError::MultiaddrNotSupported(addr))?;
189
190        let (rx, port) = HUB
191            .register_port(port)
192            .ok_or(TransportError::Other(MemoryTransportError::Unreachable))?;
193
194        let listener = Listener {
195            id,
196            port,
197            addr: Protocol::Memory(port.get()).into(),
198            receiver: rx,
199            tell_listen_addr: true,
200        };
201        self.listeners.push_back(Box::pin(listener));
202
203        Ok(())
204    }
205
206    fn remove_listener(&mut self, id: ListenerId) -> bool {
207        if let Some(index) = self.listeners.iter().position(|listener| listener.id == id) {
208            let listener = self.listeners.get_mut(index).unwrap();
209            let val_in = HUB.unregister_port(&listener.port);
210            debug_assert!(val_in.is_some());
211            listener.receiver.close();
212            true
213        } else {
214            false
215        }
216    }
217
218    fn dial(
219        &mut self,
220        addr: Multiaddr,
221        _opts: DialOpts,
222    ) -> Result<DialFuture, TransportError<Self::Error>> {
223        let port = if let Ok(port) = parse_memory_addr(&addr) {
224            if let Some(port) = NonZeroU64::new(port) {
225                port
226            } else {
227                return Err(TransportError::Other(MemoryTransportError::Unreachable));
228            }
229        } else {
230            return Err(TransportError::MultiaddrNotSupported(addr));
231        };
232
233        DialFuture::new(port).ok_or(TransportError::Other(MemoryTransportError::Unreachable))
234    }
235
236    fn poll(
237        mut self: Pin<&mut Self>,
238        cx: &mut Context<'_>,
239    ) -> Poll<TransportEvent<Self::ListenerUpgrade, Self::Error>>
240    where
241        Self: Sized,
242    {
243        let mut remaining = self.listeners.len();
244        while let Some(mut listener) = self.listeners.pop_back() {
245            if listener.tell_listen_addr {
246                listener.tell_listen_addr = false;
247                let listen_addr = listener.addr.clone();
248                let listener_id = listener.id;
249                self.listeners.push_front(listener);
250                return Poll::Ready(TransportEvent::NewAddress {
251                    listen_addr,
252                    listener_id,
253                });
254            }
255
256            let event = match Stream::poll_next(Pin::new(&mut listener.receiver), cx) {
257                Poll::Pending => None,
258                Poll::Ready(Some((channel, dial_port))) => Some(TransportEvent::Incoming {
259                    listener_id: listener.id,
260                    upgrade: future::ready(Ok(channel)),
261                    local_addr: listener.addr.clone(),
262                    send_back_addr: Protocol::Memory(dial_port.get()).into(),
263                }),
264                Poll::Ready(None) => {
265                    // Listener was closed.
266                    return Poll::Ready(TransportEvent::ListenerClosed {
267                        listener_id: listener.id,
268                        reason: Ok(()),
269                    });
270                }
271            };
272
273            self.listeners.push_front(listener);
274            if let Some(event) = event {
275                return Poll::Ready(event);
276            } else {
277                remaining -= 1;
278                if remaining == 0 {
279                    break;
280                }
281            }
282        }
283        Poll::Pending
284    }
285}
286
287/// Error that can be produced from the `MemoryTransport`.
288#[derive(Debug, Copy, Clone)]
289pub enum MemoryTransportError {
290    /// There's no listener on the given port.
291    Unreachable,
292    /// Tries to listen on a port that is already in use.
293    AlreadyInUse,
294}
295
296impl fmt::Display for MemoryTransportError {
297    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
298        match *self {
299            MemoryTransportError::Unreachable => write!(f, "No listener on the given port."),
300            MemoryTransportError::AlreadyInUse => write!(f, "Port already occupied."),
301        }
302    }
303}
304
305impl error::Error for MemoryTransportError {}
306
307/// Listener for memory connections.
308pub struct Listener {
309    id: ListenerId,
310    /// Port we're listening on.
311    port: NonZeroU64,
312    /// The address we are listening on.
313    addr: Multiaddr,
314    /// Receives incoming connections.
315    receiver: ChannelReceiver,
316    /// Generate [`TransportEvent::NewAddress`] to inform about our listen address.
317    tell_listen_addr: bool,
318}
319
320/// If the address is `/memory/n`, returns the value of `n`.
321fn parse_memory_addr(a: &Multiaddr) -> Result<u64, ()> {
322    let mut protocols = a.iter();
323    match protocols.next() {
324        Some(Protocol::Memory(port)) => match protocols.next() {
325            None | Some(Protocol::P2p(_)) => Ok(port),
326            _ => Err(()),
327        },
328        _ => Err(()),
329    }
330}
331
332/// A channel represents an established, in-memory, logical connection between two endpoints.
333///
334/// Implements `AsyncRead` and `AsyncWrite`.
335pub type Channel<T> = RwStreamSink<Chan<T>>;
336
337/// A channel represents an established, in-memory, logical connection between two endpoints.
338///
339/// Implements `Sink` and `Stream`.
340pub struct Chan<T = Vec<u8>> {
341    incoming: mpsc::Receiver<T>,
342    outgoing: mpsc::Sender<T>,
343
344    // Needed in [`Drop`] implementation of [`Chan`] to unregister the dialing
345    // port with the global [`HUB`]. Is [`Some`] when [`Chan`] of dialer and
346    // [`None`] when [`Chan`] of listener.
347    //
348    // Note: Listening port is unregistered in [`Drop`] implementation of
349    // [`Listener`].
350    dial_port: Option<NonZeroU64>,
351}
352
353impl<T> Unpin for Chan<T> {}
354
355impl<T> Stream for Chan<T> {
356    type Item = Result<T, io::Error>;
357
358    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
359        match Stream::poll_next(Pin::new(&mut self.incoming), cx) {
360            Poll::Pending => Poll::Pending,
361            Poll::Ready(None) => Poll::Ready(None),
362            Poll::Ready(Some(v)) => Poll::Ready(Some(Ok(v))),
363        }
364    }
365}
366
367impl<T> Sink<T> for Chan<T> {
368    type Error = io::Error;
369
370    fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
371        self.outgoing
372            .poll_ready(cx)
373            .map(|v| v.map_err(|_| io::ErrorKind::BrokenPipe.into()))
374    }
375
376    fn start_send(mut self: Pin<&mut Self>, item: T) -> Result<(), Self::Error> {
377        self.outgoing
378            .start_send(item)
379            .map_err(|_| io::ErrorKind::BrokenPipe.into())
380    }
381
382    fn poll_flush(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
383        Poll::Ready(Ok(()))
384    }
385
386    fn poll_close(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
387        Poll::Ready(Ok(()))
388    }
389}
390
391impl<T: AsRef<[u8]>> From<Chan<T>> for RwStreamSink<Chan<T>> {
392    fn from(channel: Chan<T>) -> RwStreamSink<Chan<T>> {
393        RwStreamSink::new(channel)
394    }
395}
396
397impl<T> Drop for Chan<T> {
398    fn drop(&mut self) {
399        if let Some(port) = self.dial_port {
400            let channel_sender = HUB.unregister_port(&port);
401            debug_assert!(channel_sender.is_some());
402        }
403    }
404}
405
406#[cfg(test)]
407mod tests {
408    use super::*;
409    use crate::{transport::PortUse, Endpoint};
410
411    #[test]
412    fn parse_memory_addr_works() {
413        assert_eq!(parse_memory_addr(&"/memory/5".parse().unwrap()), Ok(5));
414        assert_eq!(parse_memory_addr(&"/tcp/150".parse().unwrap()), Err(()));
415        assert_eq!(parse_memory_addr(&"/memory/0".parse().unwrap()), Ok(0));
416        assert_eq!(
417            parse_memory_addr(&"/memory/5/tcp/150".parse().unwrap()),
418            Err(())
419        );
420        assert_eq!(
421            parse_memory_addr(&"/tcp/150/memory/5".parse().unwrap()),
422            Err(())
423        );
424        assert_eq!(
425            parse_memory_addr(&"/memory/1234567890".parse().unwrap()),
426            Ok(1_234_567_890)
427        );
428        assert_eq!(
429            parse_memory_addr(
430                &"/memory/5/p2p/12D3KooWETLZBFBfkzvH3BQEtA1TJZPmjb4a18ss5TpwNU7DHDX6"
431                    .parse()
432                    .unwrap()
433            ),
434            Ok(5)
435        );
436        assert_eq!(
437            parse_memory_addr(
438                &"/memory/5/p2p/12D3KooWETLZBFBfkzvH3BQEtA1TJZPmjb4a18ss5TpwNU7DHDX6/p2p-circuit/p2p/12D3KooWLiQ7i8sY6LkPvHmEymncicEgzrdpXegbxEr3xgN8oxMU"
439                    .parse()
440                    .unwrap()
441            ),
442            Ok(5)
443        );
444    }
445
446    #[test]
447    fn listening_twice() {
448        let mut transport = MemoryTransport::default();
449
450        let addr_1: Multiaddr = "/memory/1639174018481".parse().unwrap();
451        let addr_2: Multiaddr = "/memory/8459375923478".parse().unwrap();
452
453        let listener_id_1 = ListenerId::next();
454
455        transport.listen_on(listener_id_1, addr_1.clone()).unwrap();
456        assert!(
457            transport.remove_listener(listener_id_1),
458            "Listener doesn't exist."
459        );
460
461        let listener_id_2 = ListenerId::next();
462        transport.listen_on(listener_id_2, addr_1.clone()).unwrap();
463        let listener_id_3 = ListenerId::next();
464        transport.listen_on(listener_id_3, addr_2.clone()).unwrap();
465
466        assert!(transport
467            .listen_on(ListenerId::next(), addr_1.clone())
468            .is_err());
469        assert!(transport
470            .listen_on(ListenerId::next(), addr_2.clone())
471            .is_err());
472
473        assert!(
474            transport.remove_listener(listener_id_2),
475            "Listener doesn't exist."
476        );
477        assert!(transport.listen_on(ListenerId::next(), addr_1).is_ok());
478        assert!(transport
479            .listen_on(ListenerId::next(), addr_2.clone())
480            .is_err());
481
482        assert!(
483            transport.remove_listener(listener_id_3),
484            "Listener doesn't exist."
485        );
486        assert!(transport.listen_on(ListenerId::next(), addr_2).is_ok());
487    }
488
489    #[test]
490    fn port_not_in_use() {
491        let mut transport = MemoryTransport::default();
492        assert!(transport
493            .dial(
494                "/memory/810172461024613".parse().unwrap(),
495                DialOpts {
496                    role: Endpoint::Dialer,
497                    port_use: PortUse::New
498                }
499            )
500            .is_err());
501        transport
502            .listen_on(
503                ListenerId::next(),
504                "/memory/810172461024613".parse().unwrap(),
505            )
506            .unwrap();
507        assert!(transport
508            .dial(
509                "/memory/810172461024613".parse().unwrap(),
510                DialOpts {
511                    role: Endpoint::Dialer,
512                    port_use: PortUse::New
513                }
514            )
515            .is_ok());
516    }
517
518    #[test]
519    fn stop_listening() {
520        let rand_port = rand::random::<u64>().saturating_add(1);
521        let addr: Multiaddr = format!("/memory/{rand_port}").parse().unwrap();
522
523        let mut transport = MemoryTransport::default().boxed();
524        futures::executor::block_on(async {
525            let listener_id = ListenerId::next();
526            transport.listen_on(listener_id, addr.clone()).unwrap();
527            let reported_addr = transport
528                .select_next_some()
529                .await
530                .into_new_address()
531                .expect("new address");
532            assert_eq!(addr, reported_addr);
533            assert!(transport.remove_listener(listener_id));
534            match transport.select_next_some().await {
535                TransportEvent::ListenerClosed {
536                    listener_id: id,
537                    reason,
538                } => {
539                    assert_eq!(id, listener_id);
540                    assert!(reason.is_ok())
541                }
542                other => panic!("Unexpected transport event: {other:?}"),
543            }
544            assert!(!transport.remove_listener(listener_id));
545        })
546    }
547
548    #[test]
549    fn communicating_between_dialer_and_listener() {
550        let msg = [1, 2, 3];
551
552        // Setup listener.
553
554        let rand_port = rand::random::<u64>().saturating_add(1);
555        let t1_addr: Multiaddr = format!("/memory/{rand_port}").parse().unwrap();
556        let cloned_t1_addr = t1_addr.clone();
557
558        let mut t1 = MemoryTransport::default().boxed();
559
560        let listener = async move {
561            t1.listen_on(ListenerId::next(), t1_addr.clone()).unwrap();
562            let upgrade = loop {
563                let event = t1.select_next_some().await;
564                if let Some(upgrade) = event.into_incoming() {
565                    break upgrade;
566                }
567            };
568
569            let mut socket = upgrade.0.await.unwrap();
570
571            let mut buf = [0; 3];
572            socket.read_exact(&mut buf).await.unwrap();
573
574            assert_eq!(buf, msg);
575        };
576
577        // Setup dialer.
578
579        let mut t2 = MemoryTransport::default();
580        let dialer = async move {
581            let mut socket = t2
582                .dial(
583                    cloned_t1_addr,
584                    DialOpts {
585                        role: Endpoint::Dialer,
586                        port_use: PortUse::New,
587                    },
588                )
589                .unwrap()
590                .await
591                .unwrap();
592            socket.write_all(&msg).await.unwrap();
593        };
594
595        // Wait for both to finish.
596
597        futures::executor::block_on(futures::future::join(listener, dialer));
598    }
599
600    #[test]
601    fn dialer_address_unequal_to_listener_address() {
602        let listener_addr: Multiaddr =
603            Protocol::Memory(rand::random::<u64>().saturating_add(1)).into();
604        let listener_addr_cloned = listener_addr.clone();
605
606        let mut listener_transport = MemoryTransport::default().boxed();
607
608        let listener = async move {
609            listener_transport
610                .listen_on(ListenerId::next(), listener_addr.clone())
611                .unwrap();
612            loop {
613                if let TransportEvent::Incoming { send_back_addr, .. } =
614                    listener_transport.select_next_some().await
615                {
616                    assert!(
617                        send_back_addr != listener_addr,
618                        "Expect dialer address not to equal listener address."
619                    );
620                    return;
621                }
622            }
623        };
624
625        let dialer = async move {
626            MemoryTransport::default()
627                .dial(
628                    listener_addr_cloned,
629                    DialOpts {
630                        role: Endpoint::Dialer,
631                        port_use: PortUse::New,
632                    },
633                )
634                .unwrap()
635                .await
636                .unwrap();
637        };
638
639        futures::executor::block_on(futures::future::join(listener, dialer));
640    }
641
642    #[test]
643    fn dialer_port_is_deregistered() {
644        let (terminate, should_terminate) = futures::channel::oneshot::channel();
645        let (terminated, is_terminated) = futures::channel::oneshot::channel();
646
647        let listener_addr: Multiaddr =
648            Protocol::Memory(rand::random::<u64>().saturating_add(1)).into();
649        let listener_addr_cloned = listener_addr.clone();
650
651        let mut listener_transport = MemoryTransport::default().boxed();
652
653        let listener = async move {
654            listener_transport
655                .listen_on(ListenerId::next(), listener_addr.clone())
656                .unwrap();
657            loop {
658                if let TransportEvent::Incoming { send_back_addr, .. } =
659                    listener_transport.select_next_some().await
660                {
661                    let dialer_port =
662                        NonZeroU64::new(parse_memory_addr(&send_back_addr).unwrap()).unwrap();
663
664                    assert!(
665                        HUB.get(&dialer_port).is_some(),
666                        "Expect dialer port to stay registered while connection is in use.",
667                    );
668
669                    terminate.send(()).unwrap();
670                    is_terminated.await.unwrap();
671
672                    assert!(
673                        HUB.get(&dialer_port).is_none(),
674                        "Expect dialer port to be deregistered once connection is dropped.",
675                    );
676
677                    return;
678                }
679            }
680        };
681
682        let dialer = async move {
683            let chan = MemoryTransport::default()
684                .dial(
685                    listener_addr_cloned,
686                    DialOpts {
687                        role: Endpoint::Dialer,
688                        port_use: PortUse::New,
689                    },
690                )
691                .unwrap()
692                .await
693                .unwrap();
694
695            should_terminate.await.unwrap();
696            drop(chan);
697            terminated.send(()).unwrap();
698        };
699
700        futures::executor::block_on(futures::future::join(listener, dialer));
701    }
702}