libp2p_webrtc_utils/
stream.rs

1// Copyright 2022 Parity Technologies (UK) Ltd.
2// Copyright 2023 Protocol Labs.
3//
4// Permission is hereby granted, free of charge, to any person obtaining a
5// copy of this software and associated documentation files (the "Software"),
6// to deal in the Software without restriction, including without limitation
7// the rights to use, copy, modify, merge, publish, distribute, sublicense,
8// and/or sell copies of the Software, and to permit persons to whom the
9// Software is furnished to do so, subject to the following conditions:
10//
11// The above copyright notice and this permission notice shall be included in
12// all copies or substantial portions of the Software.
13//
14// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
15// OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
16// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
17// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
18// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
19// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
20// DEALINGS IN THE SOFTWARE.
21
22use std::{
23    io,
24    pin::Pin,
25    task::{Context, Poll},
26};
27
28use bytes::Bytes;
29use futures::{channel::oneshot, prelude::*, ready};
30
31use crate::{
32    proto::{Flag, Message},
33    stream::{
34        drop_listener::GracefullyClosed,
35        framed_dc::FramedDc,
36        state::{Closing, State},
37    },
38};
39
40mod drop_listener;
41mod framed_dc;
42mod state;
43
44/// Maximum length of a message.
45///
46/// "As long as message interleaving is not supported, the sender SHOULD limit the maximum message
47/// size to 16 KB to avoid monopolization."
48/// Source: <https://www.rfc-editor.org/rfc/rfc8831#name-transferring-user-data-on-a>
49pub const MAX_MSG_LEN: usize = 16 * 1024;
50/// Length of varint, in bytes.
51const VARINT_LEN: usize = 2;
52/// Overhead of the protobuf encoding, in bytes.
53const PROTO_OVERHEAD: usize = 5;
54/// Maximum length of data, in bytes.
55const MAX_DATA_LEN: usize = MAX_MSG_LEN - VARINT_LEN - PROTO_OVERHEAD;
56
57pub use drop_listener::DropListener;
58/// A stream backed by a WebRTC data channel.
59///
60/// To be a proper libp2p stream, we need to implement [`AsyncRead`] and [`AsyncWrite`] as well
61/// as support a half-closed state which we do by framing messages in a protobuf envelope.
62pub struct Stream<T> {
63    io: FramedDc<T>,
64    state: State,
65    read_buffer: Bytes,
66    /// Dropping this will close the oneshot and notify the receiver by emitting `Canceled`.
67    drop_notifier: Option<oneshot::Sender<GracefullyClosed>>,
68}
69
70impl<T> Stream<T>
71where
72    T: AsyncRead + AsyncWrite + Unpin + Clone,
73{
74    /// Returns a new [`Stream`] and a [`DropListener`],
75    /// which will notify the receiver when/if the stream is dropped.
76    pub fn new(data_channel: T) -> (Self, DropListener<T>) {
77        let (sender, receiver) = oneshot::channel();
78
79        let stream = Self {
80            io: framed_dc::new(data_channel.clone()),
81            state: State::Open,
82            read_buffer: Bytes::default(),
83            drop_notifier: Some(sender),
84        };
85        let listener = DropListener::new(framed_dc::new(data_channel), receiver);
86
87        (stream, listener)
88    }
89
90    /// Gracefully closes the "read-half" of the stream.
91    pub fn poll_close_read(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
92        loop {
93            match self.state.close_read_barrier()? {
94                Some(Closing::Requested) => {
95                    ready!(self.io.poll_ready_unpin(cx))?;
96
97                    self.io.start_send_unpin(Message {
98                        flag: Some(Flag::STOP_SENDING),
99                        message: None,
100                    })?;
101                    self.state.close_read_message_sent();
102
103                    continue;
104                }
105                Some(Closing::MessageSent) => {
106                    ready!(self.io.poll_flush_unpin(cx))?;
107
108                    self.state.read_closed();
109
110                    return Poll::Ready(Ok(()));
111                }
112                None => return Poll::Ready(Ok(())),
113            }
114        }
115    }
116}
117
118impl<T> AsyncRead for Stream<T>
119where
120    T: AsyncRead + AsyncWrite + Unpin,
121{
122    fn poll_read(
123        mut self: Pin<&mut Self>,
124        cx: &mut Context<'_>,
125        buf: &mut [u8],
126    ) -> Poll<io::Result<usize>> {
127        loop {
128            self.state.read_barrier()?;
129
130            if !self.read_buffer.is_empty() {
131                let n = std::cmp::min(self.read_buffer.len(), buf.len());
132                let data = self.read_buffer.split_to(n);
133                buf[0..n].copy_from_slice(&data[..]);
134
135                return Poll::Ready(Ok(n));
136            }
137
138            let Self {
139                read_buffer,
140                io,
141                state,
142                ..
143            } = &mut *self;
144
145            match ready!(io_poll_next(io, cx))? {
146                Some((flag, message)) => {
147                    if let Some(flag) = flag {
148                        state.handle_inbound_flag(flag, read_buffer);
149                    }
150
151                    debug_assert!(read_buffer.is_empty());
152                    match message {
153                        Some(msg) if !msg.is_empty() => {
154                            *read_buffer = msg.into();
155                        }
156                        _ => {
157                            tracing::debug!("poll_read buffer is empty, received None");
158                            return Poll::Ready(Ok(0));
159                        }
160                    }
161                }
162                None => {
163                    state.handle_inbound_flag(Flag::FIN, read_buffer);
164                    return Poll::Ready(Ok(0));
165                }
166            }
167        }
168    }
169}
170
171impl<T> AsyncWrite for Stream<T>
172where
173    T: AsyncRead + AsyncWrite + Unpin,
174{
175    fn poll_write(
176        mut self: Pin<&mut Self>,
177        cx: &mut Context<'_>,
178        buf: &[u8],
179    ) -> Poll<io::Result<usize>> {
180        while self.state.read_flags_in_async_write() {
181            // TODO: In case AsyncRead::poll_read encountered an error or returned None earlier, we
182            // will poll the underlying I/O resource once more. Is that allowed? How
183            // about introducing a state IoReadClosed?
184
185            let Self {
186                read_buffer,
187                io,
188                state,
189                ..
190            } = &mut *self;
191
192            match io_poll_next(io, cx)? {
193                Poll::Ready(Some((Some(flag), message))) => {
194                    // Read side is closed. Discard any incoming messages.
195                    drop(message);
196                    // But still handle flags, e.g. a `Flag::StopSending`.
197                    state.handle_inbound_flag(flag, read_buffer)
198                }
199                Poll::Ready(Some((None, message))) => drop(message),
200                Poll::Ready(None) | Poll::Pending => break,
201            }
202        }
203
204        self.state.write_barrier()?;
205
206        ready!(self.io.poll_ready_unpin(cx))?;
207
208        let n = usize::min(buf.len(), MAX_DATA_LEN);
209
210        Pin::new(&mut self.io).start_send(Message {
211            flag: None,
212            message: Some(buf[0..n].into()),
213        })?;
214
215        Poll::Ready(Ok(n))
216    }
217
218    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
219        self.io.poll_flush_unpin(cx).map_err(Into::into)
220    }
221
222    fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
223        loop {
224            match self.state.close_write_barrier()? {
225                Some(Closing::Requested) => {
226                    ready!(self.io.poll_ready_unpin(cx))?;
227
228                    self.io.start_send_unpin(Message {
229                        flag: Some(Flag::FIN),
230                        message: None,
231                    })?;
232                    self.state.close_write_message_sent();
233
234                    continue;
235                }
236                Some(Closing::MessageSent) => {
237                    ready!(self.io.poll_flush_unpin(cx))?;
238
239                    self.state.write_closed();
240                    let _ = self
241                        .drop_notifier
242                        .take()
243                        .expect("to not close twice")
244                        .send(GracefullyClosed {});
245
246                    return Poll::Ready(Ok(()));
247                }
248                None => return Poll::Ready(Ok(())),
249            }
250        }
251    }
252}
253
254fn io_poll_next<T>(
255    io: &mut FramedDc<T>,
256    cx: &mut Context<'_>,
257) -> Poll<io::Result<Option<(Option<Flag>, Option<Vec<u8>>)>>>
258where
259    T: AsyncRead + AsyncWrite + Unpin,
260{
261    match ready!(io.poll_next_unpin(cx))
262        .transpose()
263        .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?
264    {
265        Some(Message { flag, message }) => Poll::Ready(Ok(Some((flag, message)))),
266        None => Poll::Ready(Ok(None)),
267    }
268}
269
270#[cfg(test)]
271mod tests {
272    use asynchronous_codec::Encoder;
273    use bytes::BytesMut;
274
275    use super::*;
276    use crate::stream::framed_dc::codec;
277
278    #[test]
279    fn max_data_len() {
280        // Largest possible message.
281        let message = [0; MAX_DATA_LEN];
282
283        let protobuf = Message {
284            flag: Some(Flag::FIN),
285            message: Some(message.to_vec()),
286        };
287
288        let mut codec = codec();
289
290        let mut dst = BytesMut::new();
291        codec.encode(protobuf, &mut dst).unwrap();
292
293        // Ensure the varint prefixed and protobuf encoded largest message is no longer than the
294        // maximum limit specified in the libp2p WebRTC specification.
295        assert_eq!(dst.len(), MAX_MSG_LEN);
296    }
297}