libp2p_request_response/
cbor.rs

1// Copyright 2023 Protocol Labs
2//
3// Permission is hereby granted, free of charge, to any person obtaining a
4// copy of this software and associated documentation files (the "Software"),
5// to deal in the Software without restriction, including without limitation
6// the rights to use, copy, modify, merge, publish, distribute, sublicense,
7// and/or sell copies of the Software, and to permit persons to whom the
8// Software is furnished to do so, subject to the following conditions:
9//
10// The above copyright notice and this permission notice shall be included in
11// all copies or substantial portions of the Software.
12//
13// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
14// OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
16// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
18// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
19// DEALINGS IN THE SOFTWARE.
20
21/// A request-response behaviour using [`cbor4ii::serde`] for serializing and
22/// deserializing the messages.
23///
24/// # Default Size Limits
25///
26/// The codec uses the following default size limits:
27/// - Maximum request size: 1,048,576 bytes (1 MiB)
28/// - Maximum response size: 10,485,760 bytes (10 MiB)
29///
30/// These limits can be customized with [`codec::Codec::set_request_size_maximum`]
31/// and [`codec::Codec::set_response_size_maximum`].
32///
33/// # Example
34///
35/// ```
36/// # use libp2p_request_response::{cbor, ProtocolSupport, self as request_response};
37/// # use libp2p_swarm::StreamProtocol;
38/// #[derive(Debug, serde::Serialize, serde::Deserialize)]
39/// struct GreetRequest {
40///     name: String,
41/// }
42///
43/// #[derive(Debug, serde::Serialize, serde::Deserialize)]
44/// struct GreetResponse {
45///     message: String,
46/// }
47///
48/// let behaviour = cbor::Behaviour::<GreetRequest, GreetResponse>::new(
49///     [(
50///         StreamProtocol::new("/my-cbor-protocol"),
51///         ProtocolSupport::Full,
52///     )],
53///     request_response::Config::default(),
54/// );
55/// ```
56pub type Behaviour<Req, Resp> = crate::Behaviour<codec::Codec<Req, Resp>>;
57
58pub mod codec {
59    use std::{collections::TryReserveError, convert::Infallible, io, marker::PhantomData};
60
61    use async_trait::async_trait;
62    use cbor4ii::core::error::DecodeError;
63    use futures::prelude::*;
64    use libp2p_swarm::StreamProtocol;
65    use serde::{de::DeserializeOwned, Serialize};
66
67    pub struct Codec<Req, Resp> {
68        /// Max request size in bytes.
69        request_size_maximum: u64,
70        /// Max response size in bytes.
71        response_size_maximum: u64,
72        phantom: PhantomData<(Req, Resp)>,
73    }
74
75    impl<Req, Resp> Default for Codec<Req, Resp> {
76        fn default() -> Self {
77            Codec {
78                request_size_maximum: 1024 * 1024,
79                response_size_maximum: 10 * 1024 * 1024,
80                phantom: PhantomData,
81            }
82        }
83    }
84
85    impl<Req, Resp> Clone for Codec<Req, Resp> {
86        fn clone(&self) -> Self {
87            Self {
88                request_size_maximum: self.request_size_maximum,
89                response_size_maximum: self.response_size_maximum,
90                phantom: PhantomData,
91            }
92        }
93    }
94
95    impl<Req, Resp> Codec<Req, Resp> {
96        /// Sets the limit for request size in bytes.
97        pub fn set_request_size_maximum(mut self, request_size_maximum: u64) -> Self {
98            self.request_size_maximum = request_size_maximum;
99            self
100        }
101
102        /// Sets the limit for response size in bytes.
103        pub fn set_response_size_maximum(mut self, response_size_maximum: u64) -> Self {
104            self.response_size_maximum = response_size_maximum;
105            self
106        }
107    }
108
109    #[async_trait]
110    impl<Req, Resp> crate::Codec for Codec<Req, Resp>
111    where
112        Req: Send + Serialize + DeserializeOwned,
113        Resp: Send + Serialize + DeserializeOwned,
114    {
115        type Protocol = StreamProtocol;
116        type Request = Req;
117        type Response = Resp;
118
119        async fn read_request<T>(&mut self, _: &Self::Protocol, io: &mut T) -> io::Result<Req>
120        where
121            T: AsyncRead + Unpin + Send,
122        {
123            let mut vec = Vec::new();
124
125            io.take(self.request_size_maximum)
126                .read_to_end(&mut vec)
127                .await?;
128
129            cbor4ii::serde::from_slice(vec.as_slice()).map_err(decode_into_io_error)
130        }
131
132        async fn read_response<T>(&mut self, _: &Self::Protocol, io: &mut T) -> io::Result<Resp>
133        where
134            T: AsyncRead + Unpin + Send,
135        {
136            let mut vec = Vec::new();
137
138            io.take(self.response_size_maximum)
139                .read_to_end(&mut vec)
140                .await?;
141
142            cbor4ii::serde::from_slice(vec.as_slice()).map_err(decode_into_io_error)
143        }
144
145        async fn write_request<T>(
146            &mut self,
147            _: &Self::Protocol,
148            io: &mut T,
149            req: Self::Request,
150        ) -> io::Result<()>
151        where
152            T: AsyncWrite + Unpin + Send,
153        {
154            let data: Vec<u8> =
155                cbor4ii::serde::to_vec(Vec::new(), &req).map_err(encode_into_io_error)?;
156
157            io.write_all(data.as_ref()).await?;
158
159            Ok(())
160        }
161
162        async fn write_response<T>(
163            &mut self,
164            _: &Self::Protocol,
165            io: &mut T,
166            resp: Self::Response,
167        ) -> io::Result<()>
168        where
169            T: AsyncWrite + Unpin + Send,
170        {
171            let data: Vec<u8> =
172                cbor4ii::serde::to_vec(Vec::new(), &resp).map_err(encode_into_io_error)?;
173
174            io.write_all(data.as_ref()).await?;
175
176            Ok(())
177        }
178    }
179
180    fn decode_into_io_error(err: cbor4ii::serde::DecodeError<Infallible>) -> io::Error {
181        match err {
182            cbor4ii::serde::DecodeError::Core(DecodeError::Read(e)) => {
183                io::Error::new(io::ErrorKind::Other, e)
184            }
185            cbor4ii::serde::DecodeError::Core(e @ DecodeError::Unsupported { .. }) => {
186                io::Error::new(io::ErrorKind::Unsupported, e)
187            }
188            cbor4ii::serde::DecodeError::Core(e @ DecodeError::Eof { .. }) => {
189                io::Error::new(io::ErrorKind::UnexpectedEof, e)
190            }
191            cbor4ii::serde::DecodeError::Core(e) => io::Error::new(io::ErrorKind::InvalidData, e),
192            cbor4ii::serde::DecodeError::Custom(e) => {
193                io::Error::new(io::ErrorKind::Other, e.to_string())
194            }
195        }
196    }
197
198    fn encode_into_io_error(err: cbor4ii::serde::EncodeError<TryReserveError>) -> io::Error {
199        io::Error::new(io::ErrorKind::Other, err)
200    }
201}
202
203#[cfg(test)]
204mod tests {
205    use futures::AsyncWriteExt;
206    use futures_ringbuf::Endpoint;
207    use libp2p_swarm::StreamProtocol;
208    use serde::{Deserialize, Serialize};
209
210    use crate::{cbor::codec::Codec, Codec as _};
211
212    #[async_std::test]
213    async fn test_codec() {
214        let expected_request = TestRequest {
215            payload: "test_payload".to_string(),
216        };
217        let expected_response = TestResponse {
218            payload: "test_payload".to_string(),
219        };
220        let protocol = StreamProtocol::new("/test_cbor/1");
221        let mut codec = Codec::default();
222
223        let (mut a, mut b) = Endpoint::pair(124, 124);
224        codec
225            .write_request(&protocol, &mut a, expected_request.clone())
226            .await
227            .expect("Should write request");
228        a.close().await.unwrap();
229
230        let actual_request = codec
231            .read_request(&protocol, &mut b)
232            .await
233            .expect("Should read request");
234        b.close().await.unwrap();
235
236        assert_eq!(actual_request, expected_request);
237
238        let (mut a, mut b) = Endpoint::pair(124, 124);
239        codec
240            .write_response(&protocol, &mut a, expected_response.clone())
241            .await
242            .expect("Should write response");
243        a.close().await.unwrap();
244
245        let actual_response = codec
246            .read_response(&protocol, &mut b)
247            .await
248            .expect("Should read response");
249        b.close().await.unwrap();
250
251        assert_eq!(actual_response, expected_response);
252    }
253
254    #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
255    struct TestRequest {
256        payload: String,
257    }
258
259    #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
260    struct TestResponse {
261        payload: String,
262    }
263}