1use std::{
22 collections::VecDeque,
23 error, fmt, io,
24 marker::PhantomData,
25 pin::Pin,
26 task::{Context, Poll, Waker},
27};
28
29use either::Either;
30use futures::{channel::oneshot, prelude::*, stream::SelectAll};
31use libp2p_core::{upgrade, ConnectedPoint};
32use libp2p_identity::PeerId;
33use libp2p_swarm::{
34 handler::{ConnectionEvent, FullyNegotiatedInbound, FullyNegotiatedOutbound},
35 ConnectionHandler, ConnectionHandlerEvent, Stream, StreamUpgradeError, SubstreamProtocol,
36 SupportedProtocols,
37};
38
39use crate::{
40 behaviour::Mode,
41 protocol::{
42 KadInStreamSink, KadOutStreamSink, KadPeer, KadRequestMsg, KadResponseMsg, ProtocolConfig,
43 },
44 record::{self, Record},
45 QueryId,
46};
47
48const MAX_NUM_STREAMS: usize = 32;
49
50pub struct Handler {
58 protocol_config: ProtocolConfig,
60
61 mode: Mode,
63
64 next_connec_unique_id: UniqueConnecId,
66
67 outbound_substreams:
69 futures_bounded::FuturesTupleSet<io::Result<Option<KadResponseMsg>>, QueryId>,
70
71 pending_streams:
73 VecDeque<oneshot::Sender<Result<KadOutStreamSink<Stream>, StreamUpgradeError<io::Error>>>>,
74
75 pending_messages: VecDeque<(KadRequestMsg, QueryId)>,
78
79 inbound_substreams: SelectAll<InboundSubstreamState>,
81
82 endpoint: ConnectedPoint,
85
86 remote_peer_id: PeerId,
88
89 protocol_status: Option<ProtocolStatus>,
91
92 remote_supported_protocols: SupportedProtocols,
93}
94
95#[derive(Debug, Copy, Clone, PartialEq)]
98struct ProtocolStatus {
99 supported: bool,
101 reported: bool,
103}
104
105enum InboundSubstreamState {
107 WaitingMessage {
109 first: bool,
111 connection_id: UniqueConnecId,
112 substream: KadInStreamSink<Stream>,
113 },
114 WaitingBehaviour(UniqueConnecId, KadInStreamSink<Stream>, Option<Waker>),
116 PendingSend(UniqueConnecId, KadInStreamSink<Stream>, KadResponseMsg),
118 PendingFlush(UniqueConnecId, KadInStreamSink<Stream>),
120 Closing(KadInStreamSink<Stream>),
122 Cancelled,
124
125 Poisoned {
126 phantom: PhantomData<QueryId>,
127 },
128}
129
130impl InboundSubstreamState {
131 #[allow(clippy::result_large_err)]
132 fn try_answer_with(
133 &mut self,
134 id: RequestId,
135 msg: KadResponseMsg,
136 ) -> Result<(), KadResponseMsg> {
137 match std::mem::replace(
138 self,
139 InboundSubstreamState::Poisoned {
140 phantom: PhantomData,
141 },
142 ) {
143 InboundSubstreamState::WaitingBehaviour(conn_id, substream, mut waker)
144 if conn_id == id.connec_unique_id =>
145 {
146 *self = InboundSubstreamState::PendingSend(conn_id, substream, msg);
147
148 if let Some(waker) = waker.take() {
149 waker.wake();
150 }
151
152 Ok(())
153 }
154 other => {
155 *self = other;
156
157 Err(msg)
158 }
159 }
160 }
161
162 fn close(&mut self) {
163 match std::mem::replace(
164 self,
165 InboundSubstreamState::Poisoned {
166 phantom: PhantomData,
167 },
168 ) {
169 InboundSubstreamState::WaitingMessage { substream, .. }
170 | InboundSubstreamState::WaitingBehaviour(_, substream, _)
171 | InboundSubstreamState::PendingSend(_, substream, _)
172 | InboundSubstreamState::PendingFlush(_, substream)
173 | InboundSubstreamState::Closing(substream) => {
174 *self = InboundSubstreamState::Closing(substream);
175 }
176 InboundSubstreamState::Cancelled => {
177 *self = InboundSubstreamState::Cancelled;
178 }
179 InboundSubstreamState::Poisoned { .. } => unreachable!(),
180 }
181 }
182}
183
184#[derive(Debug)]
186pub enum HandlerEvent {
187 ProtocolConfirmed { endpoint: ConnectedPoint },
190 ProtocolNotSupported { endpoint: ConnectedPoint },
193
194 FindNodeReq {
197 key: Vec<u8>,
199 request_id: RequestId,
201 },
202
203 FindNodeRes {
205 closer_peers: Vec<KadPeer>,
207 query_id: QueryId,
209 },
210
211 GetProvidersReq {
214 key: record::Key,
216 request_id: RequestId,
218 },
219
220 GetProvidersRes {
222 closer_peers: Vec<KadPeer>,
224 provider_peers: Vec<KadPeer>,
226 query_id: QueryId,
228 },
229
230 QueryError {
232 error: HandlerQueryErr,
234 query_id: QueryId,
236 },
237
238 AddProvider {
240 key: record::Key,
242 provider: KadPeer,
244 },
245
246 GetRecord {
248 key: record::Key,
250 request_id: RequestId,
252 },
253
254 GetRecordRes {
256 record: Option<Record>,
258 closer_peers: Vec<KadPeer>,
260 query_id: QueryId,
262 },
263
264 PutRecord {
266 record: Record,
267 request_id: RequestId,
269 },
270
271 PutRecordRes {
273 key: record::Key,
275 value: Vec<u8>,
277 query_id: QueryId,
279 },
280}
281
282#[derive(Debug)]
284pub enum HandlerQueryErr {
285 UnexpectedMessage,
287 Io(io::Error),
289}
290
291impl fmt::Display for HandlerQueryErr {
292 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
293 match self {
294 HandlerQueryErr::UnexpectedMessage => {
295 write!(
296 f,
297 "Remote answered our Kademlia RPC query with the wrong message type"
298 )
299 }
300 HandlerQueryErr::Io(err) => {
301 write!(f, "I/O error during a Kademlia RPC query: {err}")
302 }
303 }
304 }
305}
306
307impl error::Error for HandlerQueryErr {
308 fn source(&self) -> Option<&(dyn error::Error + 'static)> {
309 match self {
310 HandlerQueryErr::UnexpectedMessage => None,
311 HandlerQueryErr::Io(err) => Some(err),
312 }
313 }
314}
315
316#[derive(Debug)]
318pub enum HandlerIn {
319 Reset(RequestId),
327
328 ReconfigureMode { new_mode: Mode },
330
331 FindNodeReq {
334 key: Vec<u8>,
336 query_id: QueryId,
338 },
339
340 FindNodeRes {
342 closer_peers: Vec<KadPeer>,
344 request_id: RequestId,
348 },
349
350 GetProvidersReq {
353 key: record::Key,
355 query_id: QueryId,
357 },
358
359 GetProvidersRes {
361 closer_peers: Vec<KadPeer>,
363 provider_peers: Vec<KadPeer>,
365 request_id: RequestId,
369 },
370
371 AddProvider {
376 key: record::Key,
378 provider: KadPeer,
380 query_id: QueryId,
382 },
383
384 GetRecord {
386 key: record::Key,
388 query_id: QueryId,
390 },
391
392 GetRecordRes {
394 record: Option<Record>,
396 closer_peers: Vec<KadPeer>,
398 request_id: RequestId,
400 },
401
402 PutRecord {
404 record: Record,
405 query_id: QueryId,
407 },
408
409 PutRecordRes {
411 key: record::Key,
413 value: Vec<u8>,
415 request_id: RequestId,
417 },
418}
419
420#[derive(Debug, PartialEq, Eq, Copy, Clone)]
423pub struct RequestId {
424 connec_unique_id: UniqueConnecId,
426}
427
428#[derive(Debug, Copy, Clone, PartialEq, Eq)]
430struct UniqueConnecId(u64);
431
432impl Handler {
433 pub fn new(
434 protocol_config: ProtocolConfig,
435 endpoint: ConnectedPoint,
436 remote_peer_id: PeerId,
437 mode: Mode,
438 ) -> Self {
439 match &endpoint {
440 ConnectedPoint::Dialer { .. } => {
441 tracing::debug!(
442 peer=%remote_peer_id,
443 mode=%mode,
444 "New outbound connection"
445 );
446 }
447 ConnectedPoint::Listener { .. } => {
448 tracing::debug!(
449 peer=%remote_peer_id,
450 mode=%mode,
451 "New inbound connection"
452 );
453 }
454 }
455
456 let substreams_timeout = protocol_config.substreams_timeout_s();
457
458 Handler {
459 protocol_config,
460 mode,
461 endpoint,
462 remote_peer_id,
463 next_connec_unique_id: UniqueConnecId(0),
464 inbound_substreams: Default::default(),
465 outbound_substreams: futures_bounded::FuturesTupleSet::new(
466 substreams_timeout,
467 MAX_NUM_STREAMS,
468 ),
469 pending_streams: Default::default(),
470 pending_messages: Default::default(),
471 protocol_status: None,
472 remote_supported_protocols: Default::default(),
473 }
474 }
475
476 fn on_fully_negotiated_outbound(
477 &mut self,
478 FullyNegotiatedOutbound {
479 protocol: stream,
480 info: (),
481 }: FullyNegotiatedOutbound<<Self as ConnectionHandler>::OutboundProtocol>,
482 ) {
483 if let Some(sender) = self.pending_streams.pop_front() {
484 let _ = sender.send(Ok(stream));
485 }
486
487 if self.protocol_status.is_none() {
488 self.protocol_status = Some(ProtocolStatus {
492 supported: true,
493 reported: false,
494 });
495 }
496 }
497
498 fn on_fully_negotiated_inbound(
499 &mut self,
500 FullyNegotiatedInbound { protocol, .. }: FullyNegotiatedInbound<
501 <Self as ConnectionHandler>::InboundProtocol,
502 >,
503 ) {
504 let protocol = match protocol {
507 future::Either::Left(p) => p,
508 future::Either::Right(p) => libp2p_core::util::unreachable(p),
509 };
510
511 if self.protocol_status.is_none() {
512 self.protocol_status = Some(ProtocolStatus {
516 supported: true,
517 reported: false,
518 });
519 }
520
521 if self.inbound_substreams.len() == MAX_NUM_STREAMS {
522 if let Some(s) = self.inbound_substreams.iter_mut().find(|s| {
523 matches!(
524 s,
525 InboundSubstreamState::WaitingMessage { first: false, .. }
527 )
528 }) {
529 *s = InboundSubstreamState::Cancelled;
530 tracing::debug!(
531 peer=?self.remote_peer_id,
532 "New inbound substream to peer exceeds inbound substream limit. \
533 Removed older substream waiting to be reused."
534 )
535 } else {
536 tracing::warn!(
537 peer=?self.remote_peer_id,
538 "New inbound substream to peer exceeds inbound substream limit. \
539 No older substream waiting to be reused. Dropping new substream."
540 );
541 return;
542 }
543 }
544
545 let connec_unique_id = self.next_connec_unique_id;
546 self.next_connec_unique_id.0 += 1;
547 self.inbound_substreams
548 .push(InboundSubstreamState::WaitingMessage {
549 first: true,
550 connection_id: connec_unique_id,
551 substream: protocol,
552 });
553 }
554
555 fn queue_new_stream(&mut self, id: QueryId, msg: KadRequestMsg) {
558 let (sender, receiver) = oneshot::channel();
559
560 self.pending_streams.push_back(sender);
561 let result = self.outbound_substreams.try_push(
562 async move {
563 let mut stream = receiver
564 .await
565 .map_err(|_| io::Error::from(io::ErrorKind::BrokenPipe))?
566 .map_err(|e| match e {
567 StreamUpgradeError::Timeout => io::ErrorKind::TimedOut.into(),
568 StreamUpgradeError::Apply(e) => e,
569 StreamUpgradeError::NegotiationFailed => io::Error::new(
570 io::ErrorKind::ConnectionRefused,
571 "protocol not supported",
572 ),
573 StreamUpgradeError::Io(e) => e,
574 })?;
575
576 let has_answer = !matches!(msg, KadRequestMsg::AddProvider { .. });
577
578 stream.send(msg).await?;
579 stream.close().await?;
580
581 if !has_answer {
582 return Ok(None);
583 }
584
585 let msg = stream.next().await.ok_or(io::ErrorKind::UnexpectedEof)??;
586
587 Ok(Some(msg))
588 },
589 id,
590 );
591
592 debug_assert!(
593 result.is_ok(),
594 "Expected to not create more streams than allowed"
595 );
596 }
597}
598
599impl ConnectionHandler for Handler {
600 type FromBehaviour = HandlerIn;
601 type ToBehaviour = HandlerEvent;
602 type InboundProtocol = Either<ProtocolConfig, upgrade::DeniedUpgrade>;
603 type OutboundProtocol = ProtocolConfig;
604 type OutboundOpenInfo = ();
605 type InboundOpenInfo = ();
606
607 fn listen_protocol(&self) -> SubstreamProtocol<Self::InboundProtocol> {
608 match self.mode {
609 Mode::Server => SubstreamProtocol::new(Either::Left(self.protocol_config.clone()), ()),
610 Mode::Client => SubstreamProtocol::new(Either::Right(upgrade::DeniedUpgrade), ()),
611 }
612 }
613
614 fn on_behaviour_event(&mut self, message: HandlerIn) {
615 match message {
616 HandlerIn::Reset(request_id) => {
617 if let Some(state) = self
618 .inbound_substreams
619 .iter_mut()
620 .find(|state| match state {
621 InboundSubstreamState::WaitingBehaviour(conn_id, _, _) => {
622 conn_id == &request_id.connec_unique_id
623 }
624 _ => false,
625 })
626 {
627 state.close();
628 }
629 }
630 HandlerIn::FindNodeReq { key, query_id } => {
631 let msg = KadRequestMsg::FindNode { key };
632 self.pending_messages.push_back((msg, query_id));
633 }
634 HandlerIn::FindNodeRes {
635 closer_peers,
636 request_id,
637 } => self.answer_pending_request(request_id, KadResponseMsg::FindNode { closer_peers }),
638 HandlerIn::GetProvidersReq { key, query_id } => {
639 let msg = KadRequestMsg::GetProviders { key };
640 self.pending_messages.push_back((msg, query_id));
641 }
642 HandlerIn::GetProvidersRes {
643 closer_peers,
644 provider_peers,
645 request_id,
646 } => self.answer_pending_request(
647 request_id,
648 KadResponseMsg::GetProviders {
649 closer_peers,
650 provider_peers,
651 },
652 ),
653 HandlerIn::AddProvider {
654 key,
655 provider,
656 query_id,
657 } => {
658 let msg = KadRequestMsg::AddProvider { key, provider };
659 self.pending_messages.push_back((msg, query_id));
660 }
661 HandlerIn::GetRecord { key, query_id } => {
662 let msg = KadRequestMsg::GetValue { key };
663 self.pending_messages.push_back((msg, query_id));
664 }
665 HandlerIn::PutRecord { record, query_id } => {
666 let msg = KadRequestMsg::PutValue { record };
667 self.pending_messages.push_back((msg, query_id));
668 }
669 HandlerIn::GetRecordRes {
670 record,
671 closer_peers,
672 request_id,
673 } => {
674 self.answer_pending_request(
675 request_id,
676 KadResponseMsg::GetValue {
677 record,
678 closer_peers,
679 },
680 );
681 }
682 HandlerIn::PutRecordRes {
683 key,
684 request_id,
685 value,
686 } => {
687 self.answer_pending_request(request_id, KadResponseMsg::PutValue { key, value });
688 }
689 HandlerIn::ReconfigureMode { new_mode } => {
690 let peer = self.remote_peer_id;
691
692 match &self.endpoint {
693 ConnectedPoint::Dialer { .. } => {
694 tracing::debug!(
695 %peer,
696 mode=%new_mode,
697 "Changed mode on outbound connection"
698 )
699 }
700 ConnectedPoint::Listener { local_addr, .. } => {
701 tracing::debug!(
702 %peer,
703 mode=%new_mode,
704 local_address=%local_addr,
705 "Changed mode on inbound connection assuming that one of our external addresses routes to the local address")
706 }
707 }
708
709 self.mode = new_mode;
710 }
711 }
712 }
713
714 #[tracing::instrument(level = "trace", name = "ConnectionHandler::poll", skip(self, cx))]
715 fn poll(
716 &mut self,
717 cx: &mut Context<'_>,
718 ) -> Poll<ConnectionHandlerEvent<Self::OutboundProtocol, (), Self::ToBehaviour>> {
719 loop {
720 match &mut self.protocol_status {
721 Some(status) if !status.reported => {
722 status.reported = true;
723 let event = if status.supported {
724 HandlerEvent::ProtocolConfirmed {
725 endpoint: self.endpoint.clone(),
726 }
727 } else {
728 HandlerEvent::ProtocolNotSupported {
729 endpoint: self.endpoint.clone(),
730 }
731 };
732
733 return Poll::Ready(ConnectionHandlerEvent::NotifyBehaviour(event));
734 }
735 _ => {}
736 }
737
738 match self.outbound_substreams.poll_unpin(cx) {
739 Poll::Ready((Ok(Ok(Some(response))), query_id)) => {
740 return Poll::Ready(ConnectionHandlerEvent::NotifyBehaviour(
741 process_kad_response(response, query_id),
742 ))
743 }
744 Poll::Ready((Ok(Ok(None)), _)) => {
745 continue;
746 }
747 Poll::Ready((Ok(Err(e)), query_id)) => {
748 return Poll::Ready(ConnectionHandlerEvent::NotifyBehaviour(
749 HandlerEvent::QueryError {
750 error: HandlerQueryErr::Io(e),
751 query_id,
752 },
753 ))
754 }
755 Poll::Ready((Err(_timeout), query_id)) => {
756 return Poll::Ready(ConnectionHandlerEvent::NotifyBehaviour(
757 HandlerEvent::QueryError {
758 error: HandlerQueryErr::Io(io::ErrorKind::TimedOut.into()),
759 query_id,
760 },
761 ))
762 }
763 Poll::Pending => {}
764 }
765
766 if let Poll::Ready(Some(event)) = self.inbound_substreams.poll_next_unpin(cx) {
767 return Poll::Ready(event);
768 }
769
770 if self.outbound_substreams.len() < MAX_NUM_STREAMS {
771 if let Some((msg, id)) = self.pending_messages.pop_front() {
772 self.queue_new_stream(id, msg);
773 return Poll::Ready(ConnectionHandlerEvent::OutboundSubstreamRequest {
774 protocol: SubstreamProtocol::new(self.protocol_config.clone(), ()),
775 });
776 }
777 }
778
779 return Poll::Pending;
780 }
781 }
782
783 fn on_connection_event(
784 &mut self,
785 event: ConnectionEvent<Self::InboundProtocol, Self::OutboundProtocol>,
786 ) {
787 match event {
788 ConnectionEvent::FullyNegotiatedOutbound(fully_negotiated_outbound) => {
789 self.on_fully_negotiated_outbound(fully_negotiated_outbound)
790 }
791 ConnectionEvent::FullyNegotiatedInbound(fully_negotiated_inbound) => {
792 self.on_fully_negotiated_inbound(fully_negotiated_inbound)
793 }
794 ConnectionEvent::DialUpgradeError(ev) => {
795 if let Some(sender) = self.pending_streams.pop_front() {
796 let _ = sender.send(Err(ev.error));
797 }
798 }
799 ConnectionEvent::RemoteProtocolsChange(change) => {
800 let dirty = self.remote_supported_protocols.on_protocols_change(change);
801
802 if dirty {
803 let remote_supports_our_kademlia_protocols = self
804 .remote_supported_protocols
805 .iter()
806 .any(|p| self.protocol_config.protocol_names().contains(p));
807
808 self.protocol_status = Some(compute_new_protocol_status(
809 remote_supports_our_kademlia_protocols,
810 self.protocol_status,
811 ))
812 }
813 }
814 _ => {}
815 }
816 }
817}
818
819fn compute_new_protocol_status(
820 now_supported: bool,
821 current_status: Option<ProtocolStatus>,
822) -> ProtocolStatus {
823 let Some(current_status) = current_status else {
824 return ProtocolStatus {
825 supported: now_supported,
826 reported: false,
827 };
828 };
829
830 if now_supported == current_status.supported {
831 return ProtocolStatus {
832 supported: now_supported,
833 reported: true,
834 };
835 }
836
837 if now_supported {
838 tracing::debug!("Remote now supports our kademlia protocol");
839 } else {
840 tracing::debug!("Remote no longer supports our kademlia protocol");
841 }
842
843 ProtocolStatus {
844 supported: now_supported,
845 reported: false,
846 }
847}
848
849impl Handler {
850 fn answer_pending_request(&mut self, request_id: RequestId, mut msg: KadResponseMsg) {
851 for state in self.inbound_substreams.iter_mut() {
852 match state.try_answer_with(request_id, msg) {
853 Ok(()) => return,
854 Err(m) => {
855 msg = m;
856 }
857 }
858 }
859
860 debug_assert!(false, "Cannot find inbound substream for {request_id:?}")
861 }
862}
863
864impl futures::Stream for InboundSubstreamState {
865 type Item = ConnectionHandlerEvent<ProtocolConfig, (), HandlerEvent>;
866
867 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
868 let this = self.get_mut();
869
870 loop {
871 match std::mem::replace(
872 this,
873 Self::Poisoned {
874 phantom: PhantomData,
875 },
876 ) {
877 InboundSubstreamState::WaitingMessage {
878 first,
879 connection_id,
880 mut substream,
881 } => match substream.poll_next_unpin(cx) {
882 Poll::Ready(Some(Ok(KadRequestMsg::Ping))) => {
883 tracing::warn!("Kademlia PING messages are unsupported");
884
885 *this = InboundSubstreamState::Closing(substream);
886 }
887 Poll::Ready(Some(Ok(KadRequestMsg::FindNode { key }))) => {
888 *this =
889 InboundSubstreamState::WaitingBehaviour(connection_id, substream, None);
890 return Poll::Ready(Some(ConnectionHandlerEvent::NotifyBehaviour(
891 HandlerEvent::FindNodeReq {
892 key,
893 request_id: RequestId {
894 connec_unique_id: connection_id,
895 },
896 },
897 )));
898 }
899 Poll::Ready(Some(Ok(KadRequestMsg::GetProviders { key }))) => {
900 *this =
901 InboundSubstreamState::WaitingBehaviour(connection_id, substream, None);
902 return Poll::Ready(Some(ConnectionHandlerEvent::NotifyBehaviour(
903 HandlerEvent::GetProvidersReq {
904 key,
905 request_id: RequestId {
906 connec_unique_id: connection_id,
907 },
908 },
909 )));
910 }
911 Poll::Ready(Some(Ok(KadRequestMsg::AddProvider { key, provider }))) => {
912 *this = InboundSubstreamState::WaitingMessage {
913 first: false,
914 connection_id,
915 substream,
916 };
917 return Poll::Ready(Some(ConnectionHandlerEvent::NotifyBehaviour(
918 HandlerEvent::AddProvider { key, provider },
919 )));
920 }
921 Poll::Ready(Some(Ok(KadRequestMsg::GetValue { key }))) => {
922 *this =
923 InboundSubstreamState::WaitingBehaviour(connection_id, substream, None);
924 return Poll::Ready(Some(ConnectionHandlerEvent::NotifyBehaviour(
925 HandlerEvent::GetRecord {
926 key,
927 request_id: RequestId {
928 connec_unique_id: connection_id,
929 },
930 },
931 )));
932 }
933 Poll::Ready(Some(Ok(KadRequestMsg::PutValue { record }))) => {
934 *this =
935 InboundSubstreamState::WaitingBehaviour(connection_id, substream, None);
936 return Poll::Ready(Some(ConnectionHandlerEvent::NotifyBehaviour(
937 HandlerEvent::PutRecord {
938 record,
939 request_id: RequestId {
940 connec_unique_id: connection_id,
941 },
942 },
943 )));
944 }
945 Poll::Pending => {
946 *this = InboundSubstreamState::WaitingMessage {
947 first,
948 connection_id,
949 substream,
950 };
951 return Poll::Pending;
952 }
953 Poll::Ready(None) => {
954 return Poll::Ready(None);
955 }
956 Poll::Ready(Some(Err(e))) => {
957 tracing::trace!("Inbound substream error: {:?}", e);
958 return Poll::Ready(None);
959 }
960 },
961 InboundSubstreamState::WaitingBehaviour(id, substream, _) => {
962 *this = InboundSubstreamState::WaitingBehaviour(
963 id,
964 substream,
965 Some(cx.waker().clone()),
966 );
967
968 return Poll::Pending;
969 }
970 InboundSubstreamState::PendingSend(id, mut substream, msg) => {
971 match substream.poll_ready_unpin(cx) {
972 Poll::Ready(Ok(())) => match substream.start_send_unpin(msg) {
973 Ok(()) => {
974 *this = InboundSubstreamState::PendingFlush(id, substream);
975 }
976 Err(_) => return Poll::Ready(None),
977 },
978 Poll::Pending => {
979 *this = InboundSubstreamState::PendingSend(id, substream, msg);
980 return Poll::Pending;
981 }
982 Poll::Ready(Err(_)) => return Poll::Ready(None),
983 }
984 }
985 InboundSubstreamState::PendingFlush(id, mut substream) => {
986 match substream.poll_flush_unpin(cx) {
987 Poll::Ready(Ok(())) => {
988 *this = InboundSubstreamState::WaitingMessage {
989 first: false,
990 connection_id: id,
991 substream,
992 };
993 }
994 Poll::Pending => {
995 *this = InboundSubstreamState::PendingFlush(id, substream);
996 return Poll::Pending;
997 }
998 Poll::Ready(Err(_)) => return Poll::Ready(None),
999 }
1000 }
1001 InboundSubstreamState::Closing(mut stream) => match stream.poll_close_unpin(cx) {
1002 Poll::Ready(Ok(())) | Poll::Ready(Err(_)) => return Poll::Ready(None),
1003 Poll::Pending => {
1004 *this = InboundSubstreamState::Closing(stream);
1005 return Poll::Pending;
1006 }
1007 },
1008 InboundSubstreamState::Poisoned { .. } => unreachable!(),
1009 InboundSubstreamState::Cancelled => return Poll::Ready(None),
1010 }
1011 }
1012 }
1013}
1014
1015fn process_kad_response(event: KadResponseMsg, query_id: QueryId) -> HandlerEvent {
1017 match event {
1019 KadResponseMsg::Pong => {
1020 HandlerEvent::QueryError {
1022 error: HandlerQueryErr::UnexpectedMessage,
1023 query_id,
1024 }
1025 }
1026 KadResponseMsg::FindNode { closer_peers } => HandlerEvent::FindNodeRes {
1027 closer_peers,
1028 query_id,
1029 },
1030 KadResponseMsg::GetProviders {
1031 closer_peers,
1032 provider_peers,
1033 } => HandlerEvent::GetProvidersRes {
1034 closer_peers,
1035 provider_peers,
1036 query_id,
1037 },
1038 KadResponseMsg::GetValue {
1039 record,
1040 closer_peers,
1041 } => HandlerEvent::GetRecordRes {
1042 record,
1043 closer_peers,
1044 query_id,
1045 },
1046 KadResponseMsg::PutValue { key, value, .. } => HandlerEvent::PutRecordRes {
1047 key,
1048 value,
1049 query_id,
1050 },
1051 }
1052}
1053
1054#[cfg(test)]
1055mod tests {
1056 use quickcheck::{Arbitrary, Gen};
1057 use tracing_subscriber::EnvFilter;
1058
1059 use super::*;
1060
1061 impl Arbitrary for ProtocolStatus {
1062 fn arbitrary(g: &mut Gen) -> Self {
1063 Self {
1064 supported: bool::arbitrary(g),
1065 reported: bool::arbitrary(g),
1066 }
1067 }
1068 }
1069
1070 #[test]
1071 fn compute_next_protocol_status_test() {
1072 let _ = tracing_subscriber::fmt()
1073 .with_env_filter(EnvFilter::from_default_env())
1074 .try_init();
1075
1076 fn prop(now_supported: bool, current: Option<ProtocolStatus>) {
1077 let new = compute_new_protocol_status(now_supported, current);
1078
1079 match current {
1080 None => {
1081 assert!(!new.reported);
1082 assert_eq!(new.supported, now_supported);
1083 }
1084 Some(current) => {
1085 if current.supported == now_supported {
1086 assert!(new.reported);
1087 } else {
1088 assert!(!new.reported);
1089 }
1090
1091 assert_eq!(new.supported, now_supported);
1092 }
1093 }
1094 }
1095
1096 quickcheck::quickcheck(prop as fn(_, _))
1097 }
1098}