libp2p_request_response/
cbor.rs1pub type Behaviour<Req, Resp> = crate::Behaviour<codec::Codec<Req, Resp>>;
57
58pub mod codec {
59 use std::{collections::TryReserveError, convert::Infallible, io, marker::PhantomData};
60
61 use async_trait::async_trait;
62 use cbor4ii::core::error::DecodeError;
63 use futures::prelude::*;
64 use libp2p_swarm::StreamProtocol;
65 use serde::{de::DeserializeOwned, Serialize};
66
67 pub struct Codec<Req, Resp> {
68 request_size_maximum: u64,
70 response_size_maximum: u64,
72 phantom: PhantomData<(Req, Resp)>,
73 }
74
75 impl<Req, Resp> Default for Codec<Req, Resp> {
76 fn default() -> Self {
77 Codec {
78 request_size_maximum: 1024 * 1024,
79 response_size_maximum: 10 * 1024 * 1024,
80 phantom: PhantomData,
81 }
82 }
83 }
84
85 impl<Req, Resp> Clone for Codec<Req, Resp> {
86 fn clone(&self) -> Self {
87 Self {
88 request_size_maximum: self.request_size_maximum,
89 response_size_maximum: self.response_size_maximum,
90 phantom: PhantomData,
91 }
92 }
93 }
94
95 impl<Req, Resp> Codec<Req, Resp> {
96 pub fn set_request_size_maximum(mut self, request_size_maximum: u64) -> Self {
98 self.request_size_maximum = request_size_maximum;
99 self
100 }
101
102 pub fn set_response_size_maximum(mut self, response_size_maximum: u64) -> Self {
104 self.response_size_maximum = response_size_maximum;
105 self
106 }
107 }
108
109 #[async_trait]
110 impl<Req, Resp> crate::Codec for Codec<Req, Resp>
111 where
112 Req: Send + Serialize + DeserializeOwned,
113 Resp: Send + Serialize + DeserializeOwned,
114 {
115 type Protocol = StreamProtocol;
116 type Request = Req;
117 type Response = Resp;
118
119 async fn read_request<T>(&mut self, _: &Self::Protocol, io: &mut T) -> io::Result<Req>
120 where
121 T: AsyncRead + Unpin + Send,
122 {
123 let mut vec = Vec::new();
124
125 io.take(self.request_size_maximum)
126 .read_to_end(&mut vec)
127 .await?;
128
129 cbor4ii::serde::from_slice(vec.as_slice()).map_err(decode_into_io_error)
130 }
131
132 async fn read_response<T>(&mut self, _: &Self::Protocol, io: &mut T) -> io::Result<Resp>
133 where
134 T: AsyncRead + Unpin + Send,
135 {
136 let mut vec = Vec::new();
137
138 io.take(self.response_size_maximum)
139 .read_to_end(&mut vec)
140 .await?;
141
142 cbor4ii::serde::from_slice(vec.as_slice()).map_err(decode_into_io_error)
143 }
144
145 async fn write_request<T>(
146 &mut self,
147 _: &Self::Protocol,
148 io: &mut T,
149 req: Self::Request,
150 ) -> io::Result<()>
151 where
152 T: AsyncWrite + Unpin + Send,
153 {
154 let data: Vec<u8> =
155 cbor4ii::serde::to_vec(Vec::new(), &req).map_err(encode_into_io_error)?;
156
157 io.write_all(data.as_ref()).await?;
158
159 Ok(())
160 }
161
162 async fn write_response<T>(
163 &mut self,
164 _: &Self::Protocol,
165 io: &mut T,
166 resp: Self::Response,
167 ) -> io::Result<()>
168 where
169 T: AsyncWrite + Unpin + Send,
170 {
171 let data: Vec<u8> =
172 cbor4ii::serde::to_vec(Vec::new(), &resp).map_err(encode_into_io_error)?;
173
174 io.write_all(data.as_ref()).await?;
175
176 Ok(())
177 }
178 }
179
180 fn decode_into_io_error(err: cbor4ii::serde::DecodeError<Infallible>) -> io::Error {
181 match err {
182 cbor4ii::serde::DecodeError::Core(DecodeError::Read(e)) => io::Error::other(e),
183 cbor4ii::serde::DecodeError::Core(e @ DecodeError::Unsupported { .. }) => {
184 io::Error::new(io::ErrorKind::Unsupported, e)
185 }
186 cbor4ii::serde::DecodeError::Core(e @ DecodeError::Eof { .. }) => {
187 io::Error::new(io::ErrorKind::UnexpectedEof, e)
188 }
189 cbor4ii::serde::DecodeError::Core(e) => io::Error::new(io::ErrorKind::InvalidData, e),
190 cbor4ii::serde::DecodeError::Custom(e) => io::Error::other(e.to_string()),
191 }
192 }
193
194 fn encode_into_io_error(err: cbor4ii::serde::EncodeError<TryReserveError>) -> io::Error {
195 io::Error::other(err)
196 }
197}
198
199#[cfg(test)]
200mod tests {
201 use futures::AsyncWriteExt;
202 use futures_ringbuf::Endpoint;
203 use libp2p_swarm::StreamProtocol;
204 use serde::{Deserialize, Serialize};
205
206 use crate::{cbor::codec::Codec, Codec as _};
207
208 #[tokio::test]
209 async fn test_codec() {
210 let expected_request = TestRequest {
211 payload: "test_payload".to_string(),
212 };
213 let expected_response = TestResponse {
214 payload: "test_payload".to_string(),
215 };
216 let protocol = StreamProtocol::new("/test_cbor/1");
217 let mut codec = Codec::default();
218
219 let (mut a, mut b) = Endpoint::pair(124, 124);
220 codec
221 .write_request(&protocol, &mut a, expected_request.clone())
222 .await
223 .expect("Should write request");
224 a.close().await.unwrap();
225
226 let actual_request = codec
227 .read_request(&protocol, &mut b)
228 .await
229 .expect("Should read request");
230 b.close().await.unwrap();
231
232 assert_eq!(actual_request, expected_request);
233
234 let (mut a, mut b) = Endpoint::pair(124, 124);
235 codec
236 .write_response(&protocol, &mut a, expected_response.clone())
237 .await
238 .expect("Should write response");
239 a.close().await.unwrap();
240
241 let actual_response = codec
242 .read_response(&protocol, &mut b)
243 .await
244 .expect("Should read response");
245 b.close().await.unwrap();
246
247 assert_eq!(actual_response, expected_response);
248 }
249
250 #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
251 struct TestRequest {
252 payload: String,
253 }
254
255 #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
256 struct TestResponse {
257 payload: String,
258 }
259}