1use std::{
22 convert::TryFrom as _,
23 io,
24 pin::Pin,
25 task::{Context, Poll},
26};
27
28use bytes::{Buf as _, BufMut as _, Bytes, BytesMut};
29use futures::{io::IoSlice, prelude::*};
30
31const MAX_LEN_BYTES: u16 = 2;
32const MAX_FRAME_SIZE: u16 = (1 << (MAX_LEN_BYTES * 8 - MAX_LEN_BYTES)) - 1;
33const DEFAULT_BUFFER_SIZE: usize = 64;
34
35#[pin_project::pin_project]
42#[derive(Debug)]
43pub(crate) struct LengthDelimited<R> {
44 #[pin]
46 inner: R,
47 read_buffer: BytesMut,
49 write_buffer: BytesMut,
51 read_state: ReadState,
54}
55
56#[derive(Debug, Copy, Clone, PartialEq, Eq)]
57enum ReadState {
58 ReadLength {
60 buf: [u8; MAX_LEN_BYTES as usize],
61 pos: usize,
62 },
63 ReadData { len: u16, pos: usize },
65}
66
67impl Default for ReadState {
68 fn default() -> Self {
69 ReadState::ReadLength {
70 buf: [0; MAX_LEN_BYTES as usize],
71 pos: 0,
72 }
73 }
74}
75
76impl<R> LengthDelimited<R> {
77 pub(crate) fn new(inner: R) -> LengthDelimited<R> {
80 LengthDelimited {
81 inner,
82 read_state: ReadState::default(),
83 read_buffer: BytesMut::with_capacity(DEFAULT_BUFFER_SIZE),
84 write_buffer: BytesMut::with_capacity(DEFAULT_BUFFER_SIZE + MAX_LEN_BYTES as usize),
85 }
86 }
87
88 pub(crate) fn into_inner(self) -> R {
97 assert!(self.read_buffer.is_empty());
98 assert!(self.write_buffer.is_empty());
99 self.inner
100 }
101
102 pub(crate) fn into_reader(self) -> LengthDelimitedReader<R> {
110 LengthDelimitedReader { inner: self }
111 }
112
113 fn poll_write_buffer(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>>
119 where
120 R: AsyncWrite,
121 {
122 let mut this = self.project();
123
124 while !this.write_buffer.is_empty() {
125 match this.inner.as_mut().poll_write(cx, this.write_buffer) {
126 Poll::Pending => return Poll::Pending,
127 Poll::Ready(Ok(0)) => {
128 return Poll::Ready(Err(io::Error::new(
129 io::ErrorKind::WriteZero,
130 "Failed to write buffered frame.",
131 )))
132 }
133 Poll::Ready(Ok(n)) => this.write_buffer.advance(n),
134 Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
135 }
136 }
137
138 Poll::Ready(Ok(()))
139 }
140}
141
142impl<R> Stream for LengthDelimited<R>
143where
144 R: AsyncRead,
145{
146 type Item = Result<Bytes, io::Error>;
147
148 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
149 let mut this = self.project();
150
151 loop {
152 match this.read_state {
153 ReadState::ReadLength { buf, pos } => {
154 match this.inner.as_mut().poll_read(cx, &mut buf[*pos..*pos + 1]) {
155 Poll::Ready(Ok(0)) => {
156 if *pos == 0 {
157 return Poll::Ready(None);
158 } else {
159 return Poll::Ready(Some(Err(io::ErrorKind::UnexpectedEof.into())));
160 }
161 }
162 Poll::Ready(Ok(n)) => {
163 debug_assert_eq!(n, 1);
164 *pos += n;
165 }
166 Poll::Ready(Err(err)) => return Poll::Ready(Some(Err(err))),
167 Poll::Pending => return Poll::Pending,
168 };
169
170 if (buf[*pos - 1] & 0x80) == 0 {
171 let (len, _) = unsigned_varint::decode::u16(buf).map_err(|e| {
173 tracing::debug!("invalid length prefix: {e}");
174 io::Error::new(io::ErrorKind::InvalidData, "invalid length prefix")
175 })?;
176
177 if len >= 1 {
178 *this.read_state = ReadState::ReadData { len, pos: 0 };
179 this.read_buffer.resize(len as usize, 0);
180 } else {
181 debug_assert_eq!(len, 0);
182 *this.read_state = ReadState::default();
183 return Poll::Ready(Some(Ok(Bytes::new())));
184 }
185 } else if *pos == MAX_LEN_BYTES as usize {
186 return Poll::Ready(Some(Err(io::Error::new(
189 io::ErrorKind::InvalidData,
190 "Maximum frame length exceeded",
191 ))));
192 }
193 }
194 ReadState::ReadData { len, pos } => {
195 match this
196 .inner
197 .as_mut()
198 .poll_read(cx, &mut this.read_buffer[*pos..])
199 {
200 Poll::Ready(Ok(0)) => {
201 return Poll::Ready(Some(Err(io::ErrorKind::UnexpectedEof.into())))
202 }
203 Poll::Ready(Ok(n)) => *pos += n,
204 Poll::Pending => return Poll::Pending,
205 Poll::Ready(Err(err)) => return Poll::Ready(Some(Err(err))),
206 };
207
208 if *pos == *len as usize {
209 let frame = this.read_buffer.split_off(0).freeze();
211 *this.read_state = ReadState::default();
212 return Poll::Ready(Some(Ok(frame)));
213 }
214 }
215 }
216 }
217 }
218}
219
220impl<R> Sink<Bytes> for LengthDelimited<R>
221where
222 R: AsyncWrite,
223{
224 type Error = io::Error;
225
226 fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
227 if self.as_mut().project().write_buffer.len() >= MAX_FRAME_SIZE as usize {
231 match self.as_mut().poll_write_buffer(cx) {
232 Poll::Ready(Ok(())) => {}
233 Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
234 Poll::Pending => return Poll::Pending,
235 }
236
237 debug_assert!(self.as_mut().project().write_buffer.is_empty());
238 }
239
240 Poll::Ready(Ok(()))
241 }
242
243 fn start_send(self: Pin<&mut Self>, item: Bytes) -> Result<(), Self::Error> {
244 let this = self.project();
245
246 let len = match u16::try_from(item.len()) {
247 Ok(len) if len <= MAX_FRAME_SIZE => len,
248 _ => {
249 return Err(io::Error::new(
250 io::ErrorKind::InvalidData,
251 "Maximum frame size exceeded.",
252 ))
253 }
254 };
255
256 let mut uvi_buf = unsigned_varint::encode::u16_buffer();
257 let uvi_len = unsigned_varint::encode::u16(len, &mut uvi_buf);
258 this.write_buffer.reserve(len as usize + uvi_len.len());
259 this.write_buffer.put(uvi_len);
260 this.write_buffer.put(item);
261
262 Ok(())
263 }
264
265 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
266 match LengthDelimited::poll_write_buffer(self.as_mut(), cx) {
268 Poll::Ready(Ok(())) => {}
269 Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
270 Poll::Pending => return Poll::Pending,
271 }
272
273 let this = self.project();
274 debug_assert!(this.write_buffer.is_empty());
275
276 this.inner.poll_flush(cx)
278 }
279
280 fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
281 match LengthDelimited::poll_write_buffer(self.as_mut(), cx) {
283 Poll::Ready(Ok(())) => {}
284 Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
285 Poll::Pending => return Poll::Pending,
286 }
287
288 let this = self.project();
289 debug_assert!(this.write_buffer.is_empty());
290
291 this.inner.poll_close(cx)
293 }
294}
295
296#[pin_project::pin_project]
299#[derive(Debug)]
300pub(crate) struct LengthDelimitedReader<R> {
301 #[pin]
302 inner: LengthDelimited<R>,
303}
304
305impl<R> LengthDelimitedReader<R> {
306 pub(crate) fn into_inner(self) -> R {
319 self.inner.into_inner()
320 }
321}
322
323impl<R> Stream for LengthDelimitedReader<R>
324where
325 R: AsyncRead,
326{
327 type Item = Result<Bytes, io::Error>;
328
329 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
330 self.project().inner.poll_next(cx)
331 }
332}
333
334impl<R> AsyncWrite for LengthDelimitedReader<R>
335where
336 R: AsyncWrite,
337{
338 fn poll_write(
339 self: Pin<&mut Self>,
340 cx: &mut Context<'_>,
341 buf: &[u8],
342 ) -> Poll<Result<usize, io::Error>> {
343 let mut this = self.project().inner;
345
346 match LengthDelimited::poll_write_buffer(this.as_mut(), cx) {
348 Poll::Ready(Ok(())) => {}
349 Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
350 Poll::Pending => return Poll::Pending,
351 }
352 debug_assert!(this.write_buffer.is_empty());
353
354 this.project().inner.poll_write(cx, buf)
355 }
356
357 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
358 self.project().inner.poll_flush(cx)
359 }
360
361 fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
362 self.project().inner.poll_close(cx)
363 }
364
365 fn poll_write_vectored(
366 self: Pin<&mut Self>,
367 cx: &mut Context<'_>,
368 bufs: &[IoSlice<'_>],
369 ) -> Poll<Result<usize, io::Error>> {
370 let mut this = self.project().inner;
372
373 match LengthDelimited::poll_write_buffer(this.as_mut(), cx) {
375 Poll::Ready(Ok(())) => {}
376 Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
377 Poll::Pending => return Poll::Pending,
378 }
379 debug_assert!(this.write_buffer.is_empty());
380
381 this.project().inner.poll_write_vectored(cx, bufs)
382 }
383}
384
385#[cfg(test)]
386mod tests {
387 use std::io::ErrorKind;
388
389 use futures::{io::Cursor, prelude::*};
390 use quickcheck::*;
391
392 use crate::length_delimited::LengthDelimited;
393
394 #[test]
395 fn basic_read() {
396 let data = vec![6, 9, 8, 7, 6, 5, 4];
397 let framed = LengthDelimited::new(Cursor::new(data));
398 let recved = futures::executor::block_on(framed.try_collect::<Vec<_>>()).unwrap();
399 assert_eq!(recved, vec![vec![9, 8, 7, 6, 5, 4]]);
400 }
401
402 #[test]
403 fn basic_read_two() {
404 let data = vec![6, 9, 8, 7, 6, 5, 4, 3, 9, 8, 7];
405 let framed = LengthDelimited::new(Cursor::new(data));
406 let recved = futures::executor::block_on(framed.try_collect::<Vec<_>>()).unwrap();
407 assert_eq!(recved, vec![vec![9, 8, 7, 6, 5, 4], vec![9, 8, 7]]);
408 }
409
410 #[test]
411 fn two_bytes_long_packet() {
412 let len = 5000u16;
413 assert!(len < (1 << 15));
414 let frame = (0..len).map(|n| (n & 0xff) as u8).collect::<Vec<_>>();
415 let mut data = vec![(len & 0x7f) as u8 | 0x80, (len >> 7) as u8];
416 data.extend(frame.clone());
417 let mut framed = LengthDelimited::new(Cursor::new(data));
418 let recved = futures::executor::block_on(async move { framed.next().await }).unwrap();
419 assert_eq!(recved.unwrap(), frame);
420 }
421
422 #[test]
423 fn packet_len_too_long() {
424 let mut data = vec![0x81, 0x81, 0x1];
425 data.extend((0..16513).map(|_| 0));
426 let mut framed = LengthDelimited::new(Cursor::new(data));
427 let recved = futures::executor::block_on(async move { framed.next().await.unwrap() });
428
429 if let Err(io_err) = recved {
430 assert_eq!(io_err.kind(), ErrorKind::InvalidData)
431 } else {
432 panic!()
433 }
434 }
435
436 #[test]
437 fn empty_frames() {
438 let data = vec![0, 0, 6, 9, 8, 7, 6, 5, 4, 0, 3, 9, 8, 7];
439 let framed = LengthDelimited::new(Cursor::new(data));
440 let recved = futures::executor::block_on(framed.try_collect::<Vec<_>>()).unwrap();
441 assert_eq!(
442 recved,
443 vec![
444 vec![],
445 vec![],
446 vec![9, 8, 7, 6, 5, 4],
447 vec![],
448 vec![9, 8, 7],
449 ]
450 );
451 }
452
453 #[test]
454 fn unexpected_eof_in_len() {
455 let data = vec![0x89];
456 let framed = LengthDelimited::new(Cursor::new(data));
457 let recved = futures::executor::block_on(framed.try_collect::<Vec<_>>());
458 if let Err(io_err) = recved {
459 assert_eq!(io_err.kind(), ErrorKind::UnexpectedEof)
460 } else {
461 panic!()
462 }
463 }
464
465 #[test]
466 fn unexpected_eof_in_data() {
467 let data = vec![5];
468 let framed = LengthDelimited::new(Cursor::new(data));
469 let recved = futures::executor::block_on(framed.try_collect::<Vec<_>>());
470 if let Err(io_err) = recved {
471 assert_eq!(io_err.kind(), ErrorKind::UnexpectedEof)
472 } else {
473 panic!()
474 }
475 }
476
477 #[test]
478 fn unexpected_eof_in_data2() {
479 let data = vec![5, 9, 8, 7];
480 let framed = LengthDelimited::new(Cursor::new(data));
481 let recved = futures::executor::block_on(framed.try_collect::<Vec<_>>());
482 if let Err(io_err) = recved {
483 assert_eq!(io_err.kind(), ErrorKind::UnexpectedEof)
484 } else {
485 panic!()
486 }
487 }
488
489 #[test]
490 fn writing_reading() {
491 fn prop(frames: Vec<Vec<u8>>) -> TestResult {
492 let (client_connection, server_connection) = futures_ringbuf::Endpoint::pair(100, 100);
493
494 async_std::task::block_on(async move {
495 let expected_frames = frames.clone();
496 let server = async_std::task::spawn(async move {
497 let mut connec =
498 rw_stream_sink::RwStreamSink::new(LengthDelimited::new(server_connection));
499
500 let mut buf = vec![0u8; 0];
501 for expected in expected_frames {
502 if expected.is_empty() {
503 continue;
504 }
505 if buf.len() < expected.len() {
506 buf.resize(expected.len(), 0);
507 }
508 let n = connec.read(&mut buf).await.unwrap();
509 assert_eq!(&buf[..n], &expected[..]);
510 }
511 });
512
513 let client = async_std::task::spawn(async move {
514 let mut connec = LengthDelimited::new(client_connection);
515 for frame in frames {
516 connec.send(From::from(frame)).await.unwrap();
517 }
518 });
519
520 server.await;
521 client.await;
522 });
523
524 TestResult::passed()
525 }
526
527 quickcheck(prop as fn(_) -> _)
528 }
529}