quick_protobuf_codec/
lib.rs

1#![cfg_attr(docsrs, feature(doc_cfg, doc_auto_cfg))]
2
3use std::{io, marker::PhantomData};
4
5use asynchronous_codec::{Decoder, Encoder};
6use bytes::{Buf, BufMut, BytesMut};
7use quick_protobuf::{BytesReader, MessageRead, MessageWrite, Writer, WriterBackend};
8
9mod generated;
10
11#[doc(hidden)] // NOT public API. Do not use.
12pub use generated::test as proto;
13
14/// [`Codec`] implements [`Encoder`] and [`Decoder`], uses [`unsigned_varint`]
15///
16/// to prefix messages with their length and uses [`quick_protobuf`] and a provided
17/// `struct` implementing [`MessageRead`] and [`MessageWrite`] to do the encoding.
18pub struct Codec<In, Out = In> {
19    max_message_len_bytes: usize,
20    phantom: PhantomData<(In, Out)>,
21}
22
23impl<In, Out> Codec<In, Out> {
24    /// Create new [`Codec`].
25    ///
26    /// Parameter `max_message_len_bytes` determines the maximum length of the
27    /// Protobuf message. The limit does not include the bytes needed for the
28    /// [`unsigned_varint`].
29    pub fn new(max_message_len_bytes: usize) -> Self {
30        Self {
31            max_message_len_bytes,
32            phantom: PhantomData,
33        }
34    }
35}
36
37impl<In: MessageWrite, Out> Encoder for Codec<In, Out> {
38    type Item<'a> = In;
39    type Error = Error;
40
41    fn encode(&mut self, item: Self::Item<'_>, dst: &mut BytesMut) -> Result<(), Self::Error> {
42        write_length(&item, dst);
43        write_message(&item, dst)?;
44
45        Ok(())
46    }
47}
48
49/// Write the message's length (i.e. `size`) to `dst` as a variable-length integer.
50fn write_length(message: &impl MessageWrite, dst: &mut BytesMut) {
51    let message_length = message.get_size();
52
53    let mut uvi_buf = unsigned_varint::encode::usize_buffer();
54    let encoded_length = unsigned_varint::encode::usize(message_length, &mut uvi_buf);
55
56    dst.extend_from_slice(encoded_length);
57}
58
59/// Write the message itself to `dst`.
60fn write_message(item: &impl MessageWrite, dst: &mut BytesMut) -> io::Result<()> {
61    let mut writer = Writer::new(BytesMutWriterBackend::new(dst));
62    item.write_message(&mut writer)
63        .map_err(|e| io::Error::new(io::ErrorKind::Other, e))?;
64
65    Ok(())
66}
67
68impl<In, Out> Decoder for Codec<In, Out>
69where
70    Out: for<'a> MessageRead<'a>,
71{
72    type Item = Out;
73    type Error = Error;
74
75    fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
76        let (message_length, remaining) = match unsigned_varint::decode::usize(src) {
77            Ok((len, remaining)) => (len, remaining),
78            Err(unsigned_varint::decode::Error::Insufficient) => return Ok(None),
79            Err(e) => return Err(Error(io::Error::new(io::ErrorKind::InvalidData, e))),
80        };
81
82        if message_length > self.max_message_len_bytes {
83            return Err(Error(io::Error::new(
84                io::ErrorKind::PermissionDenied,
85                format!(
86                    "message with {message_length}b exceeds maximum of {}b",
87                    self.max_message_len_bytes
88                ),
89            )));
90        }
91
92        // Compute how many bytes the varint itself consumed.
93        let varint_length = src.len() - remaining.len();
94
95        // Ensure we can read an entire message.
96        if src.len() < (message_length + varint_length) {
97            return Ok(None);
98        }
99
100        // Safe to advance buffer now.
101        src.advance(varint_length);
102
103        let message = src.split_to(message_length);
104
105        let mut reader = BytesReader::from_bytes(&message);
106        let message = Self::Item::from_reader(&mut reader, &message)
107            .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?;
108
109        Ok(Some(message))
110    }
111}
112
113struct BytesMutWriterBackend<'a> {
114    dst: &'a mut BytesMut,
115}
116
117impl<'a> BytesMutWriterBackend<'a> {
118    fn new(dst: &'a mut BytesMut) -> Self {
119        Self { dst }
120    }
121}
122
123impl WriterBackend for BytesMutWriterBackend<'_> {
124    fn pb_write_u8(&mut self, x: u8) -> quick_protobuf::Result<()> {
125        self.dst.put_u8(x);
126
127        Ok(())
128    }
129
130    fn pb_write_u32(&mut self, x: u32) -> quick_protobuf::Result<()> {
131        self.dst.put_u32_le(x);
132
133        Ok(())
134    }
135
136    fn pb_write_i32(&mut self, x: i32) -> quick_protobuf::Result<()> {
137        self.dst.put_i32_le(x);
138
139        Ok(())
140    }
141
142    fn pb_write_f32(&mut self, x: f32) -> quick_protobuf::Result<()> {
143        self.dst.put_f32_le(x);
144
145        Ok(())
146    }
147
148    fn pb_write_u64(&mut self, x: u64) -> quick_protobuf::Result<()> {
149        self.dst.put_u64_le(x);
150
151        Ok(())
152    }
153
154    fn pb_write_i64(&mut self, x: i64) -> quick_protobuf::Result<()> {
155        self.dst.put_i64_le(x);
156
157        Ok(())
158    }
159
160    fn pb_write_f64(&mut self, x: f64) -> quick_protobuf::Result<()> {
161        self.dst.put_f64_le(x);
162
163        Ok(())
164    }
165
166    fn pb_write_all(&mut self, buf: &[u8]) -> quick_protobuf::Result<()> {
167        self.dst.put_slice(buf);
168
169        Ok(())
170    }
171}
172
173#[derive(thiserror::Error, Debug)]
174#[error("Failed to encode/decode message")]
175pub struct Error(#[from] io::Error);
176
177impl From<Error> for io::Error {
178    fn from(e: Error) -> Self {
179        e.0
180    }
181}
182
183#[cfg(test)]
184mod tests {
185    use std::error::Error;
186
187    use asynchronous_codec::FramedRead;
188    use futures::{io::Cursor, FutureExt, StreamExt};
189    use quickcheck::{Arbitrary, Gen, QuickCheck};
190
191    use super::*;
192
193    #[test]
194    fn honors_max_message_length() {
195        let codec = Codec::<Dummy>::new(1);
196        let mut src = varint_zeroes(100);
197
198        let mut read = FramedRead::new(Cursor::new(&mut src), codec);
199        let err = read.next().now_or_never().unwrap().unwrap().unwrap_err();
200
201        assert_eq!(
202            err.source().unwrap().to_string(),
203            "message with 100b exceeds maximum of 1b"
204        )
205    }
206
207    #[test]
208    fn empty_bytes_mut_does_not_panic() {
209        let mut codec = Codec::<Dummy>::new(100);
210
211        let mut src = varint_zeroes(100);
212        src.truncate(50);
213
214        let result = codec.decode(&mut src);
215
216        assert!(result.unwrap().is_none());
217        assert_eq!(
218            src.len(),
219            50,
220            "to not modify `src` if we cannot read a full message"
221        )
222    }
223
224    #[test]
225    fn only_partial_message_in_bytes_mut_does_not_panic() {
226        let mut codec = Codec::<Dummy>::new(100);
227
228        let result = codec.decode(&mut BytesMut::new());
229
230        assert!(result.unwrap().is_none());
231    }
232
233    #[test]
234    fn handles_arbitrary_initial_capacity() {
235        fn prop(message: proto::Message, initial_capacity: u16) {
236            let mut buffer = BytesMut::with_capacity(initial_capacity as usize);
237            let mut codec = Codec::<proto::Message>::new(u32::MAX as usize);
238
239            codec.encode(message.clone(), &mut buffer).unwrap();
240            let decoded = codec.decode(&mut buffer).unwrap().unwrap();
241
242            assert_eq!(message, decoded);
243        }
244
245        QuickCheck::new().quickcheck(prop as fn(_, _) -> _)
246    }
247
248    /// Constructs a [`BytesMut`] of the provided length where the message is all zeros.
249    fn varint_zeroes(length: usize) -> BytesMut {
250        let mut buf = unsigned_varint::encode::usize_buffer();
251        let encoded_length = unsigned_varint::encode::usize(length, &mut buf);
252
253        let mut src = BytesMut::new();
254        src.extend_from_slice(encoded_length);
255        src.extend(std::iter::repeat_n(0, length));
256        src
257    }
258
259    impl Arbitrary for proto::Message {
260        fn arbitrary(g: &mut Gen) -> Self {
261            Self {
262                data: Vec::arbitrary(g),
263            }
264        }
265    }
266
267    #[derive(Debug)]
268    struct Dummy;
269
270    impl<'a> MessageRead<'a> for Dummy {
271        fn from_reader(_: &mut BytesReader, _: &'a [u8]) -> quick_protobuf::Result<Self> {
272            todo!()
273        }
274    }
275}