multistream_select/
length_delimited.rs

1// Copyright 2017 Parity Technologies (UK) Ltd.
2//
3// Permission is hereby granted, free of charge, to any person obtaining a
4// copy of this software and associated documentation files (the "Software"),
5// to deal in the Software without restriction, including without limitation
6// the rights to use, copy, modify, merge, publish, distribute, sublicense,
7// and/or sell copies of the Software, and to permit persons to whom the
8// Software is furnished to do so, subject to the following conditions:
9//
10// The above copyright notice and this permission notice shall be included in
11// all copies or substantial portions of the Software.
12//
13// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
14// OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
16// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
18// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
19// DEALINGS IN THE SOFTWARE.
20
21use std::{
22    convert::TryFrom as _,
23    io,
24    pin::Pin,
25    task::{Context, Poll},
26};
27
28use bytes::{Buf as _, BufMut as _, Bytes, BytesMut};
29use futures::{io::IoSlice, prelude::*};
30
31const MAX_LEN_BYTES: u16 = 2;
32const MAX_FRAME_SIZE: u16 = (1 << (MAX_LEN_BYTES * 8 - MAX_LEN_BYTES)) - 1;
33const DEFAULT_BUFFER_SIZE: usize = 64;
34
35/// A `Stream` and `Sink` for unsigned-varint length-delimited frames,
36/// wrapping an underlying `AsyncRead + AsyncWrite` I/O resource.
37///
38/// We purposely only support a frame sizes up to 16KiB (2 bytes unsigned varint
39/// frame length). Frames mostly consist in a short protocol name, which is highly
40/// unlikely to be more than 16KiB long.
41#[pin_project::pin_project]
42#[derive(Debug)]
43pub(crate) struct LengthDelimited<R> {
44    /// The inner I/O resource.
45    #[pin]
46    inner: R,
47    /// Read buffer for a single incoming unsigned-varint length-delimited frame.
48    read_buffer: BytesMut,
49    /// Write buffer for outgoing unsigned-varint length-delimited frames.
50    write_buffer: BytesMut,
51    /// The current read state, alternating between reading a frame
52    /// length and reading a frame payload.
53    read_state: ReadState,
54}
55
56#[derive(Debug, Copy, Clone, PartialEq, Eq)]
57enum ReadState {
58    /// We are currently reading the length of the next frame of data.
59    ReadLength {
60        buf: [u8; MAX_LEN_BYTES as usize],
61        pos: usize,
62    },
63    /// We are currently reading the frame of data itself.
64    ReadData { len: u16, pos: usize },
65}
66
67impl Default for ReadState {
68    fn default() -> Self {
69        ReadState::ReadLength {
70            buf: [0; MAX_LEN_BYTES as usize],
71            pos: 0,
72        }
73    }
74}
75
76impl<R> LengthDelimited<R> {
77    /// Creates a new I/O resource for reading and writing unsigned-varint
78    /// length delimited frames.
79    pub(crate) fn new(inner: R) -> LengthDelimited<R> {
80        LengthDelimited {
81            inner,
82            read_state: ReadState::default(),
83            read_buffer: BytesMut::with_capacity(DEFAULT_BUFFER_SIZE),
84            write_buffer: BytesMut::with_capacity(DEFAULT_BUFFER_SIZE + MAX_LEN_BYTES as usize),
85        }
86    }
87
88    /// Drops the [`LengthDelimited`] resource, yielding the underlying I/O stream.
89    ///
90    /// # Panic
91    ///
92    /// Will panic if called while there is data in the read or write buffer.
93    /// The read buffer is guaranteed to be empty whenever `Stream::poll` yields
94    /// a new `Bytes` frame. The write buffer is guaranteed to be empty after
95    /// flushing.
96    pub(crate) fn into_inner(self) -> R {
97        assert!(self.read_buffer.is_empty());
98        assert!(self.write_buffer.is_empty());
99        self.inner
100    }
101
102    /// Converts the [`LengthDelimited`] into a [`LengthDelimitedReader`], dropping the
103    /// uvi-framed `Sink` in favour of direct `AsyncWrite` access to the underlying
104    /// I/O stream.
105    ///
106    /// This is typically done if further uvi-framed messages are expected to be
107    /// received but no more such messages are written, allowing the writing of
108    /// follow-up protocol data to commence.
109    pub(crate) fn into_reader(self) -> LengthDelimitedReader<R> {
110        LengthDelimitedReader { inner: self }
111    }
112
113    /// Writes all buffered frame data to the underlying I/O stream,
114    /// _without flushing it_.
115    ///
116    /// After this method returns `Poll::Ready`, the write buffer of frames
117    /// submitted to the `Sink` is guaranteed to be empty.
118    fn poll_write_buffer(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>>
119    where
120        R: AsyncWrite,
121    {
122        let mut this = self.project();
123
124        while !this.write_buffer.is_empty() {
125            match this.inner.as_mut().poll_write(cx, this.write_buffer) {
126                Poll::Pending => return Poll::Pending,
127                Poll::Ready(Ok(0)) => {
128                    return Poll::Ready(Err(io::Error::new(
129                        io::ErrorKind::WriteZero,
130                        "Failed to write buffered frame.",
131                    )))
132                }
133                Poll::Ready(Ok(n)) => this.write_buffer.advance(n),
134                Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
135            }
136        }
137
138        Poll::Ready(Ok(()))
139    }
140}
141
142impl<R> Stream for LengthDelimited<R>
143where
144    R: AsyncRead,
145{
146    type Item = Result<Bytes, io::Error>;
147
148    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
149        let mut this = self.project();
150
151        loop {
152            match this.read_state {
153                ReadState::ReadLength { buf, pos } => {
154                    match this.inner.as_mut().poll_read(cx, &mut buf[*pos..*pos + 1]) {
155                        Poll::Ready(Ok(0)) => {
156                            if *pos == 0 {
157                                return Poll::Ready(None);
158                            } else {
159                                return Poll::Ready(Some(Err(io::ErrorKind::UnexpectedEof.into())));
160                            }
161                        }
162                        Poll::Ready(Ok(n)) => {
163                            debug_assert_eq!(n, 1);
164                            *pos += n;
165                        }
166                        Poll::Ready(Err(err)) => return Poll::Ready(Some(Err(err))),
167                        Poll::Pending => return Poll::Pending,
168                    };
169
170                    if (buf[*pos - 1] & 0x80) == 0 {
171                        // MSB is not set, indicating the end of the length prefix.
172                        let (len, _) = unsigned_varint::decode::u16(buf).map_err(|e| {
173                            tracing::debug!("invalid length prefix: {e}");
174                            io::Error::new(io::ErrorKind::InvalidData, "invalid length prefix")
175                        })?;
176
177                        if len >= 1 {
178                            *this.read_state = ReadState::ReadData { len, pos: 0 };
179                            this.read_buffer.resize(len as usize, 0);
180                        } else {
181                            debug_assert_eq!(len, 0);
182                            *this.read_state = ReadState::default();
183                            return Poll::Ready(Some(Ok(Bytes::new())));
184                        }
185                    } else if *pos == MAX_LEN_BYTES as usize {
186                        // MSB signals more length bytes but we have already read the maximum.
187                        // See the module documentation about the max frame len.
188                        return Poll::Ready(Some(Err(io::Error::new(
189                            io::ErrorKind::InvalidData,
190                            "Maximum frame length exceeded",
191                        ))));
192                    }
193                }
194                ReadState::ReadData { len, pos } => {
195                    match this
196                        .inner
197                        .as_mut()
198                        .poll_read(cx, &mut this.read_buffer[*pos..])
199                    {
200                        Poll::Ready(Ok(0)) => {
201                            return Poll::Ready(Some(Err(io::ErrorKind::UnexpectedEof.into())))
202                        }
203                        Poll::Ready(Ok(n)) => *pos += n,
204                        Poll::Pending => return Poll::Pending,
205                        Poll::Ready(Err(err)) => return Poll::Ready(Some(Err(err))),
206                    };
207
208                    if *pos == *len as usize {
209                        // Finished reading the frame.
210                        let frame = this.read_buffer.split_off(0).freeze();
211                        *this.read_state = ReadState::default();
212                        return Poll::Ready(Some(Ok(frame)));
213                    }
214                }
215            }
216        }
217    }
218}
219
220impl<R> Sink<Bytes> for LengthDelimited<R>
221where
222    R: AsyncWrite,
223{
224    type Error = io::Error;
225
226    fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
227        // Use the maximum frame length also as a (soft) upper limit
228        // for the entire write buffer. The actual (hard) limit is thus
229        // implied to be roughly 2 * MAX_FRAME_SIZE.
230        if self.as_mut().project().write_buffer.len() >= MAX_FRAME_SIZE as usize {
231            match self.as_mut().poll_write_buffer(cx) {
232                Poll::Ready(Ok(())) => {}
233                Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
234                Poll::Pending => return Poll::Pending,
235            }
236
237            debug_assert!(self.as_mut().project().write_buffer.is_empty());
238        }
239
240        Poll::Ready(Ok(()))
241    }
242
243    fn start_send(self: Pin<&mut Self>, item: Bytes) -> Result<(), Self::Error> {
244        let this = self.project();
245
246        let len = match u16::try_from(item.len()) {
247            Ok(len) if len <= MAX_FRAME_SIZE => len,
248            _ => {
249                return Err(io::Error::new(
250                    io::ErrorKind::InvalidData,
251                    "Maximum frame size exceeded.",
252                ))
253            }
254        };
255
256        let mut uvi_buf = unsigned_varint::encode::u16_buffer();
257        let uvi_len = unsigned_varint::encode::u16(len, &mut uvi_buf);
258        this.write_buffer.reserve(len as usize + uvi_len.len());
259        this.write_buffer.put(uvi_len);
260        this.write_buffer.put(item);
261
262        Ok(())
263    }
264
265    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
266        // Write all buffered frame data to the underlying I/O stream.
267        match LengthDelimited::poll_write_buffer(self.as_mut(), cx) {
268            Poll::Ready(Ok(())) => {}
269            Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
270            Poll::Pending => return Poll::Pending,
271        }
272
273        let this = self.project();
274        debug_assert!(this.write_buffer.is_empty());
275
276        // Flush the underlying I/O stream.
277        this.inner.poll_flush(cx)
278    }
279
280    fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
281        // Write all buffered frame data to the underlying I/O stream.
282        match LengthDelimited::poll_write_buffer(self.as_mut(), cx) {
283            Poll::Ready(Ok(())) => {}
284            Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
285            Poll::Pending => return Poll::Pending,
286        }
287
288        let this = self.project();
289        debug_assert!(this.write_buffer.is_empty());
290
291        // Close the underlying I/O stream.
292        this.inner.poll_close(cx)
293    }
294}
295
296/// A `LengthDelimitedReader` implements a `Stream` of uvi-length-delimited
297/// frames on an underlying I/O resource combined with direct `AsyncWrite` access.
298#[pin_project::pin_project]
299#[derive(Debug)]
300pub(crate) struct LengthDelimitedReader<R> {
301    #[pin]
302    inner: LengthDelimited<R>,
303}
304
305impl<R> LengthDelimitedReader<R> {
306    /// Destroys the `LengthDelimitedReader` and returns the underlying I/O stream.
307    ///
308    /// This method is guaranteed not to drop any data read from or not yet
309    /// submitted to the underlying I/O stream.
310    ///
311    /// # Panic
312    ///
313    /// Will panic if called while there is data in the read or write buffer.
314    /// The read buffer is guaranteed to be empty whenever [`Stream::poll_next`]
315    /// yield a new `Message`. The write buffer is guaranteed to be empty whenever
316    /// [`LengthDelimited::poll_write_buffer`] yields [`Poll::Ready`] or after
317    /// the [`Sink`] has been completely flushed via [`Sink::poll_flush`].
318    pub(crate) fn into_inner(self) -> R {
319        self.inner.into_inner()
320    }
321}
322
323impl<R> Stream for LengthDelimitedReader<R>
324where
325    R: AsyncRead,
326{
327    type Item = Result<Bytes, io::Error>;
328
329    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
330        self.project().inner.poll_next(cx)
331    }
332}
333
334impl<R> AsyncWrite for LengthDelimitedReader<R>
335where
336    R: AsyncWrite,
337{
338    fn poll_write(
339        self: Pin<&mut Self>,
340        cx: &mut Context<'_>,
341        buf: &[u8],
342    ) -> Poll<Result<usize, io::Error>> {
343        // `this` here designates the `LengthDelimited`.
344        let mut this = self.project().inner;
345
346        // We need to flush any data previously written with the `LengthDelimited`.
347        match LengthDelimited::poll_write_buffer(this.as_mut(), cx) {
348            Poll::Ready(Ok(())) => {}
349            Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
350            Poll::Pending => return Poll::Pending,
351        }
352        debug_assert!(this.write_buffer.is_empty());
353
354        this.project().inner.poll_write(cx, buf)
355    }
356
357    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
358        self.project().inner.poll_flush(cx)
359    }
360
361    fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
362        self.project().inner.poll_close(cx)
363    }
364
365    fn poll_write_vectored(
366        self: Pin<&mut Self>,
367        cx: &mut Context<'_>,
368        bufs: &[IoSlice<'_>],
369    ) -> Poll<Result<usize, io::Error>> {
370        // `this` here designates the `LengthDelimited`.
371        let mut this = self.project().inner;
372
373        // We need to flush any data previously written with the `LengthDelimited`.
374        match LengthDelimited::poll_write_buffer(this.as_mut(), cx) {
375            Poll::Ready(Ok(())) => {}
376            Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
377            Poll::Pending => return Poll::Pending,
378        }
379        debug_assert!(this.write_buffer.is_empty());
380
381        this.project().inner.poll_write_vectored(cx, bufs)
382    }
383}
384
385#[cfg(test)]
386mod tests {
387    use std::io::ErrorKind;
388
389    use futures::{io::Cursor, prelude::*};
390    use quickcheck::*;
391    use tokio::runtime::Runtime;
392
393    use crate::length_delimited::LengthDelimited;
394
395    #[test]
396    fn basic_read() {
397        let data = vec![6, 9, 8, 7, 6, 5, 4];
398        let framed = LengthDelimited::new(Cursor::new(data));
399        let recved = futures::executor::block_on(framed.try_collect::<Vec<_>>()).unwrap();
400        assert_eq!(recved, vec![vec![9, 8, 7, 6, 5, 4]]);
401    }
402
403    #[test]
404    fn basic_read_two() {
405        let data = vec![6, 9, 8, 7, 6, 5, 4, 3, 9, 8, 7];
406        let framed = LengthDelimited::new(Cursor::new(data));
407        let recved = futures::executor::block_on(framed.try_collect::<Vec<_>>()).unwrap();
408        assert_eq!(recved, vec![vec![9, 8, 7, 6, 5, 4], vec![9, 8, 7]]);
409    }
410
411    #[test]
412    fn two_bytes_long_packet() {
413        let len = 5000u16;
414        assert!(len < (1 << 15));
415        let frame = (0..len).map(|n| (n & 0xff) as u8).collect::<Vec<_>>();
416        let mut data = vec![(len & 0x7f) as u8 | 0x80, (len >> 7) as u8];
417        data.extend(frame.clone());
418        let mut framed = LengthDelimited::new(Cursor::new(data));
419        let recved = futures::executor::block_on(async move { framed.next().await }).unwrap();
420        assert_eq!(recved.unwrap(), frame);
421    }
422
423    #[test]
424    fn packet_len_too_long() {
425        let mut data = vec![0x81, 0x81, 0x1];
426        data.extend((0..16513).map(|_| 0));
427        let mut framed = LengthDelimited::new(Cursor::new(data));
428        let recved = futures::executor::block_on(async move { framed.next().await.unwrap() });
429
430        if let Err(io_err) = recved {
431            assert_eq!(io_err.kind(), ErrorKind::InvalidData)
432        } else {
433            panic!()
434        }
435    }
436
437    #[test]
438    fn empty_frames() {
439        let data = vec![0, 0, 6, 9, 8, 7, 6, 5, 4, 0, 3, 9, 8, 7];
440        let framed = LengthDelimited::new(Cursor::new(data));
441        let recved = futures::executor::block_on(framed.try_collect::<Vec<_>>()).unwrap();
442        assert_eq!(
443            recved,
444            vec![
445                vec![],
446                vec![],
447                vec![9, 8, 7, 6, 5, 4],
448                vec![],
449                vec![9, 8, 7],
450            ]
451        );
452    }
453
454    #[test]
455    fn unexpected_eof_in_len() {
456        let data = vec![0x89];
457        let framed = LengthDelimited::new(Cursor::new(data));
458        let recved = futures::executor::block_on(framed.try_collect::<Vec<_>>());
459        if let Err(io_err) = recved {
460            assert_eq!(io_err.kind(), ErrorKind::UnexpectedEof)
461        } else {
462            panic!()
463        }
464    }
465
466    #[test]
467    fn unexpected_eof_in_data() {
468        let data = vec![5];
469        let framed = LengthDelimited::new(Cursor::new(data));
470        let recved = futures::executor::block_on(framed.try_collect::<Vec<_>>());
471        if let Err(io_err) = recved {
472            assert_eq!(io_err.kind(), ErrorKind::UnexpectedEof)
473        } else {
474            panic!()
475        }
476    }
477
478    #[test]
479    fn unexpected_eof_in_data2() {
480        let data = vec![5, 9, 8, 7];
481        let framed = LengthDelimited::new(Cursor::new(data));
482        let recved = futures::executor::block_on(framed.try_collect::<Vec<_>>());
483        if let Err(io_err) = recved {
484            assert_eq!(io_err.kind(), ErrorKind::UnexpectedEof)
485        } else {
486            panic!()
487        }
488    }
489
490    #[test]
491    fn writing_reading() {
492        fn prop(frames: Vec<Vec<u8>>) -> TestResult {
493            let (client_connection, server_connection) = futures_ringbuf::Endpoint::pair(100, 100);
494
495            let rt = Runtime::new().unwrap();
496            rt.block_on(async move {
497                let expected_frames = frames.clone();
498                let server = tokio::task::spawn(async move {
499                    let mut connec =
500                        rw_stream_sink::RwStreamSink::new(LengthDelimited::new(server_connection));
501
502                    let mut buf = vec![0u8; 0];
503                    for expected in expected_frames {
504                        if expected.is_empty() {
505                            continue;
506                        }
507                        if buf.len() < expected.len() {
508                            buf.resize(expected.len(), 0);
509                        }
510                        let n = connec.read(&mut buf).await.unwrap();
511                        assert_eq!(&buf[..n], &expected[..]);
512                    }
513                });
514
515                let client = tokio::task::spawn(async move {
516                    let mut connec = LengthDelimited::new(client_connection);
517                    for frame in frames {
518                        connec.send(From::from(frame)).await.unwrap();
519                    }
520                });
521
522                server.await.unwrap();
523                client.await.unwrap();
524            });
525
526            TestResult::passed()
527        }
528
529        quickcheck(prop as fn(_) -> _)
530    }
531}