libp2p_autonat/v2/
protocol.rs

1// change to quick-protobuf-codec
2
3use std::{io, io::ErrorKind};
4
5use asynchronous_codec::{Framed, FramedRead, FramedWrite};
6use futures::{AsyncRead, AsyncWrite, SinkExt, StreamExt};
7use libp2p_core::Multiaddr;
8use quick_protobuf_codec::Codec;
9use rand::Rng;
10
11use crate::v2::{generated::structs as proto, Nonce};
12
13const REQUEST_MAX_SIZE: usize = 4104;
14pub(super) const DATA_LEN_LOWER_BOUND: usize = 30_000u32 as usize;
15pub(super) const DATA_LEN_UPPER_BOUND: usize = 100_000u32 as usize;
16pub(super) const DATA_FIELD_LEN_UPPER_BOUND: usize = 4096;
17
18fn new_io_invalid_data_err(msg: impl Into<String>) -> io::Error {
19    io::Error::new(io::ErrorKind::InvalidData, msg.into())
20}
21
22pub(crate) struct Coder<I> {
23    inner: Framed<I, Codec<proto::Message>>,
24}
25
26impl<I> Coder<I>
27where
28    I: AsyncWrite + AsyncRead + Unpin,
29{
30    pub(crate) fn new(io: I) -> Self {
31        Self {
32            inner: Framed::new(io, Codec::new(REQUEST_MAX_SIZE)),
33        }
34    }
35    pub(crate) async fn close(mut self) -> io::Result<()> {
36        self.inner.close().await?;
37        Ok(())
38    }
39}
40
41impl<I> Coder<I>
42where
43    I: AsyncRead + Unpin,
44{
45    pub(crate) async fn next<M, E>(&mut self) -> io::Result<M>
46    where
47        proto::Message: TryInto<M, Error = E>,
48        io::Error: From<E>,
49    {
50        Ok(self.next_msg().await?.try_into()?)
51    }
52
53    async fn next_msg(&mut self) -> io::Result<proto::Message> {
54        self.inner
55            .next()
56            .await
57            .ok_or(io::Error::new(
58                ErrorKind::UnexpectedEof,
59                "no request to read",
60            ))?
61            .map_err(|e| io::Error::new(ErrorKind::InvalidData, e))
62    }
63}
64
65impl<I> Coder<I>
66where
67    I: AsyncWrite + Unpin,
68{
69    pub(crate) async fn send<M>(&mut self, msg: M) -> io::Result<()>
70    where
71        M: Into<proto::Message>,
72    {
73        self.inner.send(msg.into()).await?;
74        Ok(())
75    }
76}
77
78#[derive(Debug, Clone, PartialEq)]
79pub(crate) enum Request {
80    Dial(DialRequest),
81    Data(DialDataResponse),
82}
83
84impl From<DialRequest> for proto::Message {
85    fn from(val: DialRequest) -> Self {
86        let addrs = val.addrs.iter().map(|e| e.to_vec()).collect();
87        let nonce = val.nonce;
88
89        proto::Message {
90            msg: proto::mod_Message::OneOfmsg::dialRequest(proto::DialRequest { addrs, nonce }),
91        }
92    }
93}
94
95impl From<DialDataResponse> for proto::Message {
96    fn from(val: DialDataResponse) -> Self {
97        debug_assert!(
98            val.data_count <= DATA_FIELD_LEN_UPPER_BOUND,
99            "data_count too large"
100        );
101        proto::Message {
102            msg: proto::mod_Message::OneOfmsg::dialDataResponse(proto::DialDataResponse {
103                // One could use Cow::Borrowed here, but it will
104                // require a modification of the generated code
105                // and that will fail the CI
106                data: vec![0; val.data_count],
107            }),
108        }
109    }
110}
111
112#[derive(Debug, Clone, PartialEq)]
113pub struct DialRequest {
114    pub(crate) addrs: Vec<Multiaddr>,
115    pub(crate) nonce: u64,
116}
117
118#[derive(Debug, Clone, PartialEq)]
119pub(crate) struct DialDataResponse {
120    data_count: usize,
121}
122
123impl DialDataResponse {
124    pub(crate) fn new(data_count: usize) -> Option<Self> {
125        if data_count <= DATA_FIELD_LEN_UPPER_BOUND {
126            Some(Self { data_count })
127        } else {
128            None
129        }
130    }
131
132    pub(crate) fn get_data_count(&self) -> usize {
133        self.data_count
134    }
135}
136
137impl TryFrom<proto::Message> for Request {
138    type Error = io::Error;
139
140    fn try_from(msg: proto::Message) -> Result<Self, Self::Error> {
141        match msg.msg {
142            proto::mod_Message::OneOfmsg::dialRequest(proto::DialRequest { addrs, nonce }) => {
143                let addrs = addrs
144                    .into_iter()
145                    .map(|e| e.to_vec())
146                    .map(|e| {
147                        Multiaddr::try_from(e).map_err(|err| {
148                            new_io_invalid_data_err(format!("invalid multiaddr: {}", err))
149                        })
150                    })
151                    .collect::<Result<Vec<_>, io::Error>>()?;
152                Ok(Self::Dial(DialRequest { addrs, nonce }))
153            }
154            proto::mod_Message::OneOfmsg::dialDataResponse(proto::DialDataResponse { data }) => {
155                let data_count = data.len();
156                Ok(Self::Data(DialDataResponse { data_count }))
157            }
158            _ => Err(new_io_invalid_data_err(
159                "expected dialResponse or dialDataRequest",
160            )),
161        }
162    }
163}
164
165#[derive(Debug, Clone)]
166pub(crate) enum Response {
167    Dial(DialResponse),
168    Data(DialDataRequest),
169}
170
171#[derive(Debug, Clone)]
172pub(crate) struct DialDataRequest {
173    pub(crate) addr_idx: usize,
174    pub(crate) num_bytes: usize,
175}
176
177#[derive(Debug, Clone)]
178pub(crate) struct DialResponse {
179    pub(crate) status: proto::mod_DialResponse::ResponseStatus,
180    pub(crate) addr_idx: usize,
181    pub(crate) dial_status: proto::DialStatus,
182}
183
184impl TryFrom<proto::Message> for Response {
185    type Error = io::Error;
186
187    fn try_from(msg: proto::Message) -> Result<Self, Self::Error> {
188        match msg.msg {
189            proto::mod_Message::OneOfmsg::dialResponse(proto::DialResponse {
190                status,
191                addrIdx,
192                dialStatus,
193            }) => Ok(Response::Dial(DialResponse {
194                status,
195                addr_idx: addrIdx as usize,
196                dial_status: dialStatus,
197            })),
198            proto::mod_Message::OneOfmsg::dialDataRequest(proto::DialDataRequest {
199                addrIdx,
200                numBytes,
201            }) => Ok(Self::Data(DialDataRequest {
202                addr_idx: addrIdx as usize,
203                num_bytes: numBytes as usize,
204            })),
205            _ => Err(new_io_invalid_data_err(
206                "invalid message type, expected dialResponse or dialDataRequest",
207            )),
208        }
209    }
210}
211
212impl From<Response> for proto::Message {
213    fn from(val: Response) -> Self {
214        match val {
215            Response::Dial(DialResponse {
216                status,
217                addr_idx,
218                dial_status,
219            }) => proto::Message {
220                msg: proto::mod_Message::OneOfmsg::dialResponse(proto::DialResponse {
221                    status,
222                    addrIdx: addr_idx as u32,
223                    dialStatus: dial_status,
224                }),
225            },
226            Response::Data(DialDataRequest {
227                addr_idx,
228                num_bytes,
229            }) => proto::Message {
230                msg: proto::mod_Message::OneOfmsg::dialDataRequest(proto::DialDataRequest {
231                    addrIdx: addr_idx as u32,
232                    numBytes: num_bytes as u64,
233                }),
234            },
235        }
236    }
237}
238
239impl DialDataRequest {
240    pub(crate) fn from_rng<R: rand_core::RngCore>(addr_idx: usize, mut rng: R) -> Self {
241        let num_bytes = rng.gen_range(DATA_LEN_LOWER_BOUND..=DATA_LEN_UPPER_BOUND);
242        Self {
243            addr_idx,
244            num_bytes,
245        }
246    }
247}
248
249const DIAL_BACK_MAX_SIZE: usize = 10;
250
251pub(crate) async fn dial_back(stream: impl AsyncWrite + Unpin, nonce: Nonce) -> io::Result<()> {
252    let msg = proto::DialBack { nonce };
253    let mut framed = FramedWrite::new(stream, Codec::<proto::DialBack>::new(DIAL_BACK_MAX_SIZE));
254
255    framed
256        .send(msg)
257        .await
258        .map_err(|e| io::Error::new(io::ErrorKind::Other, e))?;
259
260    Ok(())
261}
262
263pub(crate) async fn recv_dial_back(stream: impl AsyncRead + Unpin) -> io::Result<Nonce> {
264    let framed = &mut FramedRead::new(stream, Codec::<proto::DialBack>::new(DIAL_BACK_MAX_SIZE));
265    let proto::DialBack { nonce } = framed
266        .next()
267        .await
268        .ok_or(io::Error::from(io::ErrorKind::UnexpectedEof))??;
269    Ok(nonce)
270}
271
272pub(crate) async fn dial_back_response(stream: impl AsyncWrite + Unpin) -> io::Result<()> {
273    let msg = proto::DialBackResponse {
274        status: proto::mod_DialBackResponse::DialBackStatus::OK,
275    };
276    let mut framed = FramedWrite::new(
277        stream,
278        Codec::<proto::DialBackResponse>::new(DIAL_BACK_MAX_SIZE),
279    );
280    framed
281        .send(msg)
282        .await
283        .map_err(|e| io::Error::new(io::ErrorKind::Other, e))?;
284
285    Ok(())
286}
287
288pub(crate) async fn recv_dial_back_response(
289    stream: impl AsyncRead + AsyncWrite + Unpin,
290) -> io::Result<()> {
291    let framed = &mut FramedRead::new(
292        stream,
293        Codec::<proto::DialBackResponse>::new(DIAL_BACK_MAX_SIZE),
294    );
295    let proto::DialBackResponse { status } = framed
296        .next()
297        .await
298        .ok_or(io::Error::from(io::ErrorKind::UnexpectedEof))??;
299
300    if proto::mod_DialBackResponse::DialBackStatus::OK == status {
301        Ok(())
302    } else {
303        Err(io::Error::new(
304            io::ErrorKind::InvalidData,
305            "invalid dial back response",
306        ))
307    }
308}
309
310#[cfg(test)]
311mod tests {
312    use crate::v2::generated::structs::{
313        mod_Message::OneOfmsg, DialDataResponse as GenDialDataResponse, Message,
314    };
315
316    #[test]
317    fn message_correct_max_size() {
318        let message_bytes = quick_protobuf::serialize_into_vec(&Message {
319            msg: OneOfmsg::dialDataResponse(GenDialDataResponse {
320                data: vec![0; 4096],
321            }),
322        })
323        .unwrap();
324        assert_eq!(message_bytes.len(), super::REQUEST_MAX_SIZE);
325    }
326
327    #[test]
328    fn dial_back_correct_size() {
329        let dial_back = super::proto::DialBack { nonce: 0 };
330        let buf = quick_protobuf::serialize_into_vec(&dial_back).unwrap();
331        assert!(buf.len() <= super::DIAL_BACK_MAX_SIZE);
332
333        let dial_back_max_nonce = super::proto::DialBack { nonce: u64::MAX };
334        let buf = quick_protobuf::serialize_into_vec(&dial_back_max_nonce).unwrap();
335        assert!(buf.len() <= super::DIAL_BACK_MAX_SIZE);
336    }
337}