1#![cfg_attr(docsrs, feature(doc_cfg, doc_auto_cfg))]
24
25mod codec;
26mod config;
27mod io;
28
29use std::{
30    cmp, iter,
31    pin::Pin,
32    sync::Arc,
33    task::{Context, Poll},
34};
35
36use bytes::Bytes;
37use codec::LocalStreamId;
38pub use config::{Config, MaxBufferBehaviour};
39use futures::{prelude::*, ready};
40use libp2p_core::{
41    muxing::{StreamMuxer, StreamMuxerEvent},
42    upgrade::{InboundConnectionUpgrade, OutboundConnectionUpgrade, UpgradeInfo},
43};
44use parking_lot::Mutex;
45
46#[deprecated = "Use `Config` instead"]
47pub type MplexConfig = Config;
48
49impl UpgradeInfo for Config {
50    type Info = &'static str;
51    type InfoIter = iter::Once<Self::Info>;
52
53    fn protocol_info(&self) -> Self::InfoIter {
54        iter::once(self.protocol_name)
55    }
56}
57
58impl<C> InboundConnectionUpgrade<C> for Config
59where
60    C: AsyncRead + AsyncWrite + Unpin,
61{
62    type Output = Multiplex<C>;
63    type Error = io::Error;
64    type Future = future::Ready<Result<Self::Output, io::Error>>;
65
66    fn upgrade_inbound(self, socket: C, _: Self::Info) -> Self::Future {
67        future::ready(Ok(Multiplex {
68            #[allow(unknown_lints, clippy::arc_with_non_send_sync)] io: Arc::new(Mutex::new(io::Multiplexed::new(socket, self))),
70        }))
71    }
72}
73
74impl<C> OutboundConnectionUpgrade<C> for Config
75where
76    C: AsyncRead + AsyncWrite + Unpin,
77{
78    type Output = Multiplex<C>;
79    type Error = io::Error;
80    type Future = future::Ready<Result<Self::Output, io::Error>>;
81
82    fn upgrade_outbound(self, socket: C, _: Self::Info) -> Self::Future {
83        future::ready(Ok(Multiplex {
84            #[allow(unknown_lints, clippy::arc_with_non_send_sync)] io: Arc::new(Mutex::new(io::Multiplexed::new(socket, self))),
86        }))
87    }
88}
89
90pub struct Multiplex<C> {
92    io: Arc<Mutex<io::Multiplexed<C>>>,
93}
94
95impl<C> StreamMuxer for Multiplex<C>
96where
97    C: AsyncRead + AsyncWrite + Unpin,
98{
99    type Substream = Substream<C>;
100    type Error = io::Error;
101
102    fn poll_inbound(
103        self: Pin<&mut Self>,
104        cx: &mut Context<'_>,
105    ) -> Poll<Result<Self::Substream, Self::Error>> {
106        self.io
107            .lock()
108            .poll_next_stream(cx)
109            .map_ok(|stream_id| Substream::new(stream_id, self.io.clone()))
110    }
111
112    fn poll_outbound(
113        self: Pin<&mut Self>,
114        cx: &mut Context<'_>,
115    ) -> Poll<Result<Self::Substream, Self::Error>> {
116        self.io
117            .lock()
118            .poll_open_stream(cx)
119            .map_ok(|stream_id| Substream::new(stream_id, self.io.clone()))
120    }
121
122    fn poll(
123        self: Pin<&mut Self>,
124        _: &mut Context<'_>,
125    ) -> Poll<Result<StreamMuxerEvent, Self::Error>> {
126        Poll::Pending
127    }
128
129    fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
130        self.io.lock().poll_close(cx)
131    }
132}
133
134impl<C> AsyncRead for Substream<C>
135where
136    C: AsyncRead + AsyncWrite + Unpin,
137{
138    fn poll_read(
139        self: Pin<&mut Self>,
140        cx: &mut Context<'_>,
141        buf: &mut [u8],
142    ) -> Poll<io::Result<usize>> {
143        let this = self.get_mut();
144
145        loop {
146            if !this.current_data.is_empty() {
148                let len = cmp::min(this.current_data.len(), buf.len());
149                buf[..len].copy_from_slice(&this.current_data.split_to(len));
150                return Poll::Ready(Ok(len));
151            }
152
153            match ready!(this.io.lock().poll_read_stream(cx, this.id))? {
155                Some(data) => {
156                    this.current_data = data;
157                }
158                None => return Poll::Ready(Ok(0)),
159            }
160        }
161    }
162}
163
164impl<C> AsyncWrite for Substream<C>
165where
166    C: AsyncRead + AsyncWrite + Unpin,
167{
168    fn poll_write(
169        self: Pin<&mut Self>,
170        cx: &mut Context<'_>,
171        buf: &[u8],
172    ) -> Poll<io::Result<usize>> {
173        let this = self.get_mut();
174
175        this.io.lock().poll_write_stream(cx, this.id, buf)
176    }
177
178    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
179        let this = self.get_mut();
180
181        this.io.lock().poll_flush_stream(cx, this.id)
182    }
183
184    fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
185        let this = self.get_mut();
186        let mut io = this.io.lock();
187
188        ready!(io.poll_close_stream(cx, this.id))?;
189        ready!(io.poll_flush_stream(cx, this.id))?;
190
191        Poll::Ready(Ok(()))
192    }
193}
194
195pub struct Substream<C>
197where
198    C: AsyncRead + AsyncWrite + Unpin,
199{
200    id: LocalStreamId,
202    current_data: Bytes,
204    io: Arc<Mutex<io::Multiplexed<C>>>,
206}
207
208impl<C> Substream<C>
209where
210    C: AsyncRead + AsyncWrite + Unpin,
211{
212    fn new(id: LocalStreamId, io: Arc<Mutex<io::Multiplexed<C>>>) -> Self {
213        Self {
214            id,
215            current_data: Bytes::new(),
216            io,
217        }
218    }
219}
220
221impl<C> Drop for Substream<C>
222where
223    C: AsyncRead + AsyncWrite + Unpin,
224{
225    fn drop(&mut self) {
226        self.io.lock().drop_stream(self.id);
227    }
228}