1use std::{io, io::ErrorKind};
4
5use asynchronous_codec::{Framed, FramedRead, FramedWrite};
6use futures::{AsyncRead, AsyncWrite, SinkExt, StreamExt};
7use libp2p_core::Multiaddr;
8use quick_protobuf_codec::Codec;
9use rand::Rng;
10
11use crate::v2::{generated::structs as proto, Nonce};
12
13const REQUEST_MAX_SIZE: usize = 4104;
14pub(super) const DATA_LEN_LOWER_BOUND: usize = 30_000u32 as usize;
15pub(super) const DATA_LEN_UPPER_BOUND: usize = 100_000u32 as usize;
16pub(super) const DATA_FIELD_LEN_UPPER_BOUND: usize = 4096;
17
18fn new_io_invalid_data_err(msg: impl Into<String>) -> io::Error {
19 io::Error::new(io::ErrorKind::InvalidData, msg.into())
20}
21
22pub(crate) struct Coder<I> {
23 inner: Framed<I, Codec<proto::Message>>,
24}
25
26impl<I> Coder<I>
27where
28 I: AsyncWrite + AsyncRead + Unpin,
29{
30 pub(crate) fn new(io: I) -> Self {
31 Self {
32 inner: Framed::new(io, Codec::new(REQUEST_MAX_SIZE)),
33 }
34 }
35 pub(crate) async fn close(mut self) -> io::Result<()> {
36 self.inner.close().await?;
37 Ok(())
38 }
39}
40
41impl<I> Coder<I>
42where
43 I: AsyncRead + Unpin,
44{
45 pub(crate) async fn next<M, E>(&mut self) -> io::Result<M>
46 where
47 proto::Message: TryInto<M, Error = E>,
48 io::Error: From<E>,
49 {
50 Ok(self.next_msg().await?.try_into()?)
51 }
52
53 async fn next_msg(&mut self) -> io::Result<proto::Message> {
54 self.inner
55 .next()
56 .await
57 .ok_or(io::Error::new(
58 ErrorKind::UnexpectedEof,
59 "no request to read",
60 ))?
61 .map_err(|e| io::Error::new(ErrorKind::InvalidData, e))
62 }
63}
64
65impl<I> Coder<I>
66where
67 I: AsyncWrite + Unpin,
68{
69 pub(crate) async fn send<M>(&mut self, msg: M) -> io::Result<()>
70 where
71 M: Into<proto::Message>,
72 {
73 self.inner.send(msg.into()).await?;
74 Ok(())
75 }
76}
77
78#[derive(Debug, Clone, PartialEq)]
79pub(crate) enum Request {
80 Dial(DialRequest),
81 Data(DialDataResponse),
82}
83
84impl From<DialRequest> for proto::Message {
85 fn from(val: DialRequest) -> Self {
86 let addrs = val.addrs.iter().map(|e| e.to_vec()).collect();
87 let nonce = val.nonce;
88
89 proto::Message {
90 msg: proto::mod_Message::OneOfmsg::dialRequest(proto::DialRequest { addrs, nonce }),
91 }
92 }
93}
94
95impl From<DialDataResponse> for proto::Message {
96 fn from(val: DialDataResponse) -> Self {
97 debug_assert!(
98 val.data_count <= DATA_FIELD_LEN_UPPER_BOUND,
99 "data_count too large"
100 );
101 proto::Message {
102 msg: proto::mod_Message::OneOfmsg::dialDataResponse(proto::DialDataResponse {
103 data: vec![0; val.data_count],
107 }),
108 }
109 }
110}
111
112#[derive(Debug, Clone, PartialEq)]
113pub struct DialRequest {
114 pub(crate) addrs: Vec<Multiaddr>,
115 pub(crate) nonce: u64,
116}
117
118#[derive(Debug, Clone, PartialEq)]
119pub(crate) struct DialDataResponse {
120 data_count: usize,
121}
122
123impl DialDataResponse {
124 pub(crate) fn new(data_count: usize) -> Option<Self> {
125 if data_count <= DATA_FIELD_LEN_UPPER_BOUND {
126 Some(Self { data_count })
127 } else {
128 None
129 }
130 }
131
132 pub(crate) fn get_data_count(&self) -> usize {
133 self.data_count
134 }
135}
136
137impl TryFrom<proto::Message> for Request {
138 type Error = io::Error;
139
140 fn try_from(msg: proto::Message) -> Result<Self, Self::Error> {
141 match msg.msg {
142 proto::mod_Message::OneOfmsg::dialRequest(proto::DialRequest { addrs, nonce }) => {
143 let addrs = addrs
144 .into_iter()
145 .map(|e| e.to_vec())
146 .map(|e| {
147 Multiaddr::try_from(e).map_err(|err| {
148 new_io_invalid_data_err(format!("invalid multiaddr: {}", err))
149 })
150 })
151 .collect::<Result<Vec<_>, io::Error>>()?;
152 Ok(Self::Dial(DialRequest { addrs, nonce }))
153 }
154 proto::mod_Message::OneOfmsg::dialDataResponse(proto::DialDataResponse { data }) => {
155 let data_count = data.len();
156 Ok(Self::Data(DialDataResponse { data_count }))
157 }
158 _ => Err(new_io_invalid_data_err(
159 "expected dialResponse or dialDataRequest",
160 )),
161 }
162 }
163}
164
165#[derive(Debug, Clone)]
166pub(crate) enum Response {
167 Dial(DialResponse),
168 Data(DialDataRequest),
169}
170
171#[derive(Debug, Clone)]
172pub(crate) struct DialDataRequest {
173 pub(crate) addr_idx: usize,
174 pub(crate) num_bytes: usize,
175}
176
177#[derive(Debug, Clone)]
178pub(crate) struct DialResponse {
179 pub(crate) status: proto::mod_DialResponse::ResponseStatus,
180 pub(crate) addr_idx: usize,
181 pub(crate) dial_status: proto::DialStatus,
182}
183
184impl TryFrom<proto::Message> for Response {
185 type Error = io::Error;
186
187 fn try_from(msg: proto::Message) -> Result<Self, Self::Error> {
188 match msg.msg {
189 proto::mod_Message::OneOfmsg::dialResponse(proto::DialResponse {
190 status,
191 addrIdx,
192 dialStatus,
193 }) => Ok(Response::Dial(DialResponse {
194 status,
195 addr_idx: addrIdx as usize,
196 dial_status: dialStatus,
197 })),
198 proto::mod_Message::OneOfmsg::dialDataRequest(proto::DialDataRequest {
199 addrIdx,
200 numBytes,
201 }) => Ok(Self::Data(DialDataRequest {
202 addr_idx: addrIdx as usize,
203 num_bytes: numBytes as usize,
204 })),
205 _ => Err(new_io_invalid_data_err(
206 "invalid message type, expected dialResponse or dialDataRequest",
207 )),
208 }
209 }
210}
211
212impl From<Response> for proto::Message {
213 fn from(val: Response) -> Self {
214 match val {
215 Response::Dial(DialResponse {
216 status,
217 addr_idx,
218 dial_status,
219 }) => proto::Message {
220 msg: proto::mod_Message::OneOfmsg::dialResponse(proto::DialResponse {
221 status,
222 addrIdx: addr_idx as u32,
223 dialStatus: dial_status,
224 }),
225 },
226 Response::Data(DialDataRequest {
227 addr_idx,
228 num_bytes,
229 }) => proto::Message {
230 msg: proto::mod_Message::OneOfmsg::dialDataRequest(proto::DialDataRequest {
231 addrIdx: addr_idx as u32,
232 numBytes: num_bytes as u64,
233 }),
234 },
235 }
236 }
237}
238
239impl DialDataRequest {
240 pub(crate) fn from_rng<R: rand_core::RngCore>(addr_idx: usize, mut rng: R) -> Self {
241 let num_bytes = rng.gen_range(DATA_LEN_LOWER_BOUND..=DATA_LEN_UPPER_BOUND);
242 Self {
243 addr_idx,
244 num_bytes,
245 }
246 }
247}
248
249const DIAL_BACK_MAX_SIZE: usize = 10;
250
251pub(crate) async fn dial_back(stream: impl AsyncWrite + Unpin, nonce: Nonce) -> io::Result<()> {
252 let msg = proto::DialBack { nonce };
253 let mut framed = FramedWrite::new(stream, Codec::<proto::DialBack>::new(DIAL_BACK_MAX_SIZE));
254
255 framed
256 .send(msg)
257 .await
258 .map_err(|e| io::Error::new(io::ErrorKind::Other, e))?;
259
260 Ok(())
261}
262
263pub(crate) async fn recv_dial_back(stream: impl AsyncRead + Unpin) -> io::Result<Nonce> {
264 let framed = &mut FramedRead::new(stream, Codec::<proto::DialBack>::new(DIAL_BACK_MAX_SIZE));
265 let proto::DialBack { nonce } = framed
266 .next()
267 .await
268 .ok_or(io::Error::from(io::ErrorKind::UnexpectedEof))??;
269 Ok(nonce)
270}
271
272pub(crate) async fn dial_back_response(stream: impl AsyncWrite + Unpin) -> io::Result<()> {
273 let msg = proto::DialBackResponse {
274 status: proto::mod_DialBackResponse::DialBackStatus::OK,
275 };
276 let mut framed = FramedWrite::new(
277 stream,
278 Codec::<proto::DialBackResponse>::new(DIAL_BACK_MAX_SIZE),
279 );
280 framed
281 .send(msg)
282 .await
283 .map_err(|e| io::Error::new(io::ErrorKind::Other, e))?;
284
285 Ok(())
286}
287
288pub(crate) async fn recv_dial_back_response(
289 stream: impl AsyncRead + AsyncWrite + Unpin,
290) -> io::Result<()> {
291 let framed = &mut FramedRead::new(
292 stream,
293 Codec::<proto::DialBackResponse>::new(DIAL_BACK_MAX_SIZE),
294 );
295 let proto::DialBackResponse { status } = framed
296 .next()
297 .await
298 .ok_or(io::Error::from(io::ErrorKind::UnexpectedEof))??;
299
300 if proto::mod_DialBackResponse::DialBackStatus::OK == status {
301 Ok(())
302 } else {
303 Err(io::Error::new(
304 io::ErrorKind::InvalidData,
305 "invalid dial back response",
306 ))
307 }
308}
309
310#[cfg(test)]
311mod tests {
312 use crate::v2::generated::structs::{
313 mod_Message::OneOfmsg, DialDataResponse as GenDialDataResponse, Message,
314 };
315
316 #[test]
317 fn message_correct_max_size() {
318 let message_bytes = quick_protobuf::serialize_into_vec(&Message {
319 msg: OneOfmsg::dialDataResponse(GenDialDataResponse {
320 data: vec![0; 4096],
321 }),
322 })
323 .unwrap();
324 assert_eq!(message_bytes.len(), super::REQUEST_MAX_SIZE);
325 }
326
327 #[test]
328 fn dial_back_correct_size() {
329 let dial_back = super::proto::DialBack { nonce: 0 };
330 let buf = quick_protobuf::serialize_into_vec(&dial_back).unwrap();
331 assert!(buf.len() <= super::DIAL_BACK_MAX_SIZE);
332
333 let dial_back_max_nonce = super::proto::DialBack { nonce: u64::MAX };
334 let buf = quick_protobuf::serialize_into_vec(&dial_back_max_nonce).unwrap();
335 assert!(buf.len() <= super::DIAL_BACK_MAX_SIZE);
336 }
337}