multistream_select/
length_delimited.rs

1// Copyright 2017 Parity Technologies (UK) Ltd.
2//
3// Permission is hereby granted, free of charge, to any person obtaining a
4// copy of this software and associated documentation files (the "Software"),
5// to deal in the Software without restriction, including without limitation
6// the rights to use, copy, modify, merge, publish, distribute, sublicense,
7// and/or sell copies of the Software, and to permit persons to whom the
8// Software is furnished to do so, subject to the following conditions:
9//
10// The above copyright notice and this permission notice shall be included in
11// all copies or substantial portions of the Software.
12//
13// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
14// OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
16// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
18// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
19// DEALINGS IN THE SOFTWARE.
20
21use std::{
22    convert::TryFrom as _,
23    io,
24    pin::Pin,
25    task::{Context, Poll},
26};
27
28use bytes::{Buf as _, BufMut as _, Bytes, BytesMut};
29use futures::{io::IoSlice, prelude::*};
30
31const MAX_LEN_BYTES: u16 = 2;
32const MAX_FRAME_SIZE: u16 = (1 << (MAX_LEN_BYTES * 8 - MAX_LEN_BYTES)) - 1;
33const DEFAULT_BUFFER_SIZE: usize = 64;
34
35/// A `Stream` and `Sink` for unsigned-varint length-delimited frames,
36/// wrapping an underlying `AsyncRead + AsyncWrite` I/O resource.
37///
38/// We purposely only support a frame sizes up to 16KiB (2 bytes unsigned varint
39/// frame length). Frames mostly consist in a short protocol name, which is highly
40/// unlikely to be more than 16KiB long.
41#[pin_project::pin_project]
42#[derive(Debug)]
43pub(crate) struct LengthDelimited<R> {
44    /// The inner I/O resource.
45    #[pin]
46    inner: R,
47    /// Read buffer for a single incoming unsigned-varint length-delimited frame.
48    read_buffer: BytesMut,
49    /// Write buffer for outgoing unsigned-varint length-delimited frames.
50    write_buffer: BytesMut,
51    /// The current read state, alternating between reading a frame
52    /// length and reading a frame payload.
53    read_state: ReadState,
54}
55
56#[derive(Debug, Copy, Clone, PartialEq, Eq)]
57enum ReadState {
58    /// We are currently reading the length of the next frame of data.
59    ReadLength {
60        buf: [u8; MAX_LEN_BYTES as usize],
61        pos: usize,
62    },
63    /// We are currently reading the frame of data itself.
64    ReadData { len: u16, pos: usize },
65}
66
67impl Default for ReadState {
68    fn default() -> Self {
69        ReadState::ReadLength {
70            buf: [0; MAX_LEN_BYTES as usize],
71            pos: 0,
72        }
73    }
74}
75
76impl<R> LengthDelimited<R> {
77    /// Creates a new I/O resource for reading and writing unsigned-varint
78    /// length delimited frames.
79    pub(crate) fn new(inner: R) -> LengthDelimited<R> {
80        LengthDelimited {
81            inner,
82            read_state: ReadState::default(),
83            read_buffer: BytesMut::with_capacity(DEFAULT_BUFFER_SIZE),
84            write_buffer: BytesMut::with_capacity(DEFAULT_BUFFER_SIZE + MAX_LEN_BYTES as usize),
85        }
86    }
87
88    /// Drops the [`LengthDelimited`] resource, yielding the underlying I/O stream.
89    ///
90    /// # Panic
91    ///
92    /// Will panic if called while there is data in the read or write buffer.
93    /// The read buffer is guaranteed to be empty whenever `Stream::poll` yields
94    /// a new `Bytes` frame. The write buffer is guaranteed to be empty after
95    /// flushing.
96    pub(crate) fn into_inner(self) -> R {
97        assert!(self.read_buffer.is_empty());
98        assert!(self.write_buffer.is_empty());
99        self.inner
100    }
101
102    /// Converts the [`LengthDelimited`] into a [`LengthDelimitedReader`], dropping the
103    /// uvi-framed `Sink` in favour of direct `AsyncWrite` access to the underlying
104    /// I/O stream.
105    ///
106    /// This is typically done if further uvi-framed messages are expected to be
107    /// received but no more such messages are written, allowing the writing of
108    /// follow-up protocol data to commence.
109    pub(crate) fn into_reader(self) -> LengthDelimitedReader<R> {
110        LengthDelimitedReader { inner: self }
111    }
112
113    /// Writes all buffered frame data to the underlying I/O stream,
114    /// _without flushing it_.
115    ///
116    /// After this method returns `Poll::Ready`, the write buffer of frames
117    /// submitted to the `Sink` is guaranteed to be empty.
118    fn poll_write_buffer(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>>
119    where
120        R: AsyncWrite,
121    {
122        let mut this = self.project();
123
124        while !this.write_buffer.is_empty() {
125            match this.inner.as_mut().poll_write(cx, this.write_buffer) {
126                Poll::Pending => return Poll::Pending,
127                Poll::Ready(Ok(0)) => {
128                    return Poll::Ready(Err(io::Error::new(
129                        io::ErrorKind::WriteZero,
130                        "Failed to write buffered frame.",
131                    )))
132                }
133                Poll::Ready(Ok(n)) => this.write_buffer.advance(n),
134                Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
135            }
136        }
137
138        Poll::Ready(Ok(()))
139    }
140}
141
142impl<R> Stream for LengthDelimited<R>
143where
144    R: AsyncRead,
145{
146    type Item = Result<Bytes, io::Error>;
147
148    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
149        let mut this = self.project();
150
151        loop {
152            match this.read_state {
153                ReadState::ReadLength { buf, pos } => {
154                    match this.inner.as_mut().poll_read(cx, &mut buf[*pos..*pos + 1]) {
155                        Poll::Ready(Ok(0)) => {
156                            if *pos == 0 {
157                                return Poll::Ready(None);
158                            } else {
159                                return Poll::Ready(Some(Err(io::ErrorKind::UnexpectedEof.into())));
160                            }
161                        }
162                        Poll::Ready(Ok(n)) => {
163                            debug_assert_eq!(n, 1);
164                            *pos += n;
165                        }
166                        Poll::Ready(Err(err)) => return Poll::Ready(Some(Err(err))),
167                        Poll::Pending => return Poll::Pending,
168                    };
169
170                    if (buf[*pos - 1] & 0x80) == 0 {
171                        // MSB is not set, indicating the end of the length prefix.
172                        let (len, _) = unsigned_varint::decode::u16(buf).map_err(|e| {
173                            tracing::debug!("invalid length prefix: {e}");
174                            io::Error::new(io::ErrorKind::InvalidData, "invalid length prefix")
175                        })?;
176
177                        if len >= 1 {
178                            *this.read_state = ReadState::ReadData { len, pos: 0 };
179                            this.read_buffer.resize(len as usize, 0);
180                        } else {
181                            debug_assert_eq!(len, 0);
182                            *this.read_state = ReadState::default();
183                            return Poll::Ready(Some(Ok(Bytes::new())));
184                        }
185                    } else if *pos == MAX_LEN_BYTES as usize {
186                        // MSB signals more length bytes but we have already read the maximum.
187                        // See the module documentation about the max frame len.
188                        return Poll::Ready(Some(Err(io::Error::new(
189                            io::ErrorKind::InvalidData,
190                            "Maximum frame length exceeded",
191                        ))));
192                    }
193                }
194                ReadState::ReadData { len, pos } => {
195                    match this
196                        .inner
197                        .as_mut()
198                        .poll_read(cx, &mut this.read_buffer[*pos..])
199                    {
200                        Poll::Ready(Ok(0)) => {
201                            return Poll::Ready(Some(Err(io::ErrorKind::UnexpectedEof.into())))
202                        }
203                        Poll::Ready(Ok(n)) => *pos += n,
204                        Poll::Pending => return Poll::Pending,
205                        Poll::Ready(Err(err)) => return Poll::Ready(Some(Err(err))),
206                    };
207
208                    if *pos == *len as usize {
209                        // Finished reading the frame.
210                        let frame = this.read_buffer.split_off(0).freeze();
211                        *this.read_state = ReadState::default();
212                        return Poll::Ready(Some(Ok(frame)));
213                    }
214                }
215            }
216        }
217    }
218}
219
220impl<R> Sink<Bytes> for LengthDelimited<R>
221where
222    R: AsyncWrite,
223{
224    type Error = io::Error;
225
226    fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
227        // Use the maximum frame length also as a (soft) upper limit
228        // for the entire write buffer. The actual (hard) limit is thus
229        // implied to be roughly 2 * MAX_FRAME_SIZE.
230        if self.as_mut().project().write_buffer.len() >= MAX_FRAME_SIZE as usize {
231            match self.as_mut().poll_write_buffer(cx) {
232                Poll::Ready(Ok(())) => {}
233                Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
234                Poll::Pending => return Poll::Pending,
235            }
236
237            debug_assert!(self.as_mut().project().write_buffer.is_empty());
238        }
239
240        Poll::Ready(Ok(()))
241    }
242
243    fn start_send(self: Pin<&mut Self>, item: Bytes) -> Result<(), Self::Error> {
244        let this = self.project();
245
246        let len = match u16::try_from(item.len()) {
247            Ok(len) if len <= MAX_FRAME_SIZE => len,
248            _ => {
249                return Err(io::Error::new(
250                    io::ErrorKind::InvalidData,
251                    "Maximum frame size exceeded.",
252                ))
253            }
254        };
255
256        let mut uvi_buf = unsigned_varint::encode::u16_buffer();
257        let uvi_len = unsigned_varint::encode::u16(len, &mut uvi_buf);
258        this.write_buffer.reserve(len as usize + uvi_len.len());
259        this.write_buffer.put(uvi_len);
260        this.write_buffer.put(item);
261
262        Ok(())
263    }
264
265    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
266        // Write all buffered frame data to the underlying I/O stream.
267        match LengthDelimited::poll_write_buffer(self.as_mut(), cx) {
268            Poll::Ready(Ok(())) => {}
269            Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
270            Poll::Pending => return Poll::Pending,
271        }
272
273        let this = self.project();
274        debug_assert!(this.write_buffer.is_empty());
275
276        // Flush the underlying I/O stream.
277        this.inner.poll_flush(cx)
278    }
279
280    fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
281        // Write all buffered frame data to the underlying I/O stream.
282        match LengthDelimited::poll_write_buffer(self.as_mut(), cx) {
283            Poll::Ready(Ok(())) => {}
284            Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
285            Poll::Pending => return Poll::Pending,
286        }
287
288        let this = self.project();
289        debug_assert!(this.write_buffer.is_empty());
290
291        // Close the underlying I/O stream.
292        this.inner.poll_close(cx)
293    }
294}
295
296/// A `LengthDelimitedReader` implements a `Stream` of uvi-length-delimited
297/// frames on an underlying I/O resource combined with direct `AsyncWrite` access.
298#[pin_project::pin_project]
299#[derive(Debug)]
300pub(crate) struct LengthDelimitedReader<R> {
301    #[pin]
302    inner: LengthDelimited<R>,
303}
304
305impl<R> LengthDelimitedReader<R> {
306    /// Destroys the `LengthDelimitedReader` and returns the underlying I/O stream.
307    ///
308    /// This method is guaranteed not to drop any data read from or not yet
309    /// submitted to the underlying I/O stream.
310    ///
311    /// # Panic
312    ///
313    /// Will panic if called while there is data in the read or write buffer.
314    /// The read buffer is guaranteed to be empty whenever [`Stream::poll_next`]
315    /// yield a new `Message`. The write buffer is guaranteed to be empty whenever
316    /// [`LengthDelimited::poll_write_buffer`] yields [`Poll::Ready`] or after
317    /// the [`Sink`] has been completely flushed via [`Sink::poll_flush`].
318    pub(crate) fn into_inner(self) -> R {
319        self.inner.into_inner()
320    }
321}
322
323impl<R> Stream for LengthDelimitedReader<R>
324where
325    R: AsyncRead,
326{
327    type Item = Result<Bytes, io::Error>;
328
329    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
330        self.project().inner.poll_next(cx)
331    }
332}
333
334impl<R> AsyncWrite for LengthDelimitedReader<R>
335where
336    R: AsyncWrite,
337{
338    fn poll_write(
339        self: Pin<&mut Self>,
340        cx: &mut Context<'_>,
341        buf: &[u8],
342    ) -> Poll<Result<usize, io::Error>> {
343        // `this` here designates the `LengthDelimited`.
344        let mut this = self.project().inner;
345
346        // We need to flush any data previously written with the `LengthDelimited`.
347        match LengthDelimited::poll_write_buffer(this.as_mut(), cx) {
348            Poll::Ready(Ok(())) => {}
349            Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
350            Poll::Pending => return Poll::Pending,
351        }
352        debug_assert!(this.write_buffer.is_empty());
353
354        this.project().inner.poll_write(cx, buf)
355    }
356
357    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
358        self.project().inner.poll_flush(cx)
359    }
360
361    fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
362        self.project().inner.poll_close(cx)
363    }
364
365    fn poll_write_vectored(
366        self: Pin<&mut Self>,
367        cx: &mut Context<'_>,
368        bufs: &[IoSlice<'_>],
369    ) -> Poll<Result<usize, io::Error>> {
370        // `this` here designates the `LengthDelimited`.
371        let mut this = self.project().inner;
372
373        // We need to flush any data previously written with the `LengthDelimited`.
374        match LengthDelimited::poll_write_buffer(this.as_mut(), cx) {
375            Poll::Ready(Ok(())) => {}
376            Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
377            Poll::Pending => return Poll::Pending,
378        }
379        debug_assert!(this.write_buffer.is_empty());
380
381        this.project().inner.poll_write_vectored(cx, bufs)
382    }
383}
384
385#[cfg(test)]
386mod tests {
387    use std::io::ErrorKind;
388
389    use futures::{io::Cursor, prelude::*};
390    use quickcheck::*;
391
392    use crate::length_delimited::LengthDelimited;
393
394    #[test]
395    fn basic_read() {
396        let data = vec![6, 9, 8, 7, 6, 5, 4];
397        let framed = LengthDelimited::new(Cursor::new(data));
398        let recved = futures::executor::block_on(framed.try_collect::<Vec<_>>()).unwrap();
399        assert_eq!(recved, vec![vec![9, 8, 7, 6, 5, 4]]);
400    }
401
402    #[test]
403    fn basic_read_two() {
404        let data = vec![6, 9, 8, 7, 6, 5, 4, 3, 9, 8, 7];
405        let framed = LengthDelimited::new(Cursor::new(data));
406        let recved = futures::executor::block_on(framed.try_collect::<Vec<_>>()).unwrap();
407        assert_eq!(recved, vec![vec![9, 8, 7, 6, 5, 4], vec![9, 8, 7]]);
408    }
409
410    #[test]
411    fn two_bytes_long_packet() {
412        let len = 5000u16;
413        assert!(len < (1 << 15));
414        let frame = (0..len).map(|n| (n & 0xff) as u8).collect::<Vec<_>>();
415        let mut data = vec![(len & 0x7f) as u8 | 0x80, (len >> 7) as u8];
416        data.extend(frame.clone());
417        let mut framed = LengthDelimited::new(Cursor::new(data));
418        let recved = futures::executor::block_on(async move { framed.next().await }).unwrap();
419        assert_eq!(recved.unwrap(), frame);
420    }
421
422    #[test]
423    fn packet_len_too_long() {
424        let mut data = vec![0x81, 0x81, 0x1];
425        data.extend((0..16513).map(|_| 0));
426        let mut framed = LengthDelimited::new(Cursor::new(data));
427        let recved = futures::executor::block_on(async move { framed.next().await.unwrap() });
428
429        if let Err(io_err) = recved {
430            assert_eq!(io_err.kind(), ErrorKind::InvalidData)
431        } else {
432            panic!()
433        }
434    }
435
436    #[test]
437    fn empty_frames() {
438        let data = vec![0, 0, 6, 9, 8, 7, 6, 5, 4, 0, 3, 9, 8, 7];
439        let framed = LengthDelimited::new(Cursor::new(data));
440        let recved = futures::executor::block_on(framed.try_collect::<Vec<_>>()).unwrap();
441        assert_eq!(
442            recved,
443            vec![
444                vec![],
445                vec![],
446                vec![9, 8, 7, 6, 5, 4],
447                vec![],
448                vec![9, 8, 7],
449            ]
450        );
451    }
452
453    #[test]
454    fn unexpected_eof_in_len() {
455        let data = vec![0x89];
456        let framed = LengthDelimited::new(Cursor::new(data));
457        let recved = futures::executor::block_on(framed.try_collect::<Vec<_>>());
458        if let Err(io_err) = recved {
459            assert_eq!(io_err.kind(), ErrorKind::UnexpectedEof)
460        } else {
461            panic!()
462        }
463    }
464
465    #[test]
466    fn unexpected_eof_in_data() {
467        let data = vec![5];
468        let framed = LengthDelimited::new(Cursor::new(data));
469        let recved = futures::executor::block_on(framed.try_collect::<Vec<_>>());
470        if let Err(io_err) = recved {
471            assert_eq!(io_err.kind(), ErrorKind::UnexpectedEof)
472        } else {
473            panic!()
474        }
475    }
476
477    #[test]
478    fn unexpected_eof_in_data2() {
479        let data = vec![5, 9, 8, 7];
480        let framed = LengthDelimited::new(Cursor::new(data));
481        let recved = futures::executor::block_on(framed.try_collect::<Vec<_>>());
482        if let Err(io_err) = recved {
483            assert_eq!(io_err.kind(), ErrorKind::UnexpectedEof)
484        } else {
485            panic!()
486        }
487    }
488
489    #[test]
490    fn writing_reading() {
491        fn prop(frames: Vec<Vec<u8>>) -> TestResult {
492            let (client_connection, server_connection) = futures_ringbuf::Endpoint::pair(100, 100);
493
494            async_std::task::block_on(async move {
495                let expected_frames = frames.clone();
496                let server = async_std::task::spawn(async move {
497                    let mut connec =
498                        rw_stream_sink::RwStreamSink::new(LengthDelimited::new(server_connection));
499
500                    let mut buf = vec![0u8; 0];
501                    for expected in expected_frames {
502                        if expected.is_empty() {
503                            continue;
504                        }
505                        if buf.len() < expected.len() {
506                            buf.resize(expected.len(), 0);
507                        }
508                        let n = connec.read(&mut buf).await.unwrap();
509                        assert_eq!(&buf[..n], &expected[..]);
510                    }
511                });
512
513                let client = async_std::task::spawn(async move {
514                    let mut connec = LengthDelimited::new(client_connection);
515                    for frame in frames {
516                        connec.send(From::from(frame)).await.unwrap();
517                    }
518                });
519
520                server.await;
521                client.await;
522            });
523
524            TestResult::passed()
525        }
526
527        quickcheck(prop as fn(_) -> _)
528    }
529}