1use std::{io, marker::PhantomData, time::Duration};
30
31use asynchronous_codec::{Decoder, Encoder, Framed};
32use bytes::BytesMut;
33use futures::prelude::*;
34use libp2p_core::{
35 upgrade::{InboundUpgrade, OutboundUpgrade, UpgradeInfo},
36 Multiaddr,
37};
38use libp2p_identity::PeerId;
39use libp2p_swarm::StreamProtocol;
40use tracing::debug;
41use web_time::Instant;
42
43use crate::{
44 proto,
45 record::{self, Record},
46};
47
48pub(crate) const DEFAULT_PROTO_NAME: StreamProtocol = StreamProtocol::new("/ipfs/kad/1.0.0");
50pub(crate) const DEFAULT_MAX_PACKET_SIZE: usize = 16 * 1024;
52const DEFAULT_OUTBOUND_SUBSTREAMS_TIMEOUT_S: Duration = Duration::from_secs(10);
54#[derive(Copy, Clone, PartialEq, Eq, Debug, Hash)]
56pub enum ConnectionType {
57 NotConnected = 0,
59 Connected = 1,
61 CanConnect = 2,
63 CannotConnect = 3,
65}
66
67impl From<proto::ConnectionType> for ConnectionType {
68 fn from(raw: proto::ConnectionType) -> ConnectionType {
69 use proto::ConnectionType::*;
70 match raw {
71 NOT_CONNECTED => ConnectionType::NotConnected,
72 CONNECTED => ConnectionType::Connected,
73 CAN_CONNECT => ConnectionType::CanConnect,
74 CANNOT_CONNECT => ConnectionType::CannotConnect,
75 }
76 }
77}
78
79impl From<ConnectionType> for proto::ConnectionType {
80 fn from(val: ConnectionType) -> Self {
81 use proto::ConnectionType::*;
82 match val {
83 ConnectionType::NotConnected => NOT_CONNECTED,
84 ConnectionType::Connected => CONNECTED,
85 ConnectionType::CanConnect => CAN_CONNECT,
86 ConnectionType::CannotConnect => CANNOT_CONNECT,
87 }
88 }
89}
90
91#[derive(Debug, Clone, PartialEq, Eq)]
93pub struct KadPeer {
94 pub node_id: PeerId,
96 pub multiaddrs: Vec<Multiaddr>,
98 pub connection_ty: ConnectionType,
100}
101
102impl TryFrom<proto::Peer> for KadPeer {
104 type Error = io::Error;
105
106 fn try_from(peer: proto::Peer) -> Result<KadPeer, Self::Error> {
107 let node_id = PeerId::from_bytes(&peer.id).map_err(|_| invalid_data("invalid peer id"))?;
110
111 let mut addrs = Vec::with_capacity(peer.addrs.len());
112 for addr in peer.addrs.into_iter() {
113 match Multiaddr::try_from(addr).map(|addr| addr.with_p2p(node_id)) {
114 Ok(Ok(a)) => addrs.push(a),
115 Ok(Err(a)) => {
116 debug!("Unable to parse multiaddr: {a} is not compatible with {node_id}")
117 }
118 Err(e) => debug!("Unable to parse multiaddr: {e}"),
119 };
120 }
121
122 Ok(KadPeer {
123 node_id,
124 multiaddrs: addrs,
125 connection_ty: peer.connection.into(),
126 })
127 }
128}
129
130impl From<KadPeer> for proto::Peer {
131 fn from(peer: KadPeer) -> Self {
132 proto::Peer {
133 id: peer.node_id.to_bytes(),
134 addrs: peer.multiaddrs.into_iter().map(|a| a.to_vec()).collect(),
135 connection: peer.connection_ty.into(),
136 }
137 }
138}
139
140#[derive(Debug, Clone)]
146pub struct ProtocolConfig {
147 protocol_names: Vec<StreamProtocol>,
148 max_packet_size: usize,
150 outbound_substreams_timeout_s: Duration,
152}
153
154impl ProtocolConfig {
155 pub fn new(protocol_name: StreamProtocol) -> Self {
157 ProtocolConfig {
158 protocol_names: vec![protocol_name],
159 max_packet_size: DEFAULT_MAX_PACKET_SIZE,
160 outbound_substreams_timeout_s: DEFAULT_OUTBOUND_SUBSTREAMS_TIMEOUT_S,
161 }
162 }
163
164 pub fn protocol_names(&self) -> &[StreamProtocol] {
166 &self.protocol_names
167 }
168
169 pub fn set_max_packet_size(&mut self, size: usize) {
171 self.max_packet_size = size;
172 }
173
174 pub fn set_outbound_substreams_timeout(&mut self, timeout: Duration) {
176 self.outbound_substreams_timeout_s = timeout;
177 }
178
179 pub fn outbound_substreams_timeout_s(&self) -> Duration {
181 self.outbound_substreams_timeout_s
182 }
183}
184
185impl UpgradeInfo for ProtocolConfig {
186 type Info = StreamProtocol;
187 type InfoIter = std::vec::IntoIter<Self::Info>;
188
189 fn protocol_info(&self) -> Self::InfoIter {
190 self.protocol_names.clone().into_iter()
191 }
192}
193
194pub struct Codec<A, B> {
196 codec: quick_protobuf_codec::Codec<proto::Message>,
197 __phantom: PhantomData<(A, B)>,
198}
199impl<A, B> Codec<A, B> {
200 fn new(max_packet_size: usize) -> Self {
201 Codec {
202 codec: quick_protobuf_codec::Codec::new(max_packet_size),
203 __phantom: PhantomData,
204 }
205 }
206}
207
208impl<A: Into<proto::Message>, B> Encoder for Codec<A, B> {
209 type Error = io::Error;
210 type Item<'a> = A;
211
212 fn encode(&mut self, item: Self::Item<'_>, dst: &mut BytesMut) -> Result<(), Self::Error> {
213 Ok(self.codec.encode(item.into(), dst)?)
214 }
215}
216impl<A, B: TryFrom<proto::Message, Error = io::Error>> Decoder for Codec<A, B> {
217 type Error = io::Error;
218 type Item = B;
219
220 fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
221 self.codec.decode(src)?.map(B::try_from).transpose()
222 }
223}
224
225pub(crate) type KadInStreamSink<S> = Framed<S, Codec<KadResponseMsg, KadRequestMsg>>;
227pub(crate) type KadOutStreamSink<S> = Framed<S, Codec<KadRequestMsg, KadResponseMsg>>;
229
230impl<C> InboundUpgrade<C> for ProtocolConfig
231where
232 C: AsyncRead + AsyncWrite + Unpin,
233{
234 type Output = KadInStreamSink<C>;
235 type Future = future::Ready<Result<Self::Output, io::Error>>;
236 type Error = io::Error;
237
238 fn upgrade_inbound(self, incoming: C, _: Self::Info) -> Self::Future {
239 let codec = Codec::new(self.max_packet_size);
240
241 future::ok(Framed::new(incoming, codec))
242 }
243}
244
245impl<C> OutboundUpgrade<C> for ProtocolConfig
246where
247 C: AsyncRead + AsyncWrite + Unpin,
248{
249 type Output = KadOutStreamSink<C>;
250 type Future = future::Ready<Result<Self::Output, io::Error>>;
251 type Error = io::Error;
252
253 fn upgrade_outbound(self, incoming: C, _: Self::Info) -> Self::Future {
254 let codec = Codec::new(self.max_packet_size);
255
256 future::ok(Framed::new(incoming, codec))
257 }
258}
259
260#[derive(Debug, Clone, PartialEq, Eq)]
262pub enum KadRequestMsg {
263 Ping,
265
266 FindNode {
269 key: Vec<u8>,
271 },
272
273 GetProviders {
276 key: record::Key,
278 },
279
280 AddProvider {
282 key: record::Key,
284 provider: KadPeer,
286 },
287
288 GetValue {
290 key: record::Key,
292 },
293
294 PutValue { record: Record },
296}
297
298#[derive(Debug, Clone, PartialEq, Eq)]
300pub enum KadResponseMsg {
301 Pong,
303
304 FindNode {
306 closer_peers: Vec<KadPeer>,
308 },
309
310 GetProviders {
312 closer_peers: Vec<KadPeer>,
314 provider_peers: Vec<KadPeer>,
316 },
317
318 GetValue {
320 record: Option<Record>,
322 closer_peers: Vec<KadPeer>,
324 },
325
326 PutValue {
328 key: record::Key,
330 value: Vec<u8>,
332 },
333}
334
335impl From<KadRequestMsg> for proto::Message {
336 fn from(kad_msg: KadRequestMsg) -> Self {
337 req_msg_to_proto(kad_msg)
338 }
339}
340impl From<KadResponseMsg> for proto::Message {
341 fn from(kad_msg: KadResponseMsg) -> Self {
342 resp_msg_to_proto(kad_msg)
343 }
344}
345impl TryFrom<proto::Message> for KadRequestMsg {
346 type Error = io::Error;
347
348 fn try_from(message: proto::Message) -> Result<Self, Self::Error> {
349 proto_to_req_msg(message)
350 }
351}
352impl TryFrom<proto::Message> for KadResponseMsg {
353 type Error = io::Error;
354
355 fn try_from(message: proto::Message) -> Result<Self, Self::Error> {
356 proto_to_resp_msg(message)
357 }
358}
359
360fn req_msg_to_proto(kad_msg: KadRequestMsg) -> proto::Message {
362 match kad_msg {
363 KadRequestMsg::Ping => proto::Message {
364 type_pb: proto::MessageType::PING,
365 ..proto::Message::default()
366 },
367 KadRequestMsg::FindNode { key } => proto::Message {
368 type_pb: proto::MessageType::FIND_NODE,
369 key,
370 clusterLevelRaw: 10,
371 ..proto::Message::default()
372 },
373 KadRequestMsg::GetProviders { key } => proto::Message {
374 type_pb: proto::MessageType::GET_PROVIDERS,
375 key: key.to_vec(),
376 clusterLevelRaw: 10,
377 ..proto::Message::default()
378 },
379 KadRequestMsg::AddProvider { key, provider } => proto::Message {
380 type_pb: proto::MessageType::ADD_PROVIDER,
381 clusterLevelRaw: 10,
382 key: key.to_vec(),
383 providerPeers: vec![provider.into()],
384 ..proto::Message::default()
385 },
386 KadRequestMsg::GetValue { key } => proto::Message {
387 type_pb: proto::MessageType::GET_VALUE,
388 clusterLevelRaw: 10,
389 key: key.to_vec(),
390 ..proto::Message::default()
391 },
392 KadRequestMsg::PutValue { record } => proto::Message {
393 type_pb: proto::MessageType::PUT_VALUE,
394 key: record.key.to_vec(),
395 record: Some(record_to_proto(record)),
396 ..proto::Message::default()
397 },
398 }
399}
400
401fn resp_msg_to_proto(kad_msg: KadResponseMsg) -> proto::Message {
403 match kad_msg {
404 KadResponseMsg::Pong => proto::Message {
405 type_pb: proto::MessageType::PING,
406 ..proto::Message::default()
407 },
408 KadResponseMsg::FindNode { closer_peers } => proto::Message {
409 type_pb: proto::MessageType::FIND_NODE,
410 clusterLevelRaw: 9,
411 closerPeers: closer_peers.into_iter().map(KadPeer::into).collect(),
412 ..proto::Message::default()
413 },
414 KadResponseMsg::GetProviders {
415 closer_peers,
416 provider_peers,
417 } => proto::Message {
418 type_pb: proto::MessageType::GET_PROVIDERS,
419 clusterLevelRaw: 9,
420 closerPeers: closer_peers.into_iter().map(KadPeer::into).collect(),
421 providerPeers: provider_peers.into_iter().map(KadPeer::into).collect(),
422 ..proto::Message::default()
423 },
424 KadResponseMsg::GetValue {
425 record,
426 closer_peers,
427 } => proto::Message {
428 type_pb: proto::MessageType::GET_VALUE,
429 clusterLevelRaw: 9,
430 closerPeers: closer_peers.into_iter().map(KadPeer::into).collect(),
431 record: record.map(record_to_proto),
432 ..proto::Message::default()
433 },
434 KadResponseMsg::PutValue { key, value } => proto::Message {
435 type_pb: proto::MessageType::PUT_VALUE,
436 key: key.to_vec(),
437 record: Some(proto::Record {
438 key: key.to_vec(),
439 value,
440 ..proto::Record::default()
441 }),
442 ..proto::Message::default()
443 },
444 }
445}
446
447fn proto_to_req_msg(message: proto::Message) -> Result<KadRequestMsg, io::Error> {
451 match message.type_pb {
452 proto::MessageType::PING => Ok(KadRequestMsg::Ping),
453 proto::MessageType::PUT_VALUE => {
454 let record = record_from_proto(message.record.unwrap_or_default())?;
455 Ok(KadRequestMsg::PutValue { record })
456 }
457 proto::MessageType::GET_VALUE => Ok(KadRequestMsg::GetValue {
458 key: record::Key::from(message.key),
459 }),
460 proto::MessageType::FIND_NODE => Ok(KadRequestMsg::FindNode { key: message.key }),
461 proto::MessageType::GET_PROVIDERS => Ok(KadRequestMsg::GetProviders {
462 key: record::Key::from(message.key),
463 }),
464 proto::MessageType::ADD_PROVIDER => {
465 let provider = message
469 .providerPeers
470 .into_iter()
471 .find_map(|peer| KadPeer::try_from(peer).ok());
472
473 if let Some(provider) = provider {
474 let key = record::Key::from(message.key);
475 Ok(KadRequestMsg::AddProvider { key, provider })
476 } else {
477 Err(invalid_data("AddProvider message with no valid peer."))
478 }
479 }
480 }
481}
482
483fn proto_to_resp_msg(message: proto::Message) -> Result<KadResponseMsg, io::Error> {
487 match message.type_pb {
488 proto::MessageType::PING => Ok(KadResponseMsg::Pong),
489 proto::MessageType::GET_VALUE => {
490 let record = if let Some(r) = message.record {
491 Some(record_from_proto(r)?)
492 } else {
493 None
494 };
495
496 let closer_peers = message
497 .closerPeers
498 .into_iter()
499 .filter_map(|peer| KadPeer::try_from(peer).ok())
500 .collect();
501
502 Ok(KadResponseMsg::GetValue {
503 record,
504 closer_peers,
505 })
506 }
507
508 proto::MessageType::FIND_NODE => {
509 let closer_peers = message
510 .closerPeers
511 .into_iter()
512 .filter_map(|peer| KadPeer::try_from(peer).ok())
513 .collect();
514
515 Ok(KadResponseMsg::FindNode { closer_peers })
516 }
517
518 proto::MessageType::GET_PROVIDERS => {
519 let closer_peers = message
520 .closerPeers
521 .into_iter()
522 .filter_map(|peer| KadPeer::try_from(peer).ok())
523 .collect();
524
525 let provider_peers = message
526 .providerPeers
527 .into_iter()
528 .filter_map(|peer| KadPeer::try_from(peer).ok())
529 .collect();
530
531 Ok(KadResponseMsg::GetProviders {
532 closer_peers,
533 provider_peers,
534 })
535 }
536
537 proto::MessageType::PUT_VALUE => {
538 let key = record::Key::from(message.key);
539 let rec = message
540 .record
541 .ok_or_else(|| invalid_data("received PutValue message with no record"))?;
542
543 Ok(KadResponseMsg::PutValue {
544 key,
545 value: rec.value,
546 })
547 }
548
549 proto::MessageType::ADD_PROVIDER => {
550 Err(invalid_data("received an unexpected AddProvider message"))
551 }
552 }
553}
554
555fn record_from_proto(record: proto::Record) -> Result<Record, io::Error> {
556 let key = record::Key::from(record.key);
557 let value = record.value;
558
559 let publisher = if !record.publisher.is_empty() {
560 PeerId::from_bytes(&record.publisher)
561 .map(Some)
562 .map_err(|_| invalid_data("Invalid publisher peer ID."))?
563 } else {
564 None
565 };
566
567 let expires = if record.ttl > 0 {
568 Some(Instant::now() + Duration::from_secs(record.ttl as u64))
569 } else {
570 None
571 };
572
573 Ok(Record {
574 key,
575 value,
576 publisher,
577 expires,
578 })
579}
580
581fn record_to_proto(record: Record) -> proto::Record {
582 proto::Record {
583 key: record.key.to_vec(),
584 value: record.value,
585 publisher: record.publisher.map(|id| id.to_bytes()).unwrap_or_default(),
586 ttl: record
587 .expires
588 .map(|t| {
589 let now = Instant::now();
590 if t > now {
591 (t - now).as_secs() as u32
592 } else {
593 1 }
595 })
596 .unwrap_or(0),
597 timeReceived: String::new(),
598 }
599}
600
601fn invalid_data<E>(e: E) -> io::Error
603where
604 E: Into<Box<dyn std::error::Error + Send + Sync>>,
605{
606 io::Error::new(io::ErrorKind::InvalidData, e)
607}
608
609#[cfg(test)]
610mod tests {
611 use super::*;
612
613 #[test]
614 fn append_p2p() {
615 let peer_id = PeerId::random();
616 let multiaddr = "/ip6/2001:db8::/tcp/1234".parse::<Multiaddr>().unwrap();
617
618 let payload = proto::Peer {
619 id: peer_id.to_bytes(),
620 addrs: vec![multiaddr.to_vec()],
621 connection: proto::ConnectionType::CAN_CONNECT,
622 };
623
624 let peer = KadPeer::try_from(payload).unwrap();
625
626 assert_eq!(peer.multiaddrs, vec![multiaddr.with_p2p(peer_id).unwrap()])
627 }
628
629 #[test]
630 fn skip_invalid_multiaddr() {
631 let peer_id = PeerId::random();
632 let multiaddr = "/ip6/2001:db8::/tcp/1234".parse::<Multiaddr>().unwrap();
633
634 let valid_multiaddr = multiaddr.clone().with_p2p(peer_id).unwrap();
635
636 let multiaddr_with_incorrect_peer_id = {
637 let other_peer_id = PeerId::random();
638 assert_ne!(peer_id, other_peer_id);
639 multiaddr.with_p2p(other_peer_id).unwrap()
640 };
641
642 let invalid_multiaddr = {
643 let a = vec![255; 8];
644 assert!(Multiaddr::try_from(a.clone()).is_err());
645 a
646 };
647
648 let payload = proto::Peer {
649 id: peer_id.to_bytes(),
650 addrs: vec![
651 valid_multiaddr.to_vec(),
652 multiaddr_with_incorrect_peer_id.to_vec(),
653 invalid_multiaddr,
654 ],
655 connection: proto::ConnectionType::CAN_CONNECT,
656 };
657
658 let peer = KadPeer::try_from(payload).unwrap();
659
660 assert_eq!(peer.multiaddrs, vec![valid_multiaddr])
661 }
662
663 }