libp2p_autonat/v1/
protocol.rs

1// Copyright 2021 Protocol Labs.
2//
3// Permission is hereby granted, free of charge, to any person obtaining a
4// copy of this software and associated documentation files (the "Software"),
5// to deal in the Software without restriction, including without limitation
6// the rights to use, copy, modify, merge, publish, distribute, sublicense,
7// and/or sell copies of the Software, and to permit persons to whom the
8// Software is furnished to do so, subject to the following conditions:
9//
10// The above copyright notice and this permission notice shall be included in
11// all copies or substantial portions of the Software.
12//
13// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
14// OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
16// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
18// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
19// DEALINGS IN THE SOFTWARE.
20
21use std::io;
22
23use async_trait::async_trait;
24use asynchronous_codec::{FramedRead, FramedWrite};
25use futures::{
26    io::{AsyncRead, AsyncWrite},
27    SinkExt, StreamExt,
28};
29use libp2p_core::Multiaddr;
30use libp2p_identity::PeerId;
31use libp2p_request_response::{self as request_response};
32use libp2p_swarm::StreamProtocol;
33
34use crate::proto;
35
36/// The protocol name used for negotiating with multistream-select.
37pub const DEFAULT_PROTOCOL_NAME: StreamProtocol = StreamProtocol::new("/libp2p/autonat/1.0.0");
38
39#[derive(Clone)]
40pub struct AutoNatCodec;
41
42#[async_trait]
43impl request_response::Codec for AutoNatCodec {
44    type Protocol = StreamProtocol;
45    type Request = DialRequest;
46    type Response = DialResponse;
47
48    async fn read_request<T>(&mut self, _: &StreamProtocol, io: &mut T) -> io::Result<Self::Request>
49    where
50        T: AsyncRead + Send + Unpin,
51    {
52        let message = FramedRead::new(io, codec())
53            .next()
54            .await
55            .ok_or(io::ErrorKind::UnexpectedEof)??;
56        let request = DialRequest::from_proto(message)?;
57
58        Ok(request)
59    }
60
61    async fn read_response<T>(
62        &mut self,
63        _: &StreamProtocol,
64        io: &mut T,
65    ) -> io::Result<Self::Response>
66    where
67        T: AsyncRead + Send + Unpin,
68    {
69        let message = FramedRead::new(io, codec())
70            .next()
71            .await
72            .ok_or(io::ErrorKind::UnexpectedEof)??;
73        let response = DialResponse::from_proto(message)?;
74
75        Ok(response)
76    }
77
78    async fn write_request<T>(
79        &mut self,
80        _: &StreamProtocol,
81        io: &mut T,
82        data: Self::Request,
83    ) -> io::Result<()>
84    where
85        T: AsyncWrite + Send + Unpin,
86    {
87        let mut framed = FramedWrite::new(io, codec());
88        framed.send(data.into_proto()).await?;
89        framed.close().await?;
90
91        Ok(())
92    }
93
94    async fn write_response<T>(
95        &mut self,
96        _: &StreamProtocol,
97        io: &mut T,
98        data: Self::Response,
99    ) -> io::Result<()>
100    where
101        T: AsyncWrite + Send + Unpin,
102    {
103        let mut framed = FramedWrite::new(io, codec());
104        framed.send(data.into_proto()).await?;
105        framed.close().await?;
106
107        Ok(())
108    }
109}
110
111fn codec() -> quick_protobuf_codec::Codec<proto::Message> {
112    quick_protobuf_codec::Codec::<proto::Message>::new(1024)
113}
114
115#[derive(Clone, Debug, Eq, PartialEq)]
116pub struct DialRequest {
117    pub peer_id: PeerId,
118    pub addresses: Vec<Multiaddr>,
119}
120
121impl DialRequest {
122    pub fn from_proto(msg: proto::Message) -> Result<Self, io::Error> {
123        if msg.type_pb != Some(proto::MessageType::DIAL) {
124            return Err(io::Error::new(io::ErrorKind::InvalidData, "invalid type"));
125        }
126
127        let peer_id_result = msg.dial.and_then(|dial| {
128            dial.peer
129                .and_then(|peer_info| peer_info.id.map(|peer_id| (peer_id, peer_info.addrs)))
130        });
131
132        let (peer_id, addrs) = peer_id_result
133            .ok_or_else(|| io::Error::new(io::ErrorKind::InvalidData, "invalid dial message"))?;
134
135        let peer_id = {
136            PeerId::try_from(peer_id.to_vec())
137                .map_err(|_| io::Error::new(io::ErrorKind::InvalidData, "invalid peer id"))?
138        };
139
140        let addrs = addrs
141            .into_iter()
142            .filter_map(|a| match Multiaddr::try_from(a.to_vec()) {
143                Ok(a) => Some(a),
144                Err(e) => {
145                    tracing::debug!("Unable to parse multiaddr: {e}");
146                    None
147                }
148            })
149            .collect();
150        Ok(Self {
151            peer_id,
152            addresses: addrs,
153        })
154    }
155
156    pub fn into_proto(self) -> proto::Message {
157        let peer_id = self.peer_id.to_bytes();
158        let addrs = self
159            .addresses
160            .into_iter()
161            .map(|addr| addr.to_vec())
162            .collect();
163
164        proto::Message {
165            type_pb: Some(proto::MessageType::DIAL),
166            dial: Some(proto::Dial {
167                peer: Some(proto::PeerInfo {
168                    id: Some(peer_id.to_vec()),
169                    addrs,
170                }),
171            }),
172            dialResponse: None,
173        }
174    }
175}
176
177#[derive(Clone, Debug, Eq, PartialEq)]
178pub enum ResponseError {
179    DialError,
180    DialRefused,
181    BadRequest,
182    InternalError,
183}
184
185impl From<ResponseError> for proto::ResponseStatus {
186    fn from(t: ResponseError) -> Self {
187        match t {
188            ResponseError::DialError => proto::ResponseStatus::E_DIAL_ERROR,
189            ResponseError::DialRefused => proto::ResponseStatus::E_DIAL_REFUSED,
190            ResponseError::BadRequest => proto::ResponseStatus::E_BAD_REQUEST,
191            ResponseError::InternalError => proto::ResponseStatus::E_INTERNAL_ERROR,
192        }
193    }
194}
195
196impl TryFrom<proto::ResponseStatus> for ResponseError {
197    type Error = io::Error;
198
199    fn try_from(value: proto::ResponseStatus) -> Result<Self, Self::Error> {
200        match value {
201            proto::ResponseStatus::E_DIAL_ERROR => Ok(ResponseError::DialError),
202            proto::ResponseStatus::E_DIAL_REFUSED => Ok(ResponseError::DialRefused),
203            proto::ResponseStatus::E_BAD_REQUEST => Ok(ResponseError::BadRequest),
204            proto::ResponseStatus::E_INTERNAL_ERROR => Ok(ResponseError::InternalError),
205            proto::ResponseStatus::OK => {
206                tracing::debug!("Received response with status code OK but expected error");
207                Err(io::Error::new(
208                    io::ErrorKind::InvalidData,
209                    "invalid response error type",
210                ))
211            }
212        }
213    }
214}
215
216#[derive(Clone, Debug, Eq, PartialEq)]
217pub struct DialResponse {
218    pub status_text: Option<String>,
219    pub result: Result<Multiaddr, ResponseError>,
220}
221
222impl DialResponse {
223    pub fn from_proto(msg: proto::Message) -> Result<Self, io::Error> {
224        if msg.type_pb != Some(proto::MessageType::DIAL_RESPONSE) {
225            return Err(io::Error::new(io::ErrorKind::InvalidData, "invalid type"));
226        }
227
228        Ok(match msg.dialResponse {
229            Some(proto::DialResponse {
230                status: Some(proto::ResponseStatus::OK),
231                statusText,
232                addr: Some(addr),
233            }) => {
234                let addr = Multiaddr::try_from(addr.to_vec())
235                    .map_err(|err| io::Error::new(io::ErrorKind::InvalidData, err))?;
236                Self {
237                    status_text: statusText,
238                    result: Ok(addr),
239                }
240            }
241            Some(proto::DialResponse {
242                status: Some(status),
243                statusText,
244                addr: None,
245            }) => Self {
246                status_text: statusText,
247                result: Err(ResponseError::try_from(status)?),
248            },
249            _ => {
250                tracing::debug!("Received malformed response message");
251                return Err(io::Error::new(
252                    io::ErrorKind::InvalidData,
253                    "invalid dial response message",
254                ));
255            }
256        })
257    }
258
259    pub fn into_proto(self) -> proto::Message {
260        let dial_response = match self.result {
261            Ok(addr) => proto::DialResponse {
262                status: Some(proto::ResponseStatus::OK),
263                statusText: self.status_text,
264                addr: Some(addr.to_vec()),
265            },
266            Err(error) => proto::DialResponse {
267                status: Some(error.into()),
268                statusText: self.status_text,
269                addr: None,
270            },
271        };
272
273        proto::Message {
274            type_pb: Some(proto::MessageType::DIAL_RESPONSE),
275            dial: None,
276            dialResponse: Some(dial_response),
277        }
278    }
279}
280
281#[cfg(test)]
282mod tests {
283    use super::*;
284
285    #[test]
286    fn test_request_encode_decode() {
287        let request = DialRequest {
288            peer_id: PeerId::random(),
289            addresses: vec![
290                "/ip4/8.8.8.8/tcp/30333".parse().unwrap(),
291                "/ip4/192.168.1.42/tcp/30333".parse().unwrap(),
292            ],
293        };
294        let proto = request.clone().into_proto();
295        let request2 = DialRequest::from_proto(proto).unwrap();
296        assert_eq!(request, request2);
297    }
298
299    #[test]
300    fn test_response_ok_encode_decode() {
301        let response = DialResponse {
302            result: Ok("/ip4/8.8.8.8/tcp/30333".parse().unwrap()),
303            status_text: None,
304        };
305        let proto = response.clone().into_proto();
306        let response2 = DialResponse::from_proto(proto).unwrap();
307        assert_eq!(response, response2);
308    }
309
310    #[test]
311    fn test_response_err_encode_decode() {
312        let response = DialResponse {
313            result: Err(ResponseError::DialError),
314            status_text: Some("dial failed".to_string()),
315        };
316        let proto = response.clone().into_proto();
317        let response2 = DialResponse::from_proto(proto).unwrap();
318        assert_eq!(response, response2);
319    }
320
321    #[test]
322    fn test_skip_unparsable_multiaddr() {
323        let valid_multiaddr: Multiaddr = "/ip6/2001:db8::/tcp/1234".parse().unwrap();
324        let valid_multiaddr_bytes = valid_multiaddr.to_vec();
325
326        let invalid_multiaddr = {
327            let a = vec![255; 8];
328            assert!(Multiaddr::try_from(a.clone()).is_err());
329            a
330        };
331
332        let msg = proto::Message {
333            type_pb: Some(proto::MessageType::DIAL),
334            dial: Some(proto::Dial {
335                peer: Some(proto::PeerInfo {
336                    id: Some(PeerId::random().to_bytes()),
337                    addrs: vec![valid_multiaddr_bytes, invalid_multiaddr],
338                }),
339            }),
340            dialResponse: None,
341        };
342
343        let request = DialRequest::from_proto(msg).expect("not to fail");
344
345        assert_eq!(request.addresses, vec![valid_multiaddr])
346    }
347}