1use std::{
22 collections::VecDeque,
23 error, fmt, io,
24 marker::PhantomData,
25 pin::Pin,
26 task::{Context, Poll, Waker},
27 time::Duration,
28};
29
30use either::Either;
31use futures::{channel::oneshot, prelude::*, stream::SelectAll};
32use libp2p_core::{upgrade, ConnectedPoint};
33use libp2p_identity::PeerId;
34use libp2p_swarm::{
35 handler::{ConnectionEvent, FullyNegotiatedInbound, FullyNegotiatedOutbound},
36 ConnectionHandler, ConnectionHandlerEvent, Stream, StreamUpgradeError, SubstreamProtocol,
37 SupportedProtocols,
38};
39
40use crate::{
41 behaviour::Mode,
42 protocol::{
43 KadInStreamSink, KadOutStreamSink, KadPeer, KadRequestMsg, KadResponseMsg, ProtocolConfig,
44 },
45 record::{self, Record},
46 QueryId,
47};
48
49const MAX_NUM_STREAMS: usize = 32;
50
51pub struct Handler {
59 protocol_config: ProtocolConfig,
61
62 mode: Mode,
64
65 next_connec_unique_id: UniqueConnecId,
67
68 outbound_substreams:
70 futures_bounded::FuturesTupleSet<io::Result<Option<KadResponseMsg>>, QueryId>,
71
72 pending_streams:
74 VecDeque<oneshot::Sender<Result<KadOutStreamSink<Stream>, StreamUpgradeError<io::Error>>>>,
75
76 pending_messages: VecDeque<(KadRequestMsg, QueryId)>,
79
80 inbound_substreams: SelectAll<InboundSubstreamState>,
82
83 endpoint: ConnectedPoint,
86
87 remote_peer_id: PeerId,
89
90 protocol_status: Option<ProtocolStatus>,
92
93 remote_supported_protocols: SupportedProtocols,
94}
95
96#[derive(Debug, Copy, Clone, PartialEq)]
99struct ProtocolStatus {
100 supported: bool,
102 reported: bool,
104}
105
106enum InboundSubstreamState {
108 WaitingMessage {
110 first: bool,
112 connection_id: UniqueConnecId,
113 substream: KadInStreamSink<Stream>,
114 },
115 WaitingBehaviour(UniqueConnecId, KadInStreamSink<Stream>, Option<Waker>),
117 PendingSend(UniqueConnecId, KadInStreamSink<Stream>, KadResponseMsg),
119 PendingFlush(UniqueConnecId, KadInStreamSink<Stream>),
121 Closing(KadInStreamSink<Stream>),
123 Cancelled,
125
126 Poisoned {
127 phantom: PhantomData<QueryId>,
128 },
129}
130
131impl InboundSubstreamState {
132 fn try_answer_with(
133 &mut self,
134 id: RequestId,
135 msg: KadResponseMsg,
136 ) -> Result<(), KadResponseMsg> {
137 match std::mem::replace(
138 self,
139 InboundSubstreamState::Poisoned {
140 phantom: PhantomData,
141 },
142 ) {
143 InboundSubstreamState::WaitingBehaviour(conn_id, substream, mut waker)
144 if conn_id == id.connec_unique_id =>
145 {
146 *self = InboundSubstreamState::PendingSend(conn_id, substream, msg);
147
148 if let Some(waker) = waker.take() {
149 waker.wake();
150 }
151
152 Ok(())
153 }
154 other => {
155 *self = other;
156
157 Err(msg)
158 }
159 }
160 }
161
162 fn close(&mut self) {
163 match std::mem::replace(
164 self,
165 InboundSubstreamState::Poisoned {
166 phantom: PhantomData,
167 },
168 ) {
169 InboundSubstreamState::WaitingMessage { substream, .. }
170 | InboundSubstreamState::WaitingBehaviour(_, substream, _)
171 | InboundSubstreamState::PendingSend(_, substream, _)
172 | InboundSubstreamState::PendingFlush(_, substream)
173 | InboundSubstreamState::Closing(substream) => {
174 *self = InboundSubstreamState::Closing(substream);
175 }
176 InboundSubstreamState::Cancelled => {
177 *self = InboundSubstreamState::Cancelled;
178 }
179 InboundSubstreamState::Poisoned { .. } => unreachable!(),
180 }
181 }
182}
183
184#[derive(Debug)]
186pub enum HandlerEvent {
187 ProtocolConfirmed { endpoint: ConnectedPoint },
190 ProtocolNotSupported { endpoint: ConnectedPoint },
193
194 FindNodeReq {
197 key: Vec<u8>,
199 request_id: RequestId,
201 },
202
203 FindNodeRes {
205 closer_peers: Vec<KadPeer>,
207 query_id: QueryId,
209 },
210
211 GetProvidersReq {
214 key: record::Key,
216 request_id: RequestId,
218 },
219
220 GetProvidersRes {
222 closer_peers: Vec<KadPeer>,
224 provider_peers: Vec<KadPeer>,
226 query_id: QueryId,
228 },
229
230 QueryError {
232 error: HandlerQueryErr,
234 query_id: QueryId,
236 },
237
238 AddProvider {
240 key: record::Key,
242 provider: KadPeer,
244 },
245
246 GetRecord {
248 key: record::Key,
250 request_id: RequestId,
252 },
253
254 GetRecordRes {
256 record: Option<Record>,
258 closer_peers: Vec<KadPeer>,
260 query_id: QueryId,
262 },
263
264 PutRecord {
266 record: Record,
267 request_id: RequestId,
269 },
270
271 PutRecordRes {
273 key: record::Key,
275 value: Vec<u8>,
277 query_id: QueryId,
279 },
280}
281
282#[derive(Debug)]
284pub enum HandlerQueryErr {
285 UnexpectedMessage,
287 Io(io::Error),
289}
290
291impl fmt::Display for HandlerQueryErr {
292 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
293 match self {
294 HandlerQueryErr::UnexpectedMessage => {
295 write!(
296 f,
297 "Remote answered our Kademlia RPC query with the wrong message type"
298 )
299 }
300 HandlerQueryErr::Io(err) => {
301 write!(f, "I/O error during a Kademlia RPC query: {err}")
302 }
303 }
304 }
305}
306
307impl error::Error for HandlerQueryErr {
308 fn source(&self) -> Option<&(dyn error::Error + 'static)> {
309 match self {
310 HandlerQueryErr::UnexpectedMessage => None,
311 HandlerQueryErr::Io(err) => Some(err),
312 }
313 }
314}
315
316#[derive(Debug)]
318pub enum HandlerIn {
319 Reset(RequestId),
327
328 ReconfigureMode { new_mode: Mode },
330
331 FindNodeReq {
334 key: Vec<u8>,
336 query_id: QueryId,
338 },
339
340 FindNodeRes {
342 closer_peers: Vec<KadPeer>,
344 request_id: RequestId,
348 },
349
350 GetProvidersReq {
353 key: record::Key,
355 query_id: QueryId,
357 },
358
359 GetProvidersRes {
361 closer_peers: Vec<KadPeer>,
363 provider_peers: Vec<KadPeer>,
365 request_id: RequestId,
369 },
370
371 AddProvider {
376 key: record::Key,
378 provider: KadPeer,
380 query_id: QueryId,
382 },
383
384 GetRecord {
386 key: record::Key,
388 query_id: QueryId,
390 },
391
392 GetRecordRes {
394 record: Option<Record>,
396 closer_peers: Vec<KadPeer>,
398 request_id: RequestId,
400 },
401
402 PutRecord {
404 record: Record,
405 query_id: QueryId,
407 },
408
409 PutRecordRes {
411 key: record::Key,
413 value: Vec<u8>,
415 request_id: RequestId,
417 },
418}
419
420#[derive(Debug, PartialEq, Eq, Copy, Clone)]
423pub struct RequestId {
424 connec_unique_id: UniqueConnecId,
426}
427
428#[derive(Debug, Copy, Clone, PartialEq, Eq)]
430struct UniqueConnecId(u64);
431
432impl Handler {
433 pub fn new(
434 protocol_config: ProtocolConfig,
435 endpoint: ConnectedPoint,
436 remote_peer_id: PeerId,
437 mode: Mode,
438 ) -> Self {
439 match &endpoint {
440 ConnectedPoint::Dialer { .. } => {
441 tracing::debug!(
442 peer=%remote_peer_id,
443 mode=%mode,
444 "New outbound connection"
445 );
446 }
447 ConnectedPoint::Listener { .. } => {
448 tracing::debug!(
449 peer=%remote_peer_id,
450 mode=%mode,
451 "New inbound connection"
452 );
453 }
454 }
455
456 Handler {
457 protocol_config,
458 mode,
459 endpoint,
460 remote_peer_id,
461 next_connec_unique_id: UniqueConnecId(0),
462 inbound_substreams: Default::default(),
463 outbound_substreams: futures_bounded::FuturesTupleSet::new(
464 Duration::from_secs(10),
465 MAX_NUM_STREAMS,
466 ),
467 pending_streams: Default::default(),
468 pending_messages: Default::default(),
469 protocol_status: None,
470 remote_supported_protocols: Default::default(),
471 }
472 }
473
474 fn on_fully_negotiated_outbound(
475 &mut self,
476 FullyNegotiatedOutbound {
477 protocol: stream,
478 info: (),
479 }: FullyNegotiatedOutbound<<Self as ConnectionHandler>::OutboundProtocol>,
480 ) {
481 if let Some(sender) = self.pending_streams.pop_front() {
482 let _ = sender.send(Ok(stream));
483 }
484
485 if self.protocol_status.is_none() {
486 self.protocol_status = Some(ProtocolStatus {
490 supported: true,
491 reported: false,
492 });
493 }
494 }
495
496 fn on_fully_negotiated_inbound(
497 &mut self,
498 FullyNegotiatedInbound { protocol, .. }: FullyNegotiatedInbound<
499 <Self as ConnectionHandler>::InboundProtocol,
500 >,
501 ) {
502 let protocol = match protocol {
505 future::Either::Left(p) => p,
506 future::Either::Right(p) => libp2p_core::util::unreachable(p),
507 };
508
509 if self.protocol_status.is_none() {
510 self.protocol_status = Some(ProtocolStatus {
514 supported: true,
515 reported: false,
516 });
517 }
518
519 if self.inbound_substreams.len() == MAX_NUM_STREAMS {
520 if let Some(s) = self.inbound_substreams.iter_mut().find(|s| {
521 matches!(
522 s,
523 InboundSubstreamState::WaitingMessage { first: false, .. }
525 )
526 }) {
527 *s = InboundSubstreamState::Cancelled;
528 tracing::debug!(
529 peer=?self.remote_peer_id,
530 "New inbound substream to peer exceeds inbound substream limit. \
531 Removed older substream waiting to be reused."
532 )
533 } else {
534 tracing::warn!(
535 peer=?self.remote_peer_id,
536 "New inbound substream to peer exceeds inbound substream limit. \
537 No older substream waiting to be reused. Dropping new substream."
538 );
539 return;
540 }
541 }
542
543 let connec_unique_id = self.next_connec_unique_id;
544 self.next_connec_unique_id.0 += 1;
545 self.inbound_substreams
546 .push(InboundSubstreamState::WaitingMessage {
547 first: true,
548 connection_id: connec_unique_id,
549 substream: protocol,
550 });
551 }
552
553 fn queue_new_stream(&mut self, id: QueryId, msg: KadRequestMsg) {
556 let (sender, receiver) = oneshot::channel();
557
558 self.pending_streams.push_back(sender);
559 let result = self.outbound_substreams.try_push(
560 async move {
561 let mut stream = receiver
562 .await
563 .map_err(|_| io::Error::from(io::ErrorKind::BrokenPipe))?
564 .map_err(|e| match e {
565 StreamUpgradeError::Timeout => io::ErrorKind::TimedOut.into(),
566 StreamUpgradeError::Apply(e) => e,
567 StreamUpgradeError::NegotiationFailed => io::Error::new(
568 io::ErrorKind::ConnectionRefused,
569 "protocol not supported",
570 ),
571 StreamUpgradeError::Io(e) => e,
572 })?;
573
574 let has_answer = !matches!(msg, KadRequestMsg::AddProvider { .. });
575
576 stream.send(msg).await?;
577 stream.close().await?;
578
579 if !has_answer {
580 return Ok(None);
581 }
582
583 let msg = stream.next().await.ok_or(io::ErrorKind::UnexpectedEof)??;
584
585 Ok(Some(msg))
586 },
587 id,
588 );
589
590 debug_assert!(
591 result.is_ok(),
592 "Expected to not create more streams than allowed"
593 );
594 }
595}
596
597impl ConnectionHandler for Handler {
598 type FromBehaviour = HandlerIn;
599 type ToBehaviour = HandlerEvent;
600 type InboundProtocol = Either<ProtocolConfig, upgrade::DeniedUpgrade>;
601 type OutboundProtocol = ProtocolConfig;
602 type OutboundOpenInfo = ();
603 type InboundOpenInfo = ();
604
605 fn listen_protocol(&self) -> SubstreamProtocol<Self::InboundProtocol> {
606 match self.mode {
607 Mode::Server => SubstreamProtocol::new(Either::Left(self.protocol_config.clone()), ()),
608 Mode::Client => SubstreamProtocol::new(Either::Right(upgrade::DeniedUpgrade), ()),
609 }
610 }
611
612 fn on_behaviour_event(&mut self, message: HandlerIn) {
613 match message {
614 HandlerIn::Reset(request_id) => {
615 if let Some(state) = self
616 .inbound_substreams
617 .iter_mut()
618 .find(|state| match state {
619 InboundSubstreamState::WaitingBehaviour(conn_id, _, _) => {
620 conn_id == &request_id.connec_unique_id
621 }
622 _ => false,
623 })
624 {
625 state.close();
626 }
627 }
628 HandlerIn::FindNodeReq { key, query_id } => {
629 let msg = KadRequestMsg::FindNode { key };
630 self.pending_messages.push_back((msg, query_id));
631 }
632 HandlerIn::FindNodeRes {
633 closer_peers,
634 request_id,
635 } => self.answer_pending_request(request_id, KadResponseMsg::FindNode { closer_peers }),
636 HandlerIn::GetProvidersReq { key, query_id } => {
637 let msg = KadRequestMsg::GetProviders { key };
638 self.pending_messages.push_back((msg, query_id));
639 }
640 HandlerIn::GetProvidersRes {
641 closer_peers,
642 provider_peers,
643 request_id,
644 } => self.answer_pending_request(
645 request_id,
646 KadResponseMsg::GetProviders {
647 closer_peers,
648 provider_peers,
649 },
650 ),
651 HandlerIn::AddProvider {
652 key,
653 provider,
654 query_id,
655 } => {
656 let msg = KadRequestMsg::AddProvider { key, provider };
657 self.pending_messages.push_back((msg, query_id));
658 }
659 HandlerIn::GetRecord { key, query_id } => {
660 let msg = KadRequestMsg::GetValue { key };
661 self.pending_messages.push_back((msg, query_id));
662 }
663 HandlerIn::PutRecord { record, query_id } => {
664 let msg = KadRequestMsg::PutValue { record };
665 self.pending_messages.push_back((msg, query_id));
666 }
667 HandlerIn::GetRecordRes {
668 record,
669 closer_peers,
670 request_id,
671 } => {
672 self.answer_pending_request(
673 request_id,
674 KadResponseMsg::GetValue {
675 record,
676 closer_peers,
677 },
678 );
679 }
680 HandlerIn::PutRecordRes {
681 key,
682 request_id,
683 value,
684 } => {
685 self.answer_pending_request(request_id, KadResponseMsg::PutValue { key, value });
686 }
687 HandlerIn::ReconfigureMode { new_mode } => {
688 let peer = self.remote_peer_id;
689
690 match &self.endpoint {
691 ConnectedPoint::Dialer { .. } => {
692 tracing::debug!(
693 %peer,
694 mode=%new_mode,
695 "Changed mode on outbound connection"
696 )
697 }
698 ConnectedPoint::Listener { local_addr, .. } => {
699 tracing::debug!(
700 %peer,
701 mode=%new_mode,
702 local_address=%local_addr,
703 "Changed mode on inbound connection assuming that one of our external addresses routes to the local address")
704 }
705 }
706
707 self.mode = new_mode;
708 }
709 }
710 }
711
712 #[tracing::instrument(level = "trace", name = "ConnectionHandler::poll", skip(self, cx))]
713 fn poll(
714 &mut self,
715 cx: &mut Context<'_>,
716 ) -> Poll<ConnectionHandlerEvent<Self::OutboundProtocol, (), Self::ToBehaviour>> {
717 loop {
718 match &mut self.protocol_status {
719 Some(status) if !status.reported => {
720 status.reported = true;
721 let event = if status.supported {
722 HandlerEvent::ProtocolConfirmed {
723 endpoint: self.endpoint.clone(),
724 }
725 } else {
726 HandlerEvent::ProtocolNotSupported {
727 endpoint: self.endpoint.clone(),
728 }
729 };
730
731 return Poll::Ready(ConnectionHandlerEvent::NotifyBehaviour(event));
732 }
733 _ => {}
734 }
735
736 match self.outbound_substreams.poll_unpin(cx) {
737 Poll::Ready((Ok(Ok(Some(response))), query_id)) => {
738 return Poll::Ready(ConnectionHandlerEvent::NotifyBehaviour(
739 process_kad_response(response, query_id),
740 ))
741 }
742 Poll::Ready((Ok(Ok(None)), _)) => {
743 continue;
744 }
745 Poll::Ready((Ok(Err(e)), query_id)) => {
746 return Poll::Ready(ConnectionHandlerEvent::NotifyBehaviour(
747 HandlerEvent::QueryError {
748 error: HandlerQueryErr::Io(e),
749 query_id,
750 },
751 ))
752 }
753 Poll::Ready((Err(_timeout), query_id)) => {
754 return Poll::Ready(ConnectionHandlerEvent::NotifyBehaviour(
755 HandlerEvent::QueryError {
756 error: HandlerQueryErr::Io(io::ErrorKind::TimedOut.into()),
757 query_id,
758 },
759 ))
760 }
761 Poll::Pending => {}
762 }
763
764 if let Poll::Ready(Some(event)) = self.inbound_substreams.poll_next_unpin(cx) {
765 return Poll::Ready(event);
766 }
767
768 if self.outbound_substreams.len() < MAX_NUM_STREAMS {
769 if let Some((msg, id)) = self.pending_messages.pop_front() {
770 self.queue_new_stream(id, msg);
771 return Poll::Ready(ConnectionHandlerEvent::OutboundSubstreamRequest {
772 protocol: SubstreamProtocol::new(self.protocol_config.clone(), ()),
773 });
774 }
775 }
776
777 return Poll::Pending;
778 }
779 }
780
781 fn on_connection_event(
782 &mut self,
783 event: ConnectionEvent<Self::InboundProtocol, Self::OutboundProtocol>,
784 ) {
785 match event {
786 ConnectionEvent::FullyNegotiatedOutbound(fully_negotiated_outbound) => {
787 self.on_fully_negotiated_outbound(fully_negotiated_outbound)
788 }
789 ConnectionEvent::FullyNegotiatedInbound(fully_negotiated_inbound) => {
790 self.on_fully_negotiated_inbound(fully_negotiated_inbound)
791 }
792 ConnectionEvent::DialUpgradeError(ev) => {
793 if let Some(sender) = self.pending_streams.pop_front() {
794 let _ = sender.send(Err(ev.error));
795 }
796 }
797 ConnectionEvent::RemoteProtocolsChange(change) => {
798 let dirty = self.remote_supported_protocols.on_protocols_change(change);
799
800 if dirty {
801 let remote_supports_our_kademlia_protocols = self
802 .remote_supported_protocols
803 .iter()
804 .any(|p| self.protocol_config.protocol_names().contains(p));
805
806 self.protocol_status = Some(compute_new_protocol_status(
807 remote_supports_our_kademlia_protocols,
808 self.protocol_status,
809 ))
810 }
811 }
812 _ => {}
813 }
814 }
815}
816
817fn compute_new_protocol_status(
818 now_supported: bool,
819 current_status: Option<ProtocolStatus>,
820) -> ProtocolStatus {
821 let current_status = match current_status {
822 None => {
823 return ProtocolStatus {
824 supported: now_supported,
825 reported: false,
826 }
827 }
828 Some(current) => current,
829 };
830
831 if now_supported == current_status.supported {
832 return ProtocolStatus {
833 supported: now_supported,
834 reported: true,
835 };
836 }
837
838 if now_supported {
839 tracing::debug!("Remote now supports our kademlia protocol");
840 } else {
841 tracing::debug!("Remote no longer supports our kademlia protocol");
842 }
843
844 ProtocolStatus {
845 supported: now_supported,
846 reported: false,
847 }
848}
849
850impl Handler {
851 fn answer_pending_request(&mut self, request_id: RequestId, mut msg: KadResponseMsg) {
852 for state in self.inbound_substreams.iter_mut() {
853 match state.try_answer_with(request_id, msg) {
854 Ok(()) => return,
855 Err(m) => {
856 msg = m;
857 }
858 }
859 }
860
861 debug_assert!(false, "Cannot find inbound substream for {request_id:?}")
862 }
863}
864
865impl futures::Stream for InboundSubstreamState {
866 type Item = ConnectionHandlerEvent<ProtocolConfig, (), HandlerEvent>;
867
868 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
869 let this = self.get_mut();
870
871 loop {
872 match std::mem::replace(
873 this,
874 Self::Poisoned {
875 phantom: PhantomData,
876 },
877 ) {
878 InboundSubstreamState::WaitingMessage {
879 first,
880 connection_id,
881 mut substream,
882 } => match substream.poll_next_unpin(cx) {
883 Poll::Ready(Some(Ok(KadRequestMsg::Ping))) => {
884 tracing::warn!("Kademlia PING messages are unsupported");
885
886 *this = InboundSubstreamState::Closing(substream);
887 }
888 Poll::Ready(Some(Ok(KadRequestMsg::FindNode { key }))) => {
889 *this =
890 InboundSubstreamState::WaitingBehaviour(connection_id, substream, None);
891 return Poll::Ready(Some(ConnectionHandlerEvent::NotifyBehaviour(
892 HandlerEvent::FindNodeReq {
893 key,
894 request_id: RequestId {
895 connec_unique_id: connection_id,
896 },
897 },
898 )));
899 }
900 Poll::Ready(Some(Ok(KadRequestMsg::GetProviders { key }))) => {
901 *this =
902 InboundSubstreamState::WaitingBehaviour(connection_id, substream, None);
903 return Poll::Ready(Some(ConnectionHandlerEvent::NotifyBehaviour(
904 HandlerEvent::GetProvidersReq {
905 key,
906 request_id: RequestId {
907 connec_unique_id: connection_id,
908 },
909 },
910 )));
911 }
912 Poll::Ready(Some(Ok(KadRequestMsg::AddProvider { key, provider }))) => {
913 *this = InboundSubstreamState::WaitingMessage {
914 first: false,
915 connection_id,
916 substream,
917 };
918 return Poll::Ready(Some(ConnectionHandlerEvent::NotifyBehaviour(
919 HandlerEvent::AddProvider { key, provider },
920 )));
921 }
922 Poll::Ready(Some(Ok(KadRequestMsg::GetValue { key }))) => {
923 *this =
924 InboundSubstreamState::WaitingBehaviour(connection_id, substream, None);
925 return Poll::Ready(Some(ConnectionHandlerEvent::NotifyBehaviour(
926 HandlerEvent::GetRecord {
927 key,
928 request_id: RequestId {
929 connec_unique_id: connection_id,
930 },
931 },
932 )));
933 }
934 Poll::Ready(Some(Ok(KadRequestMsg::PutValue { record }))) => {
935 *this =
936 InboundSubstreamState::WaitingBehaviour(connection_id, substream, None);
937 return Poll::Ready(Some(ConnectionHandlerEvent::NotifyBehaviour(
938 HandlerEvent::PutRecord {
939 record,
940 request_id: RequestId {
941 connec_unique_id: connection_id,
942 },
943 },
944 )));
945 }
946 Poll::Pending => {
947 *this = InboundSubstreamState::WaitingMessage {
948 first,
949 connection_id,
950 substream,
951 };
952 return Poll::Pending;
953 }
954 Poll::Ready(None) => {
955 return Poll::Ready(None);
956 }
957 Poll::Ready(Some(Err(e))) => {
958 tracing::trace!("Inbound substream error: {:?}", e);
959 return Poll::Ready(None);
960 }
961 },
962 InboundSubstreamState::WaitingBehaviour(id, substream, _) => {
963 *this = InboundSubstreamState::WaitingBehaviour(
964 id,
965 substream,
966 Some(cx.waker().clone()),
967 );
968
969 return Poll::Pending;
970 }
971 InboundSubstreamState::PendingSend(id, mut substream, msg) => {
972 match substream.poll_ready_unpin(cx) {
973 Poll::Ready(Ok(())) => match substream.start_send_unpin(msg) {
974 Ok(()) => {
975 *this = InboundSubstreamState::PendingFlush(id, substream);
976 }
977 Err(_) => return Poll::Ready(None),
978 },
979 Poll::Pending => {
980 *this = InboundSubstreamState::PendingSend(id, substream, msg);
981 return Poll::Pending;
982 }
983 Poll::Ready(Err(_)) => return Poll::Ready(None),
984 }
985 }
986 InboundSubstreamState::PendingFlush(id, mut substream) => {
987 match substream.poll_flush_unpin(cx) {
988 Poll::Ready(Ok(())) => {
989 *this = InboundSubstreamState::WaitingMessage {
990 first: false,
991 connection_id: id,
992 substream,
993 };
994 }
995 Poll::Pending => {
996 *this = InboundSubstreamState::PendingFlush(id, substream);
997 return Poll::Pending;
998 }
999 Poll::Ready(Err(_)) => return Poll::Ready(None),
1000 }
1001 }
1002 InboundSubstreamState::Closing(mut stream) => match stream.poll_close_unpin(cx) {
1003 Poll::Ready(Ok(())) | Poll::Ready(Err(_)) => return Poll::Ready(None),
1004 Poll::Pending => {
1005 *this = InboundSubstreamState::Closing(stream);
1006 return Poll::Pending;
1007 }
1008 },
1009 InboundSubstreamState::Poisoned { .. } => unreachable!(),
1010 InboundSubstreamState::Cancelled => return Poll::Ready(None),
1011 }
1012 }
1013 }
1014}
1015
1016fn process_kad_response(event: KadResponseMsg, query_id: QueryId) -> HandlerEvent {
1018 match event {
1020 KadResponseMsg::Pong => {
1021 HandlerEvent::QueryError {
1023 error: HandlerQueryErr::UnexpectedMessage,
1024 query_id,
1025 }
1026 }
1027 KadResponseMsg::FindNode { closer_peers } => HandlerEvent::FindNodeRes {
1028 closer_peers,
1029 query_id,
1030 },
1031 KadResponseMsg::GetProviders {
1032 closer_peers,
1033 provider_peers,
1034 } => HandlerEvent::GetProvidersRes {
1035 closer_peers,
1036 provider_peers,
1037 query_id,
1038 },
1039 KadResponseMsg::GetValue {
1040 record,
1041 closer_peers,
1042 } => HandlerEvent::GetRecordRes {
1043 record,
1044 closer_peers,
1045 query_id,
1046 },
1047 KadResponseMsg::PutValue { key, value, .. } => HandlerEvent::PutRecordRes {
1048 key,
1049 value,
1050 query_id,
1051 },
1052 }
1053}
1054
1055#[cfg(test)]
1056mod tests {
1057 use quickcheck::{Arbitrary, Gen};
1058 use tracing_subscriber::EnvFilter;
1059
1060 use super::*;
1061
1062 impl Arbitrary for ProtocolStatus {
1063 fn arbitrary(g: &mut Gen) -> Self {
1064 Self {
1065 supported: bool::arbitrary(g),
1066 reported: bool::arbitrary(g),
1067 }
1068 }
1069 }
1070
1071 #[test]
1072 fn compute_next_protocol_status_test() {
1073 let _ = tracing_subscriber::fmt()
1074 .with_env_filter(EnvFilter::from_default_env())
1075 .try_init();
1076
1077 fn prop(now_supported: bool, current: Option<ProtocolStatus>) {
1078 let new = compute_new_protocol_status(now_supported, current);
1079
1080 match current {
1081 None => {
1082 assert!(!new.reported);
1083 assert_eq!(new.supported, now_supported);
1084 }
1085 Some(current) => {
1086 if current.supported == now_supported {
1087 assert!(new.reported);
1088 } else {
1089 assert!(!new.reported);
1090 }
1091
1092 assert_eq!(new.supported, now_supported);
1093 }
1094 }
1095 }
1096
1097 quickcheck::quickcheck(prop as fn(_, _))
1098 }
1099}