libp2p_kad/
handler.rs

1// Copyright 2018 Parity Technologies (UK) Ltd.
2//
3// Permission is hereby granted, free of charge, to any person obtaining a
4// copy of this software and associated documentation files (the "Software"),
5// to deal in the Software without restriction, including without limitation
6// the rights to use, copy, modify, merge, publish, distribute, sublicense,
7// and/or sell copies of the Software, and to permit persons to whom the
8// Software is furnished to do so, subject to the following conditions:
9//
10// The above copyright notice and this permission notice shall be included in
11// all copies or substantial portions of the Software.
12//
13// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
14// OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
16// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
18// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
19// DEALINGS IN THE SOFTWARE.
20
21use std::{
22    collections::VecDeque,
23    error, fmt, io,
24    marker::PhantomData,
25    pin::Pin,
26    task::{Context, Poll, Waker},
27    time::Duration,
28};
29
30use either::Either;
31use futures::{channel::oneshot, prelude::*, stream::SelectAll};
32use libp2p_core::{upgrade, ConnectedPoint};
33use libp2p_identity::PeerId;
34use libp2p_swarm::{
35    handler::{ConnectionEvent, FullyNegotiatedInbound, FullyNegotiatedOutbound},
36    ConnectionHandler, ConnectionHandlerEvent, Stream, StreamUpgradeError, SubstreamProtocol,
37    SupportedProtocols,
38};
39
40use crate::{
41    behaviour::Mode,
42    protocol::{
43        KadInStreamSink, KadOutStreamSink, KadPeer, KadRequestMsg, KadResponseMsg, ProtocolConfig,
44    },
45    record::{self, Record},
46    QueryId,
47};
48
49const MAX_NUM_STREAMS: usize = 32;
50
51/// Protocol handler that manages substreams for the Kademlia protocol
52/// on a single connection with a peer.
53///
54/// The handler will automatically open a Kademlia substream with the remote for each request we
55/// make.
56///
57/// It also handles requests made by the remote.
58pub struct Handler {
59    /// Configuration of the wire protocol.
60    protocol_config: ProtocolConfig,
61
62    /// In client mode, we don't accept inbound substreams.
63    mode: Mode,
64
65    /// Next unique ID of a connection.
66    next_connec_unique_id: UniqueConnecId,
67
68    /// List of active outbound streams.
69    outbound_substreams:
70        futures_bounded::FuturesTupleSet<io::Result<Option<KadResponseMsg>>, QueryId>,
71
72    /// Contains one [`oneshot::Sender`] per outbound stream that we have requested.
73    pending_streams:
74        VecDeque<oneshot::Sender<Result<KadOutStreamSink<Stream>, StreamUpgradeError<io::Error>>>>,
75
76    /// List of outbound substreams that are waiting to become active next.
77    /// Contains the request we want to send, and the user data if we expect an answer.
78    pending_messages: VecDeque<(KadRequestMsg, QueryId)>,
79
80    /// List of active inbound substreams with the state they are in.
81    inbound_substreams: SelectAll<InboundSubstreamState>,
82
83    /// The connected endpoint of the connection that the handler
84    /// is associated with.
85    endpoint: ConnectedPoint,
86
87    /// The [`PeerId`] of the remote.
88    remote_peer_id: PeerId,
89
90    /// The current state of protocol confirmation.
91    protocol_status: Option<ProtocolStatus>,
92
93    remote_supported_protocols: SupportedProtocols,
94}
95
96/// The states of protocol confirmation that a connection
97/// handler transitions through.
98#[derive(Debug, Copy, Clone, PartialEq)]
99struct ProtocolStatus {
100    /// Whether the remote node supports one of our kademlia protocols.
101    supported: bool,
102    /// Whether we reported the state to the behaviour.
103    reported: bool,
104}
105
106/// State of an active inbound substream.
107enum InboundSubstreamState {
108    /// Waiting for a request from the remote.
109    WaitingMessage {
110        /// Whether it is the first message to be awaited on this stream.
111        first: bool,
112        connection_id: UniqueConnecId,
113        substream: KadInStreamSink<Stream>,
114    },
115    /// Waiting for the behaviour to send a [`HandlerIn`] event containing the response.
116    WaitingBehaviour(UniqueConnecId, KadInStreamSink<Stream>, Option<Waker>),
117    /// Waiting to send an answer back to the remote.
118    PendingSend(UniqueConnecId, KadInStreamSink<Stream>, KadResponseMsg),
119    /// Waiting to flush an answer back to the remote.
120    PendingFlush(UniqueConnecId, KadInStreamSink<Stream>),
121    /// The substream is being closed.
122    Closing(KadInStreamSink<Stream>),
123    /// The substream was cancelled in favor of a new one.
124    Cancelled,
125
126    Poisoned {
127        phantom: PhantomData<QueryId>,
128    },
129}
130
131impl InboundSubstreamState {
132    fn try_answer_with(
133        &mut self,
134        id: RequestId,
135        msg: KadResponseMsg,
136    ) -> Result<(), KadResponseMsg> {
137        match std::mem::replace(
138            self,
139            InboundSubstreamState::Poisoned {
140                phantom: PhantomData,
141            },
142        ) {
143            InboundSubstreamState::WaitingBehaviour(conn_id, substream, mut waker)
144                if conn_id == id.connec_unique_id =>
145            {
146                *self = InboundSubstreamState::PendingSend(conn_id, substream, msg);
147
148                if let Some(waker) = waker.take() {
149                    waker.wake();
150                }
151
152                Ok(())
153            }
154            other => {
155                *self = other;
156
157                Err(msg)
158            }
159        }
160    }
161
162    fn close(&mut self) {
163        match std::mem::replace(
164            self,
165            InboundSubstreamState::Poisoned {
166                phantom: PhantomData,
167            },
168        ) {
169            InboundSubstreamState::WaitingMessage { substream, .. }
170            | InboundSubstreamState::WaitingBehaviour(_, substream, _)
171            | InboundSubstreamState::PendingSend(_, substream, _)
172            | InboundSubstreamState::PendingFlush(_, substream)
173            | InboundSubstreamState::Closing(substream) => {
174                *self = InboundSubstreamState::Closing(substream);
175            }
176            InboundSubstreamState::Cancelled => {
177                *self = InboundSubstreamState::Cancelled;
178            }
179            InboundSubstreamState::Poisoned { .. } => unreachable!(),
180        }
181    }
182}
183
184/// Event produced by the Kademlia handler.
185#[derive(Debug)]
186pub enum HandlerEvent {
187    /// The configured protocol name has been confirmed by the peer through
188    /// a successfully negotiated substream or by learning the supported protocols of the remote.
189    ProtocolConfirmed { endpoint: ConnectedPoint },
190    /// The configured protocol name(s) are not or no longer supported by the peer on the provided
191    /// connection and it should be removed from the routing table.
192    ProtocolNotSupported { endpoint: ConnectedPoint },
193
194    /// Request for the list of nodes whose IDs are the closest to `key`. The number of nodes
195    /// returned is not specified, but should be around 20.
196    FindNodeReq {
197        /// The key for which to locate the closest nodes.
198        key: Vec<u8>,
199        /// Identifier of the request. Needs to be passed back when answering.
200        request_id: RequestId,
201    },
202
203    /// Response to an `HandlerIn::FindNodeReq`.
204    FindNodeRes {
205        /// Results of the request.
206        closer_peers: Vec<KadPeer>,
207        /// The user data passed to the `FindNodeReq`.
208        query_id: QueryId,
209    },
210
211    /// Same as `FindNodeReq`, but should also return the entries of the local providers list for
212    /// this key.
213    GetProvidersReq {
214        /// The key for which providers are requested.
215        key: record::Key,
216        /// Identifier of the request. Needs to be passed back when answering.
217        request_id: RequestId,
218    },
219
220    /// Response to an `HandlerIn::GetProvidersReq`.
221    GetProvidersRes {
222        /// Nodes closest to the key.
223        closer_peers: Vec<KadPeer>,
224        /// Known providers for this key.
225        provider_peers: Vec<KadPeer>,
226        /// The user data passed to the `GetProvidersReq`.
227        query_id: QueryId,
228    },
229
230    /// An error happened when performing a query.
231    QueryError {
232        /// The error that happened.
233        error: HandlerQueryErr,
234        /// The user data passed to the query.
235        query_id: QueryId,
236    },
237
238    /// The peer announced itself as a provider of a key.
239    AddProvider {
240        /// The key for which the peer is a provider of the associated value.
241        key: record::Key,
242        /// The peer that is the provider of the value for `key`.
243        provider: KadPeer,
244    },
245
246    /// Request to get a value from the dht records
247    GetRecord {
248        /// Key for which we should look in the dht
249        key: record::Key,
250        /// Identifier of the request. Needs to be passed back when answering.
251        request_id: RequestId,
252    },
253
254    /// Response to a `HandlerIn::GetRecord`.
255    GetRecordRes {
256        /// The result is present if the key has been found
257        record: Option<Record>,
258        /// Nodes closest to the key.
259        closer_peers: Vec<KadPeer>,
260        /// The user data passed to the `GetValue`.
261        query_id: QueryId,
262    },
263
264    /// Request to put a value in the dht records
265    PutRecord {
266        record: Record,
267        /// Identifier of the request. Needs to be passed back when answering.
268        request_id: RequestId,
269    },
270
271    /// Response to a request to store a record.
272    PutRecordRes {
273        /// The key of the stored record.
274        key: record::Key,
275        /// The value of the stored record.
276        value: Vec<u8>,
277        /// The user data passed to the `PutValue`.
278        query_id: QueryId,
279    },
280}
281
282/// Error that can happen when requesting an RPC query.
283#[derive(Debug)]
284pub enum HandlerQueryErr {
285    /// Received an answer that doesn't correspond to the request.
286    UnexpectedMessage,
287    /// I/O error in the substream.
288    Io(io::Error),
289}
290
291impl fmt::Display for HandlerQueryErr {
292    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
293        match self {
294            HandlerQueryErr::UnexpectedMessage => {
295                write!(
296                    f,
297                    "Remote answered our Kademlia RPC query with the wrong message type"
298                )
299            }
300            HandlerQueryErr::Io(err) => {
301                write!(f, "I/O error during a Kademlia RPC query: {err}")
302            }
303        }
304    }
305}
306
307impl error::Error for HandlerQueryErr {
308    fn source(&self) -> Option<&(dyn error::Error + 'static)> {
309        match self {
310            HandlerQueryErr::UnexpectedMessage => None,
311            HandlerQueryErr::Io(err) => Some(err),
312        }
313    }
314}
315
316/// Event to send to the handler.
317#[derive(Debug)]
318pub enum HandlerIn {
319    /// Resets the (sub)stream associated with the given request ID,
320    /// thus signaling an error to the remote.
321    ///
322    /// Explicitly resetting the (sub)stream associated with a request
323    /// can be used as an alternative to letting requests simply time
324    /// out on the remote peer, thus potentially avoiding some delay
325    /// for the query on the remote.
326    Reset(RequestId),
327
328    /// Change the connection to the specified mode.
329    ReconfigureMode { new_mode: Mode },
330
331    /// Request for the list of nodes whose IDs are the closest to `key`. The number of nodes
332    /// returned is not specified, but should be around 20.
333    FindNodeReq {
334        /// Identifier of the node.
335        key: Vec<u8>,
336        /// ID of the query that generated this request.
337        query_id: QueryId,
338    },
339
340    /// Response to a `FindNodeReq`.
341    FindNodeRes {
342        /// Results of the request.
343        closer_peers: Vec<KadPeer>,
344        /// Identifier of the request that was made by the remote.
345        ///
346        /// It is a logic error to use an id of the handler of a different node.
347        request_id: RequestId,
348    },
349
350    /// Same as `FindNodeReq`, but should also return the entries of the local providers list for
351    /// this key.
352    GetProvidersReq {
353        /// Identifier being searched.
354        key: record::Key,
355        /// ID of the query that generated this request.
356        query_id: QueryId,
357    },
358
359    /// Response to a `GetProvidersReq`.
360    GetProvidersRes {
361        /// Nodes closest to the key.
362        closer_peers: Vec<KadPeer>,
363        /// Known providers for this key.
364        provider_peers: Vec<KadPeer>,
365        /// Identifier of the request that was made by the remote.
366        ///
367        /// It is a logic error to use an id of the handler of a different node.
368        request_id: RequestId,
369    },
370
371    /// Indicates that this provider is known for this key.
372    ///
373    /// The API of the handler doesn't expose any event that allows you to know whether this
374    /// succeeded.
375    AddProvider {
376        /// Key for which we should add providers.
377        key: record::Key,
378        /// Known provider for this key.
379        provider: KadPeer,
380        /// ID of the query that generated this request.
381        query_id: QueryId,
382    },
383
384    /// Request to retrieve a record from the DHT.
385    GetRecord {
386        /// The key of the record.
387        key: record::Key,
388        /// ID of the query that generated this request.
389        query_id: QueryId,
390    },
391
392    /// Response to a `GetRecord` request.
393    GetRecordRes {
394        /// The value that might have been found in our storage.
395        record: Option<Record>,
396        /// Nodes that are closer to the key we were searching for.
397        closer_peers: Vec<KadPeer>,
398        /// Identifier of the request that was made by the remote.
399        request_id: RequestId,
400    },
401
402    /// Put a value into the dht records.
403    PutRecord {
404        record: Record,
405        /// ID of the query that generated this request.
406        query_id: QueryId,
407    },
408
409    /// Response to a `PutRecord`.
410    PutRecordRes {
411        /// Key of the value that was put.
412        key: record::Key,
413        /// Value that was put.
414        value: Vec<u8>,
415        /// Identifier of the request that was made by the remote.
416        request_id: RequestId,
417    },
418}
419
420/// Unique identifier for a request. Must be passed back in order to answer a request from
421/// the remote.
422#[derive(Debug, PartialEq, Eq, Copy, Clone)]
423pub struct RequestId {
424    /// Unique identifier for an incoming connection.
425    connec_unique_id: UniqueConnecId,
426}
427
428/// Unique identifier for a connection.
429#[derive(Debug, Copy, Clone, PartialEq, Eq)]
430struct UniqueConnecId(u64);
431
432impl Handler {
433    pub fn new(
434        protocol_config: ProtocolConfig,
435        endpoint: ConnectedPoint,
436        remote_peer_id: PeerId,
437        mode: Mode,
438    ) -> Self {
439        match &endpoint {
440            ConnectedPoint::Dialer { .. } => {
441                tracing::debug!(
442                    peer=%remote_peer_id,
443                    mode=%mode,
444                    "New outbound connection"
445                );
446            }
447            ConnectedPoint::Listener { .. } => {
448                tracing::debug!(
449                    peer=%remote_peer_id,
450                    mode=%mode,
451                    "New inbound connection"
452                );
453            }
454        }
455
456        Handler {
457            protocol_config,
458            mode,
459            endpoint,
460            remote_peer_id,
461            next_connec_unique_id: UniqueConnecId(0),
462            inbound_substreams: Default::default(),
463            outbound_substreams: futures_bounded::FuturesTupleSet::new(
464                Duration::from_secs(10),
465                MAX_NUM_STREAMS,
466            ),
467            pending_streams: Default::default(),
468            pending_messages: Default::default(),
469            protocol_status: None,
470            remote_supported_protocols: Default::default(),
471        }
472    }
473
474    fn on_fully_negotiated_outbound(
475        &mut self,
476        FullyNegotiatedOutbound {
477            protocol: stream,
478            info: (),
479        }: FullyNegotiatedOutbound<<Self as ConnectionHandler>::OutboundProtocol>,
480    ) {
481        if let Some(sender) = self.pending_streams.pop_front() {
482            let _ = sender.send(Ok(stream));
483        }
484
485        if self.protocol_status.is_none() {
486            // Upon the first successfully negotiated substream, we know that the
487            // remote is configured with the same protocol name and we want
488            // the behaviour to add this peer to the routing table, if possible.
489            self.protocol_status = Some(ProtocolStatus {
490                supported: true,
491                reported: false,
492            });
493        }
494    }
495
496    fn on_fully_negotiated_inbound(
497        &mut self,
498        FullyNegotiatedInbound { protocol, .. }: FullyNegotiatedInbound<
499            <Self as ConnectionHandler>::InboundProtocol,
500        >,
501    ) {
502        // If `self.allow_listening` is false, then we produced a `DeniedUpgrade` and `protocol`
503        // is a `Infallible`.
504        let protocol = match protocol {
505            future::Either::Left(p) => p,
506            future::Either::Right(p) => libp2p_core::util::unreachable(p),
507        };
508
509        if self.protocol_status.is_none() {
510            // Upon the first successfully negotiated substream, we know that the
511            // remote is configured with the same protocol name and we want
512            // the behaviour to add this peer to the routing table, if possible.
513            self.protocol_status = Some(ProtocolStatus {
514                supported: true,
515                reported: false,
516            });
517        }
518
519        if self.inbound_substreams.len() == MAX_NUM_STREAMS {
520            if let Some(s) = self.inbound_substreams.iter_mut().find(|s| {
521                matches!(
522                    s,
523                    // An inbound substream waiting to be reused.
524                    InboundSubstreamState::WaitingMessage { first: false, .. }
525                )
526            }) {
527                *s = InboundSubstreamState::Cancelled;
528                tracing::debug!(
529                    peer=?self.remote_peer_id,
530                    "New inbound substream to peer exceeds inbound substream limit. \
531                    Removed older substream waiting to be reused."
532                )
533            } else {
534                tracing::warn!(
535                    peer=?self.remote_peer_id,
536                    "New inbound substream to peer exceeds inbound substream limit. \
537                     No older substream waiting to be reused. Dropping new substream."
538                );
539                return;
540            }
541        }
542
543        let connec_unique_id = self.next_connec_unique_id;
544        self.next_connec_unique_id.0 += 1;
545        self.inbound_substreams
546            .push(InboundSubstreamState::WaitingMessage {
547                first: true,
548                connection_id: connec_unique_id,
549                substream: protocol,
550            });
551    }
552
553    /// Takes the given [`KadRequestMsg`] and composes it into an outbound request-response protocol
554    /// handshake using a [`oneshot::channel`].
555    fn queue_new_stream(&mut self, id: QueryId, msg: KadRequestMsg) {
556        let (sender, receiver) = oneshot::channel();
557
558        self.pending_streams.push_back(sender);
559        let result = self.outbound_substreams.try_push(
560            async move {
561                let mut stream = receiver
562                    .await
563                    .map_err(|_| io::Error::from(io::ErrorKind::BrokenPipe))?
564                    .map_err(|e| match e {
565                        StreamUpgradeError::Timeout => io::ErrorKind::TimedOut.into(),
566                        StreamUpgradeError::Apply(e) => e,
567                        StreamUpgradeError::NegotiationFailed => io::Error::new(
568                            io::ErrorKind::ConnectionRefused,
569                            "protocol not supported",
570                        ),
571                        StreamUpgradeError::Io(e) => e,
572                    })?;
573
574                let has_answer = !matches!(msg, KadRequestMsg::AddProvider { .. });
575
576                stream.send(msg).await?;
577                stream.close().await?;
578
579                if !has_answer {
580                    return Ok(None);
581                }
582
583                let msg = stream.next().await.ok_or(io::ErrorKind::UnexpectedEof)??;
584
585                Ok(Some(msg))
586            },
587            id,
588        );
589
590        debug_assert!(
591            result.is_ok(),
592            "Expected to not create more streams than allowed"
593        );
594    }
595}
596
597impl ConnectionHandler for Handler {
598    type FromBehaviour = HandlerIn;
599    type ToBehaviour = HandlerEvent;
600    type InboundProtocol = Either<ProtocolConfig, upgrade::DeniedUpgrade>;
601    type OutboundProtocol = ProtocolConfig;
602    type OutboundOpenInfo = ();
603    type InboundOpenInfo = ();
604
605    fn listen_protocol(&self) -> SubstreamProtocol<Self::InboundProtocol> {
606        match self.mode {
607            Mode::Server => SubstreamProtocol::new(Either::Left(self.protocol_config.clone()), ()),
608            Mode::Client => SubstreamProtocol::new(Either::Right(upgrade::DeniedUpgrade), ()),
609        }
610    }
611
612    fn on_behaviour_event(&mut self, message: HandlerIn) {
613        match message {
614            HandlerIn::Reset(request_id) => {
615                if let Some(state) = self
616                    .inbound_substreams
617                    .iter_mut()
618                    .find(|state| match state {
619                        InboundSubstreamState::WaitingBehaviour(conn_id, _, _) => {
620                            conn_id == &request_id.connec_unique_id
621                        }
622                        _ => false,
623                    })
624                {
625                    state.close();
626                }
627            }
628            HandlerIn::FindNodeReq { key, query_id } => {
629                let msg = KadRequestMsg::FindNode { key };
630                self.pending_messages.push_back((msg, query_id));
631            }
632            HandlerIn::FindNodeRes {
633                closer_peers,
634                request_id,
635            } => self.answer_pending_request(request_id, KadResponseMsg::FindNode { closer_peers }),
636            HandlerIn::GetProvidersReq { key, query_id } => {
637                let msg = KadRequestMsg::GetProviders { key };
638                self.pending_messages.push_back((msg, query_id));
639            }
640            HandlerIn::GetProvidersRes {
641                closer_peers,
642                provider_peers,
643                request_id,
644            } => self.answer_pending_request(
645                request_id,
646                KadResponseMsg::GetProviders {
647                    closer_peers,
648                    provider_peers,
649                },
650            ),
651            HandlerIn::AddProvider {
652                key,
653                provider,
654                query_id,
655            } => {
656                let msg = KadRequestMsg::AddProvider { key, provider };
657                self.pending_messages.push_back((msg, query_id));
658            }
659            HandlerIn::GetRecord { key, query_id } => {
660                let msg = KadRequestMsg::GetValue { key };
661                self.pending_messages.push_back((msg, query_id));
662            }
663            HandlerIn::PutRecord { record, query_id } => {
664                let msg = KadRequestMsg::PutValue { record };
665                self.pending_messages.push_back((msg, query_id));
666            }
667            HandlerIn::GetRecordRes {
668                record,
669                closer_peers,
670                request_id,
671            } => {
672                self.answer_pending_request(
673                    request_id,
674                    KadResponseMsg::GetValue {
675                        record,
676                        closer_peers,
677                    },
678                );
679            }
680            HandlerIn::PutRecordRes {
681                key,
682                request_id,
683                value,
684            } => {
685                self.answer_pending_request(request_id, KadResponseMsg::PutValue { key, value });
686            }
687            HandlerIn::ReconfigureMode { new_mode } => {
688                let peer = self.remote_peer_id;
689
690                match &self.endpoint {
691                    ConnectedPoint::Dialer { .. } => {
692                        tracing::debug!(
693                            %peer,
694                            mode=%new_mode,
695                            "Changed mode on outbound connection"
696                        )
697                    }
698                    ConnectedPoint::Listener { local_addr, .. } => {
699                        tracing::debug!(
700                            %peer,
701                            mode=%new_mode,
702                            local_address=%local_addr,
703                            "Changed mode on inbound connection assuming that one of our external addresses routes to the local address")
704                    }
705                }
706
707                self.mode = new_mode;
708            }
709        }
710    }
711
712    #[tracing::instrument(level = "trace", name = "ConnectionHandler::poll", skip(self, cx))]
713    fn poll(
714        &mut self,
715        cx: &mut Context<'_>,
716    ) -> Poll<ConnectionHandlerEvent<Self::OutboundProtocol, (), Self::ToBehaviour>> {
717        loop {
718            match &mut self.protocol_status {
719                Some(status) if !status.reported => {
720                    status.reported = true;
721                    let event = if status.supported {
722                        HandlerEvent::ProtocolConfirmed {
723                            endpoint: self.endpoint.clone(),
724                        }
725                    } else {
726                        HandlerEvent::ProtocolNotSupported {
727                            endpoint: self.endpoint.clone(),
728                        }
729                    };
730
731                    return Poll::Ready(ConnectionHandlerEvent::NotifyBehaviour(event));
732                }
733                _ => {}
734            }
735
736            match self.outbound_substreams.poll_unpin(cx) {
737                Poll::Ready((Ok(Ok(Some(response))), query_id)) => {
738                    return Poll::Ready(ConnectionHandlerEvent::NotifyBehaviour(
739                        process_kad_response(response, query_id),
740                    ))
741                }
742                Poll::Ready((Ok(Ok(None)), _)) => {
743                    continue;
744                }
745                Poll::Ready((Ok(Err(e)), query_id)) => {
746                    return Poll::Ready(ConnectionHandlerEvent::NotifyBehaviour(
747                        HandlerEvent::QueryError {
748                            error: HandlerQueryErr::Io(e),
749                            query_id,
750                        },
751                    ))
752                }
753                Poll::Ready((Err(_timeout), query_id)) => {
754                    return Poll::Ready(ConnectionHandlerEvent::NotifyBehaviour(
755                        HandlerEvent::QueryError {
756                            error: HandlerQueryErr::Io(io::ErrorKind::TimedOut.into()),
757                            query_id,
758                        },
759                    ))
760                }
761                Poll::Pending => {}
762            }
763
764            if let Poll::Ready(Some(event)) = self.inbound_substreams.poll_next_unpin(cx) {
765                return Poll::Ready(event);
766            }
767
768            if self.outbound_substreams.len() < MAX_NUM_STREAMS {
769                if let Some((msg, id)) = self.pending_messages.pop_front() {
770                    self.queue_new_stream(id, msg);
771                    return Poll::Ready(ConnectionHandlerEvent::OutboundSubstreamRequest {
772                        protocol: SubstreamProtocol::new(self.protocol_config.clone(), ()),
773                    });
774                }
775            }
776
777            return Poll::Pending;
778        }
779    }
780
781    fn on_connection_event(
782        &mut self,
783        event: ConnectionEvent<Self::InboundProtocol, Self::OutboundProtocol>,
784    ) {
785        match event {
786            ConnectionEvent::FullyNegotiatedOutbound(fully_negotiated_outbound) => {
787                self.on_fully_negotiated_outbound(fully_negotiated_outbound)
788            }
789            ConnectionEvent::FullyNegotiatedInbound(fully_negotiated_inbound) => {
790                self.on_fully_negotiated_inbound(fully_negotiated_inbound)
791            }
792            ConnectionEvent::DialUpgradeError(ev) => {
793                if let Some(sender) = self.pending_streams.pop_front() {
794                    let _ = sender.send(Err(ev.error));
795                }
796            }
797            ConnectionEvent::RemoteProtocolsChange(change) => {
798                let dirty = self.remote_supported_protocols.on_protocols_change(change);
799
800                if dirty {
801                    let remote_supports_our_kademlia_protocols = self
802                        .remote_supported_protocols
803                        .iter()
804                        .any(|p| self.protocol_config.protocol_names().contains(p));
805
806                    self.protocol_status = Some(compute_new_protocol_status(
807                        remote_supports_our_kademlia_protocols,
808                        self.protocol_status,
809                    ))
810                }
811            }
812            _ => {}
813        }
814    }
815}
816
817fn compute_new_protocol_status(
818    now_supported: bool,
819    current_status: Option<ProtocolStatus>,
820) -> ProtocolStatus {
821    let current_status = match current_status {
822        None => {
823            return ProtocolStatus {
824                supported: now_supported,
825                reported: false,
826            }
827        }
828        Some(current) => current,
829    };
830
831    if now_supported == current_status.supported {
832        return ProtocolStatus {
833            supported: now_supported,
834            reported: true,
835        };
836    }
837
838    if now_supported {
839        tracing::debug!("Remote now supports our kademlia protocol");
840    } else {
841        tracing::debug!("Remote no longer supports our kademlia protocol");
842    }
843
844    ProtocolStatus {
845        supported: now_supported,
846        reported: false,
847    }
848}
849
850impl Handler {
851    fn answer_pending_request(&mut self, request_id: RequestId, mut msg: KadResponseMsg) {
852        for state in self.inbound_substreams.iter_mut() {
853            match state.try_answer_with(request_id, msg) {
854                Ok(()) => return,
855                Err(m) => {
856                    msg = m;
857                }
858            }
859        }
860
861        debug_assert!(false, "Cannot find inbound substream for {request_id:?}")
862    }
863}
864
865impl futures::Stream for InboundSubstreamState {
866    type Item = ConnectionHandlerEvent<ProtocolConfig, (), HandlerEvent>;
867
868    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
869        let this = self.get_mut();
870
871        loop {
872            match std::mem::replace(
873                this,
874                Self::Poisoned {
875                    phantom: PhantomData,
876                },
877            ) {
878                InboundSubstreamState::WaitingMessage {
879                    first,
880                    connection_id,
881                    mut substream,
882                } => match substream.poll_next_unpin(cx) {
883                    Poll::Ready(Some(Ok(KadRequestMsg::Ping))) => {
884                        tracing::warn!("Kademlia PING messages are unsupported");
885
886                        *this = InboundSubstreamState::Closing(substream);
887                    }
888                    Poll::Ready(Some(Ok(KadRequestMsg::FindNode { key }))) => {
889                        *this =
890                            InboundSubstreamState::WaitingBehaviour(connection_id, substream, None);
891                        return Poll::Ready(Some(ConnectionHandlerEvent::NotifyBehaviour(
892                            HandlerEvent::FindNodeReq {
893                                key,
894                                request_id: RequestId {
895                                    connec_unique_id: connection_id,
896                                },
897                            },
898                        )));
899                    }
900                    Poll::Ready(Some(Ok(KadRequestMsg::GetProviders { key }))) => {
901                        *this =
902                            InboundSubstreamState::WaitingBehaviour(connection_id, substream, None);
903                        return Poll::Ready(Some(ConnectionHandlerEvent::NotifyBehaviour(
904                            HandlerEvent::GetProvidersReq {
905                                key,
906                                request_id: RequestId {
907                                    connec_unique_id: connection_id,
908                                },
909                            },
910                        )));
911                    }
912                    Poll::Ready(Some(Ok(KadRequestMsg::AddProvider { key, provider }))) => {
913                        *this = InboundSubstreamState::WaitingMessage {
914                            first: false,
915                            connection_id,
916                            substream,
917                        };
918                        return Poll::Ready(Some(ConnectionHandlerEvent::NotifyBehaviour(
919                            HandlerEvent::AddProvider { key, provider },
920                        )));
921                    }
922                    Poll::Ready(Some(Ok(KadRequestMsg::GetValue { key }))) => {
923                        *this =
924                            InboundSubstreamState::WaitingBehaviour(connection_id, substream, None);
925                        return Poll::Ready(Some(ConnectionHandlerEvent::NotifyBehaviour(
926                            HandlerEvent::GetRecord {
927                                key,
928                                request_id: RequestId {
929                                    connec_unique_id: connection_id,
930                                },
931                            },
932                        )));
933                    }
934                    Poll::Ready(Some(Ok(KadRequestMsg::PutValue { record }))) => {
935                        *this =
936                            InboundSubstreamState::WaitingBehaviour(connection_id, substream, None);
937                        return Poll::Ready(Some(ConnectionHandlerEvent::NotifyBehaviour(
938                            HandlerEvent::PutRecord {
939                                record,
940                                request_id: RequestId {
941                                    connec_unique_id: connection_id,
942                                },
943                            },
944                        )));
945                    }
946                    Poll::Pending => {
947                        *this = InboundSubstreamState::WaitingMessage {
948                            first,
949                            connection_id,
950                            substream,
951                        };
952                        return Poll::Pending;
953                    }
954                    Poll::Ready(None) => {
955                        return Poll::Ready(None);
956                    }
957                    Poll::Ready(Some(Err(e))) => {
958                        tracing::trace!("Inbound substream error: {:?}", e);
959                        return Poll::Ready(None);
960                    }
961                },
962                InboundSubstreamState::WaitingBehaviour(id, substream, _) => {
963                    *this = InboundSubstreamState::WaitingBehaviour(
964                        id,
965                        substream,
966                        Some(cx.waker().clone()),
967                    );
968
969                    return Poll::Pending;
970                }
971                InboundSubstreamState::PendingSend(id, mut substream, msg) => {
972                    match substream.poll_ready_unpin(cx) {
973                        Poll::Ready(Ok(())) => match substream.start_send_unpin(msg) {
974                            Ok(()) => {
975                                *this = InboundSubstreamState::PendingFlush(id, substream);
976                            }
977                            Err(_) => return Poll::Ready(None),
978                        },
979                        Poll::Pending => {
980                            *this = InboundSubstreamState::PendingSend(id, substream, msg);
981                            return Poll::Pending;
982                        }
983                        Poll::Ready(Err(_)) => return Poll::Ready(None),
984                    }
985                }
986                InboundSubstreamState::PendingFlush(id, mut substream) => {
987                    match substream.poll_flush_unpin(cx) {
988                        Poll::Ready(Ok(())) => {
989                            *this = InboundSubstreamState::WaitingMessage {
990                                first: false,
991                                connection_id: id,
992                                substream,
993                            };
994                        }
995                        Poll::Pending => {
996                            *this = InboundSubstreamState::PendingFlush(id, substream);
997                            return Poll::Pending;
998                        }
999                        Poll::Ready(Err(_)) => return Poll::Ready(None),
1000                    }
1001                }
1002                InboundSubstreamState::Closing(mut stream) => match stream.poll_close_unpin(cx) {
1003                    Poll::Ready(Ok(())) | Poll::Ready(Err(_)) => return Poll::Ready(None),
1004                    Poll::Pending => {
1005                        *this = InboundSubstreamState::Closing(stream);
1006                        return Poll::Pending;
1007                    }
1008                },
1009                InboundSubstreamState::Poisoned { .. } => unreachable!(),
1010                InboundSubstreamState::Cancelled => return Poll::Ready(None),
1011            }
1012        }
1013    }
1014}
1015
1016/// Process a Kademlia message that's supposed to be a response to one of our requests.
1017fn process_kad_response(event: KadResponseMsg, query_id: QueryId) -> HandlerEvent {
1018    // TODO: must check that the response corresponds to the request
1019    match event {
1020        KadResponseMsg::Pong => {
1021            // We never send out pings.
1022            HandlerEvent::QueryError {
1023                error: HandlerQueryErr::UnexpectedMessage,
1024                query_id,
1025            }
1026        }
1027        KadResponseMsg::FindNode { closer_peers } => HandlerEvent::FindNodeRes {
1028            closer_peers,
1029            query_id,
1030        },
1031        KadResponseMsg::GetProviders {
1032            closer_peers,
1033            provider_peers,
1034        } => HandlerEvent::GetProvidersRes {
1035            closer_peers,
1036            provider_peers,
1037            query_id,
1038        },
1039        KadResponseMsg::GetValue {
1040            record,
1041            closer_peers,
1042        } => HandlerEvent::GetRecordRes {
1043            record,
1044            closer_peers,
1045            query_id,
1046        },
1047        KadResponseMsg::PutValue { key, value, .. } => HandlerEvent::PutRecordRes {
1048            key,
1049            value,
1050            query_id,
1051        },
1052    }
1053}
1054
1055#[cfg(test)]
1056mod tests {
1057    use quickcheck::{Arbitrary, Gen};
1058    use tracing_subscriber::EnvFilter;
1059
1060    use super::*;
1061
1062    impl Arbitrary for ProtocolStatus {
1063        fn arbitrary(g: &mut Gen) -> Self {
1064            Self {
1065                supported: bool::arbitrary(g),
1066                reported: bool::arbitrary(g),
1067            }
1068        }
1069    }
1070
1071    #[test]
1072    fn compute_next_protocol_status_test() {
1073        let _ = tracing_subscriber::fmt()
1074            .with_env_filter(EnvFilter::from_default_env())
1075            .try_init();
1076
1077        fn prop(now_supported: bool, current: Option<ProtocolStatus>) {
1078            let new = compute_new_protocol_status(now_supported, current);
1079
1080            match current {
1081                None => {
1082                    assert!(!new.reported);
1083                    assert_eq!(new.supported, now_supported);
1084                }
1085                Some(current) => {
1086                    if current.supported == now_supported {
1087                        assert!(new.reported);
1088                    } else {
1089                        assert!(!new.reported);
1090                    }
1091
1092                    assert_eq!(new.supported, now_supported);
1093                }
1094            }
1095        }
1096
1097        quickcheck::quickcheck(prop as fn(_, _))
1098    }
1099}