1use std::io;
22
23use async_trait::async_trait;
24use asynchronous_codec::{FramedRead, FramedWrite};
25use futures::{
26 io::{AsyncRead, AsyncWrite},
27 SinkExt, StreamExt,
28};
29use libp2p_core::Multiaddr;
30use libp2p_identity::PeerId;
31use libp2p_request_response::{self as request_response};
32use libp2p_swarm::StreamProtocol;
33
34use crate::proto;
35
36pub const DEFAULT_PROTOCOL_NAME: StreamProtocol = StreamProtocol::new("/libp2p/autonat/1.0.0");
38
39#[derive(Clone)]
40pub struct AutoNatCodec;
41
42#[async_trait]
43impl request_response::Codec for AutoNatCodec {
44 type Protocol = StreamProtocol;
45 type Request = DialRequest;
46 type Response = DialResponse;
47
48 async fn read_request<T>(&mut self, _: &StreamProtocol, io: &mut T) -> io::Result<Self::Request>
49 where
50 T: AsyncRead + Send + Unpin,
51 {
52 let message = FramedRead::new(io, codec())
53 .next()
54 .await
55 .ok_or(io::ErrorKind::UnexpectedEof)??;
56 let request = DialRequest::from_proto(message)?;
57
58 Ok(request)
59 }
60
61 async fn read_response<T>(
62 &mut self,
63 _: &StreamProtocol,
64 io: &mut T,
65 ) -> io::Result<Self::Response>
66 where
67 T: AsyncRead + Send + Unpin,
68 {
69 let message = FramedRead::new(io, codec())
70 .next()
71 .await
72 .ok_or(io::ErrorKind::UnexpectedEof)??;
73 let response = DialResponse::from_proto(message)?;
74
75 Ok(response)
76 }
77
78 async fn write_request<T>(
79 &mut self,
80 _: &StreamProtocol,
81 io: &mut T,
82 data: Self::Request,
83 ) -> io::Result<()>
84 where
85 T: AsyncWrite + Send + Unpin,
86 {
87 let mut framed = FramedWrite::new(io, codec());
88 framed.send(data.into_proto()).await?;
89 framed.close().await?;
90
91 Ok(())
92 }
93
94 async fn write_response<T>(
95 &mut self,
96 _: &StreamProtocol,
97 io: &mut T,
98 data: Self::Response,
99 ) -> io::Result<()>
100 where
101 T: AsyncWrite + Send + Unpin,
102 {
103 let mut framed = FramedWrite::new(io, codec());
104 framed.send(data.into_proto()).await?;
105 framed.close().await?;
106
107 Ok(())
108 }
109}
110
111fn codec() -> quick_protobuf_codec::Codec<proto::Message> {
112 quick_protobuf_codec::Codec::<proto::Message>::new(1024)
113}
114
115#[derive(Clone, Debug, Eq, PartialEq)]
116pub struct DialRequest {
117 pub peer_id: PeerId,
118 pub addresses: Vec<Multiaddr>,
119}
120
121impl DialRequest {
122 pub fn from_proto(msg: proto::Message) -> Result<Self, io::Error> {
123 if msg.type_pb != Some(proto::MessageType::DIAL) {
124 return Err(io::Error::new(io::ErrorKind::InvalidData, "invalid type"));
125 }
126
127 let peer_id_result = msg.dial.and_then(|dial| {
128 dial.peer
129 .and_then(|peer_info| peer_info.id.map(|peer_id| (peer_id, peer_info.addrs)))
130 });
131
132 let (peer_id, addrs) = peer_id_result
133 .ok_or_else(|| io::Error::new(io::ErrorKind::InvalidData, "invalid dial message"))?;
134
135 let peer_id = {
136 PeerId::try_from(peer_id.to_vec())
137 .map_err(|_| io::Error::new(io::ErrorKind::InvalidData, "invalid peer id"))?
138 };
139
140 let addrs = addrs
141 .into_iter()
142 .filter_map(|a| match Multiaddr::try_from(a.to_vec()) {
143 Ok(a) => Some(a),
144 Err(e) => {
145 tracing::debug!("Unable to parse multiaddr: {e}");
146 None
147 }
148 })
149 .collect();
150 Ok(Self {
151 peer_id,
152 addresses: addrs,
153 })
154 }
155
156 pub fn into_proto(self) -> proto::Message {
157 let peer_id = self.peer_id.to_bytes();
158 let addrs = self
159 .addresses
160 .into_iter()
161 .map(|addr| addr.to_vec())
162 .collect();
163
164 proto::Message {
165 type_pb: Some(proto::MessageType::DIAL),
166 dial: Some(proto::Dial {
167 peer: Some(proto::PeerInfo {
168 id: Some(peer_id.to_vec()),
169 addrs,
170 }),
171 }),
172 dialResponse: None,
173 }
174 }
175}
176
177#[derive(Clone, Debug, Eq, PartialEq)]
178pub enum ResponseError {
179 DialError,
180 DialRefused,
181 BadRequest,
182 InternalError,
183}
184
185impl From<ResponseError> for proto::ResponseStatus {
186 fn from(t: ResponseError) -> Self {
187 match t {
188 ResponseError::DialError => proto::ResponseStatus::E_DIAL_ERROR,
189 ResponseError::DialRefused => proto::ResponseStatus::E_DIAL_REFUSED,
190 ResponseError::BadRequest => proto::ResponseStatus::E_BAD_REQUEST,
191 ResponseError::InternalError => proto::ResponseStatus::E_INTERNAL_ERROR,
192 }
193 }
194}
195
196impl TryFrom<proto::ResponseStatus> for ResponseError {
197 type Error = io::Error;
198
199 fn try_from(value: proto::ResponseStatus) -> Result<Self, Self::Error> {
200 match value {
201 proto::ResponseStatus::E_DIAL_ERROR => Ok(ResponseError::DialError),
202 proto::ResponseStatus::E_DIAL_REFUSED => Ok(ResponseError::DialRefused),
203 proto::ResponseStatus::E_BAD_REQUEST => Ok(ResponseError::BadRequest),
204 proto::ResponseStatus::E_INTERNAL_ERROR => Ok(ResponseError::InternalError),
205 proto::ResponseStatus::OK => {
206 tracing::debug!("Received response with status code OK but expected error");
207 Err(io::Error::new(
208 io::ErrorKind::InvalidData,
209 "invalid response error type",
210 ))
211 }
212 }
213 }
214}
215
216#[derive(Clone, Debug, Eq, PartialEq)]
217pub struct DialResponse {
218 pub status_text: Option<String>,
219 pub result: Result<Multiaddr, ResponseError>,
220}
221
222impl DialResponse {
223 pub fn from_proto(msg: proto::Message) -> Result<Self, io::Error> {
224 if msg.type_pb != Some(proto::MessageType::DIAL_RESPONSE) {
225 return Err(io::Error::new(io::ErrorKind::InvalidData, "invalid type"));
226 }
227
228 Ok(match msg.dialResponse {
229 Some(proto::DialResponse {
230 status: Some(proto::ResponseStatus::OK),
231 statusText,
232 addr: Some(addr),
233 }) => {
234 let addr = Multiaddr::try_from(addr.to_vec())
235 .map_err(|err| io::Error::new(io::ErrorKind::InvalidData, err))?;
236 Self {
237 status_text: statusText,
238 result: Ok(addr),
239 }
240 }
241 Some(proto::DialResponse {
242 status: Some(status),
243 statusText,
244 addr: None,
245 }) => Self {
246 status_text: statusText,
247 result: Err(ResponseError::try_from(status)?),
248 },
249 _ => {
250 tracing::debug!("Received malformed response message");
251 return Err(io::Error::new(
252 io::ErrorKind::InvalidData,
253 "invalid dial response message",
254 ));
255 }
256 })
257 }
258
259 pub fn into_proto(self) -> proto::Message {
260 let dial_response = match self.result {
261 Ok(addr) => proto::DialResponse {
262 status: Some(proto::ResponseStatus::OK),
263 statusText: self.status_text,
264 addr: Some(addr.to_vec()),
265 },
266 Err(error) => proto::DialResponse {
267 status: Some(error.into()),
268 statusText: self.status_text,
269 addr: None,
270 },
271 };
272
273 proto::Message {
274 type_pb: Some(proto::MessageType::DIAL_RESPONSE),
275 dial: None,
276 dialResponse: Some(dial_response),
277 }
278 }
279}
280
281#[cfg(test)]
282mod tests {
283 use super::*;
284
285 #[test]
286 fn test_request_encode_decode() {
287 let request = DialRequest {
288 peer_id: PeerId::random(),
289 addresses: vec![
290 "/ip4/8.8.8.8/tcp/30333".parse().unwrap(),
291 "/ip4/192.168.1.42/tcp/30333".parse().unwrap(),
292 ],
293 };
294 let proto = request.clone().into_proto();
295 let request2 = DialRequest::from_proto(proto).unwrap();
296 assert_eq!(request, request2);
297 }
298
299 #[test]
300 fn test_response_ok_encode_decode() {
301 let response = DialResponse {
302 result: Ok("/ip4/8.8.8.8/tcp/30333".parse().unwrap()),
303 status_text: None,
304 };
305 let proto = response.clone().into_proto();
306 let response2 = DialResponse::from_proto(proto).unwrap();
307 assert_eq!(response, response2);
308 }
309
310 #[test]
311 fn test_response_err_encode_decode() {
312 let response = DialResponse {
313 result: Err(ResponseError::DialError),
314 status_text: Some("dial failed".to_string()),
315 };
316 let proto = response.clone().into_proto();
317 let response2 = DialResponse::from_proto(proto).unwrap();
318 assert_eq!(response, response2);
319 }
320
321 #[test]
322 fn test_skip_unparsable_multiaddr() {
323 let valid_multiaddr: Multiaddr = "/ip6/2001:db8::/tcp/1234".parse().unwrap();
324 let valid_multiaddr_bytes = valid_multiaddr.to_vec();
325
326 let invalid_multiaddr = {
327 let a = vec![255; 8];
328 assert!(Multiaddr::try_from(a.clone()).is_err());
329 a
330 };
331
332 let msg = proto::Message {
333 type_pb: Some(proto::MessageType::DIAL),
334 dial: Some(proto::Dial {
335 peer: Some(proto::PeerInfo {
336 id: Some(PeerId::random().to_bytes()),
337 addrs: vec![valid_multiaddr_bytes, invalid_multiaddr],
338 }),
339 }),
340 dialResponse: None,
341 };
342
343 let request = DialRequest::from_proto(msg).expect("not to fail");
344
345 assert_eq!(request.addresses, vec![valid_multiaddr])
346 }
347}