quick_protobuf_codec/
lib.rs1#![cfg_attr(docsrs, feature(doc_cfg, doc_auto_cfg))]
2
3use std::{io, marker::PhantomData};
4
5use asynchronous_codec::{Decoder, Encoder};
6use bytes::{Buf, BufMut, BytesMut};
7use quick_protobuf::{BytesReader, MessageRead, MessageWrite, Writer, WriterBackend};
8
9mod generated;
10
11#[doc(hidden)] pub use generated::test as proto;
13
14pub struct Codec<In, Out = In> {
19 max_message_len_bytes: usize,
20 phantom: PhantomData<(In, Out)>,
21}
22
23impl<In, Out> Codec<In, Out> {
24 pub fn new(max_message_len_bytes: usize) -> Self {
30 Self {
31 max_message_len_bytes,
32 phantom: PhantomData,
33 }
34 }
35}
36
37impl<In: MessageWrite, Out> Encoder for Codec<In, Out> {
38 type Item<'a> = In;
39 type Error = Error;
40
41 fn encode(&mut self, item: Self::Item<'_>, dst: &mut BytesMut) -> Result<(), Self::Error> {
42 write_length(&item, dst);
43 write_message(&item, dst)?;
44
45 Ok(())
46 }
47}
48
49fn write_length(message: &impl MessageWrite, dst: &mut BytesMut) {
51 let message_length = message.get_size();
52
53 let mut uvi_buf = unsigned_varint::encode::usize_buffer();
54 let encoded_length = unsigned_varint::encode::usize(message_length, &mut uvi_buf);
55
56 dst.extend_from_slice(encoded_length);
57}
58
59fn write_message(item: &impl MessageWrite, dst: &mut BytesMut) -> io::Result<()> {
61 let mut writer = Writer::new(BytesMutWriterBackend::new(dst));
62 item.write_message(&mut writer)
63 .map_err(|e| io::Error::new(io::ErrorKind::Other, e))?;
64
65 Ok(())
66}
67
68impl<In, Out> Decoder for Codec<In, Out>
69where
70 Out: for<'a> MessageRead<'a>,
71{
72 type Item = Out;
73 type Error = Error;
74
75 fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
76 let (message_length, remaining) = match unsigned_varint::decode::usize(src) {
77 Ok((len, remaining)) => (len, remaining),
78 Err(unsigned_varint::decode::Error::Insufficient) => return Ok(None),
79 Err(e) => return Err(Error(io::Error::new(io::ErrorKind::InvalidData, e))),
80 };
81
82 if message_length > self.max_message_len_bytes {
83 return Err(Error(io::Error::new(
84 io::ErrorKind::PermissionDenied,
85 format!(
86 "message with {message_length}b exceeds maximum of {}b",
87 self.max_message_len_bytes
88 ),
89 )));
90 }
91
92 let varint_length = src.len() - remaining.len();
94
95 if src.len() < (message_length + varint_length) {
97 return Ok(None);
98 }
99
100 src.advance(varint_length);
102
103 let message = src.split_to(message_length);
104
105 let mut reader = BytesReader::from_bytes(&message);
106 let message = Self::Item::from_reader(&mut reader, &message)
107 .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?;
108
109 Ok(Some(message))
110 }
111}
112
113struct BytesMutWriterBackend<'a> {
114 dst: &'a mut BytesMut,
115}
116
117impl<'a> BytesMutWriterBackend<'a> {
118 fn new(dst: &'a mut BytesMut) -> Self {
119 Self { dst }
120 }
121}
122
123impl WriterBackend for BytesMutWriterBackend<'_> {
124 fn pb_write_u8(&mut self, x: u8) -> quick_protobuf::Result<()> {
125 self.dst.put_u8(x);
126
127 Ok(())
128 }
129
130 fn pb_write_u32(&mut self, x: u32) -> quick_protobuf::Result<()> {
131 self.dst.put_u32_le(x);
132
133 Ok(())
134 }
135
136 fn pb_write_i32(&mut self, x: i32) -> quick_protobuf::Result<()> {
137 self.dst.put_i32_le(x);
138
139 Ok(())
140 }
141
142 fn pb_write_f32(&mut self, x: f32) -> quick_protobuf::Result<()> {
143 self.dst.put_f32_le(x);
144
145 Ok(())
146 }
147
148 fn pb_write_u64(&mut self, x: u64) -> quick_protobuf::Result<()> {
149 self.dst.put_u64_le(x);
150
151 Ok(())
152 }
153
154 fn pb_write_i64(&mut self, x: i64) -> quick_protobuf::Result<()> {
155 self.dst.put_i64_le(x);
156
157 Ok(())
158 }
159
160 fn pb_write_f64(&mut self, x: f64) -> quick_protobuf::Result<()> {
161 self.dst.put_f64_le(x);
162
163 Ok(())
164 }
165
166 fn pb_write_all(&mut self, buf: &[u8]) -> quick_protobuf::Result<()> {
167 self.dst.put_slice(buf);
168
169 Ok(())
170 }
171}
172
173#[derive(thiserror::Error, Debug)]
174#[error("Failed to encode/decode message")]
175pub struct Error(#[from] io::Error);
176
177impl From<Error> for io::Error {
178 fn from(e: Error) -> Self {
179 e.0
180 }
181}
182
183#[cfg(test)]
184mod tests {
185 use std::error::Error;
186
187 use asynchronous_codec::FramedRead;
188 use futures::{io::Cursor, FutureExt, StreamExt};
189 use quickcheck::{Arbitrary, Gen, QuickCheck};
190
191 use super::*;
192
193 #[test]
194 fn honors_max_message_length() {
195 let codec = Codec::<Dummy>::new(1);
196 let mut src = varint_zeroes(100);
197
198 let mut read = FramedRead::new(Cursor::new(&mut src), codec);
199 let err = read.next().now_or_never().unwrap().unwrap().unwrap_err();
200
201 assert_eq!(
202 err.source().unwrap().to_string(),
203 "message with 100b exceeds maximum of 1b"
204 )
205 }
206
207 #[test]
208 fn empty_bytes_mut_does_not_panic() {
209 let mut codec = Codec::<Dummy>::new(100);
210
211 let mut src = varint_zeroes(100);
212 src.truncate(50);
213
214 let result = codec.decode(&mut src);
215
216 assert!(result.unwrap().is_none());
217 assert_eq!(
218 src.len(),
219 50,
220 "to not modify `src` if we cannot read a full message"
221 )
222 }
223
224 #[test]
225 fn only_partial_message_in_bytes_mut_does_not_panic() {
226 let mut codec = Codec::<Dummy>::new(100);
227
228 let result = codec.decode(&mut BytesMut::new());
229
230 assert!(result.unwrap().is_none());
231 }
232
233 #[test]
234 fn handles_arbitrary_initial_capacity() {
235 fn prop(message: proto::Message, initial_capacity: u16) {
236 let mut buffer = BytesMut::with_capacity(initial_capacity as usize);
237 let mut codec = Codec::<proto::Message>::new(u32::MAX as usize);
238
239 codec.encode(message.clone(), &mut buffer).unwrap();
240 let decoded = codec.decode(&mut buffer).unwrap().unwrap();
241
242 assert_eq!(message, decoded);
243 }
244
245 QuickCheck::new().quickcheck(prop as fn(_, _) -> _)
246 }
247
248 fn varint_zeroes(length: usize) -> BytesMut {
250 let mut buf = unsigned_varint::encode::usize_buffer();
251 let encoded_length = unsigned_varint::encode::usize(length, &mut buf);
252
253 let mut src = BytesMut::new();
254 src.extend_from_slice(encoded_length);
255 src.extend(std::iter::repeat_n(0, length));
256 src
257 }
258
259 impl Arbitrary for proto::Message {
260 fn arbitrary(g: &mut Gen) -> Self {
261 Self {
262 data: Vec::arbitrary(g),
263 }
264 }
265 }
266
267 #[derive(Debug)]
268 struct Dummy;
269
270 impl<'a> MessageRead<'a> for Dummy {
271 fn from_reader(_: &mut BytesReader, _: &'a [u8]) -> quick_protobuf::Result<Self> {
272 todo!()
273 }
274 }
275}