libp2p_request_response/
cbor.rs1pub type Behaviour<Req, Resp> = crate::Behaviour<codec::Codec<Req, Resp>>;
57
58pub mod codec {
59 use std::{collections::TryReserveError, convert::Infallible, io, marker::PhantomData};
60
61 use async_trait::async_trait;
62 use cbor4ii::core::error::DecodeError;
63 use futures::prelude::*;
64 use libp2p_swarm::StreamProtocol;
65 use serde::{de::DeserializeOwned, Serialize};
66
67 pub struct Codec<Req, Resp> {
68 request_size_maximum: u64,
70 response_size_maximum: u64,
72 phantom: PhantomData<(Req, Resp)>,
73 }
74
75 impl<Req, Resp> Default for Codec<Req, Resp> {
76 fn default() -> Self {
77 Codec {
78 request_size_maximum: 1024 * 1024,
79 response_size_maximum: 10 * 1024 * 1024,
80 phantom: PhantomData,
81 }
82 }
83 }
84
85 impl<Req, Resp> Clone for Codec<Req, Resp> {
86 fn clone(&self) -> Self {
87 Self {
88 request_size_maximum: self.request_size_maximum,
89 response_size_maximum: self.response_size_maximum,
90 phantom: PhantomData,
91 }
92 }
93 }
94
95 impl<Req, Resp> Codec<Req, Resp> {
96 pub fn set_request_size_maximum(mut self, request_size_maximum: u64) -> Self {
98 self.request_size_maximum = request_size_maximum;
99 self
100 }
101
102 pub fn set_response_size_maximum(mut self, response_size_maximum: u64) -> Self {
104 self.response_size_maximum = response_size_maximum;
105 self
106 }
107 }
108
109 #[async_trait]
110 impl<Req, Resp> crate::Codec for Codec<Req, Resp>
111 where
112 Req: Send + Serialize + DeserializeOwned,
113 Resp: Send + Serialize + DeserializeOwned,
114 {
115 type Protocol = StreamProtocol;
116 type Request = Req;
117 type Response = Resp;
118
119 async fn read_request<T>(&mut self, _: &Self::Protocol, io: &mut T) -> io::Result<Req>
120 where
121 T: AsyncRead + Unpin + Send,
122 {
123 let mut vec = Vec::new();
124
125 io.take(self.request_size_maximum)
126 .read_to_end(&mut vec)
127 .await?;
128
129 cbor4ii::serde::from_slice(vec.as_slice()).map_err(decode_into_io_error)
130 }
131
132 async fn read_response<T>(&mut self, _: &Self::Protocol, io: &mut T) -> io::Result<Resp>
133 where
134 T: AsyncRead + Unpin + Send,
135 {
136 let mut vec = Vec::new();
137
138 io.take(self.response_size_maximum)
139 .read_to_end(&mut vec)
140 .await?;
141
142 cbor4ii::serde::from_slice(vec.as_slice()).map_err(decode_into_io_error)
143 }
144
145 async fn write_request<T>(
146 &mut self,
147 _: &Self::Protocol,
148 io: &mut T,
149 req: Self::Request,
150 ) -> io::Result<()>
151 where
152 T: AsyncWrite + Unpin + Send,
153 {
154 let data: Vec<u8> =
155 cbor4ii::serde::to_vec(Vec::new(), &req).map_err(encode_into_io_error)?;
156
157 io.write_all(data.as_ref()).await?;
158
159 Ok(())
160 }
161
162 async fn write_response<T>(
163 &mut self,
164 _: &Self::Protocol,
165 io: &mut T,
166 resp: Self::Response,
167 ) -> io::Result<()>
168 where
169 T: AsyncWrite + Unpin + Send,
170 {
171 let data: Vec<u8> =
172 cbor4ii::serde::to_vec(Vec::new(), &resp).map_err(encode_into_io_error)?;
173
174 io.write_all(data.as_ref()).await?;
175
176 Ok(())
177 }
178 }
179
180 fn decode_into_io_error(err: cbor4ii::serde::DecodeError<Infallible>) -> io::Error {
181 match err {
182 cbor4ii::serde::DecodeError::Core(DecodeError::Read(e)) => {
183 io::Error::new(io::ErrorKind::Other, e)
184 }
185 cbor4ii::serde::DecodeError::Core(e @ DecodeError::Unsupported { .. }) => {
186 io::Error::new(io::ErrorKind::Unsupported, e)
187 }
188 cbor4ii::serde::DecodeError::Core(e @ DecodeError::Eof { .. }) => {
189 io::Error::new(io::ErrorKind::UnexpectedEof, e)
190 }
191 cbor4ii::serde::DecodeError::Core(e) => io::Error::new(io::ErrorKind::InvalidData, e),
192 cbor4ii::serde::DecodeError::Custom(e) => {
193 io::Error::new(io::ErrorKind::Other, e.to_string())
194 }
195 }
196 }
197
198 fn encode_into_io_error(err: cbor4ii::serde::EncodeError<TryReserveError>) -> io::Error {
199 io::Error::new(io::ErrorKind::Other, err)
200 }
201}
202
203#[cfg(test)]
204mod tests {
205 use futures::AsyncWriteExt;
206 use futures_ringbuf::Endpoint;
207 use libp2p_swarm::StreamProtocol;
208 use serde::{Deserialize, Serialize};
209
210 use crate::{cbor::codec::Codec, Codec as _};
211
212 #[async_std::test]
213 async fn test_codec() {
214 let expected_request = TestRequest {
215 payload: "test_payload".to_string(),
216 };
217 let expected_response = TestResponse {
218 payload: "test_payload".to_string(),
219 };
220 let protocol = StreamProtocol::new("/test_cbor/1");
221 let mut codec = Codec::default();
222
223 let (mut a, mut b) = Endpoint::pair(124, 124);
224 codec
225 .write_request(&protocol, &mut a, expected_request.clone())
226 .await
227 .expect("Should write request");
228 a.close().await.unwrap();
229
230 let actual_request = codec
231 .read_request(&protocol, &mut b)
232 .await
233 .expect("Should read request");
234 b.close().await.unwrap();
235
236 assert_eq!(actual_request, expected_request);
237
238 let (mut a, mut b) = Endpoint::pair(124, 124);
239 codec
240 .write_response(&protocol, &mut a, expected_response.clone())
241 .await
242 .expect("Should write response");
243 a.close().await.unwrap();
244
245 let actual_response = codec
246 .read_response(&protocol, &mut b)
247 .await
248 .expect("Should read response");
249 b.close().await.unwrap();
250
251 assert_eq!(actual_response, expected_response);
252 }
253
254 #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
255 struct TestRequest {
256 payload: String,
257 }
258
259 #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
260 struct TestResponse {
261 payload: String,
262 }
263}