libp2p_kad/
handler.rs

1// Copyright 2018 Parity Technologies (UK) Ltd.
2//
3// Permission is hereby granted, free of charge, to any person obtaining a
4// copy of this software and associated documentation files (the "Software"),
5// to deal in the Software without restriction, including without limitation
6// the rights to use, copy, modify, merge, publish, distribute, sublicense,
7// and/or sell copies of the Software, and to permit persons to whom the
8// Software is furnished to do so, subject to the following conditions:
9//
10// The above copyright notice and this permission notice shall be included in
11// all copies or substantial portions of the Software.
12//
13// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
14// OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
16// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
18// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
19// DEALINGS IN THE SOFTWARE.
20
21use std::{
22    collections::VecDeque,
23    error, fmt, io,
24    marker::PhantomData,
25    pin::Pin,
26    task::{Context, Poll, Waker},
27};
28
29use either::Either;
30use futures::{channel::oneshot, prelude::*, stream::SelectAll};
31use libp2p_core::{upgrade, ConnectedPoint};
32use libp2p_identity::PeerId;
33use libp2p_swarm::{
34    handler::{ConnectionEvent, FullyNegotiatedInbound, FullyNegotiatedOutbound},
35    ConnectionHandler, ConnectionHandlerEvent, Stream, StreamUpgradeError, SubstreamProtocol,
36    SupportedProtocols,
37};
38
39use crate::{
40    behaviour::Mode,
41    protocol::{
42        KadInStreamSink, KadOutStreamSink, KadPeer, KadRequestMsg, KadResponseMsg, ProtocolConfig,
43    },
44    record::{self, Record},
45    QueryId,
46};
47
48const MAX_NUM_STREAMS: usize = 32;
49
50/// Protocol handler that manages substreams for the Kademlia protocol
51/// on a single connection with a peer.
52///
53/// The handler will automatically open a Kademlia substream with the remote for each request we
54/// make.
55///
56/// It also handles requests made by the remote.
57pub struct Handler {
58    /// Configuration of the wire protocol.
59    protocol_config: ProtocolConfig,
60
61    /// In client mode, we don't accept inbound substreams.
62    mode: Mode,
63
64    /// Next unique ID of a connection.
65    next_connec_unique_id: UniqueConnecId,
66
67    /// List of active outbound streams.
68    outbound_substreams:
69        futures_bounded::FuturesTupleSet<io::Result<Option<KadResponseMsg>>, QueryId>,
70
71    /// Contains one [`oneshot::Sender`] per outbound stream that we have requested.
72    pending_streams:
73        VecDeque<oneshot::Sender<Result<KadOutStreamSink<Stream>, StreamUpgradeError<io::Error>>>>,
74
75    /// List of outbound substreams that are waiting to become active next.
76    /// Contains the request we want to send, and the user data if we expect an answer.
77    pending_messages: VecDeque<(KadRequestMsg, QueryId)>,
78
79    /// List of active inbound substreams with the state they are in.
80    inbound_substreams: SelectAll<InboundSubstreamState>,
81
82    /// The connected endpoint of the connection that the handler
83    /// is associated with.
84    endpoint: ConnectedPoint,
85
86    /// The [`PeerId`] of the remote.
87    remote_peer_id: PeerId,
88
89    /// The current state of protocol confirmation.
90    protocol_status: Option<ProtocolStatus>,
91
92    remote_supported_protocols: SupportedProtocols,
93}
94
95/// The states of protocol confirmation that a connection
96/// handler transitions through.
97#[derive(Debug, Copy, Clone, PartialEq)]
98struct ProtocolStatus {
99    /// Whether the remote node supports one of our kademlia protocols.
100    supported: bool,
101    /// Whether we reported the state to the behaviour.
102    reported: bool,
103}
104
105/// State of an active inbound substream.
106enum InboundSubstreamState {
107    /// Waiting for a request from the remote.
108    WaitingMessage {
109        /// Whether it is the first message to be awaited on this stream.
110        first: bool,
111        connection_id: UniqueConnecId,
112        substream: KadInStreamSink<Stream>,
113    },
114    /// Waiting for the behaviour to send a [`HandlerIn`] event containing the response.
115    WaitingBehaviour(UniqueConnecId, KadInStreamSink<Stream>, Option<Waker>),
116    /// Waiting to send an answer back to the remote.
117    PendingSend(UniqueConnecId, KadInStreamSink<Stream>, KadResponseMsg),
118    /// Waiting to flush an answer back to the remote.
119    PendingFlush(UniqueConnecId, KadInStreamSink<Stream>),
120    /// The substream is being closed.
121    Closing(KadInStreamSink<Stream>),
122    /// The substream was cancelled in favor of a new one.
123    Cancelled,
124
125    Poisoned {
126        phantom: PhantomData<QueryId>,
127    },
128}
129
130impl InboundSubstreamState {
131    #[allow(clippy::result_large_err)]
132    fn try_answer_with(
133        &mut self,
134        id: RequestId,
135        msg: KadResponseMsg,
136    ) -> Result<(), KadResponseMsg> {
137        match std::mem::replace(
138            self,
139            InboundSubstreamState::Poisoned {
140                phantom: PhantomData,
141            },
142        ) {
143            InboundSubstreamState::WaitingBehaviour(conn_id, substream, mut waker)
144                if conn_id == id.connec_unique_id =>
145            {
146                *self = InboundSubstreamState::PendingSend(conn_id, substream, msg);
147
148                if let Some(waker) = waker.take() {
149                    waker.wake();
150                }
151
152                Ok(())
153            }
154            other => {
155                *self = other;
156
157                Err(msg)
158            }
159        }
160    }
161
162    fn close(&mut self) {
163        match std::mem::replace(
164            self,
165            InboundSubstreamState::Poisoned {
166                phantom: PhantomData,
167            },
168        ) {
169            InboundSubstreamState::WaitingMessage { substream, .. }
170            | InboundSubstreamState::WaitingBehaviour(_, substream, _)
171            | InboundSubstreamState::PendingSend(_, substream, _)
172            | InboundSubstreamState::PendingFlush(_, substream)
173            | InboundSubstreamState::Closing(substream) => {
174                *self = InboundSubstreamState::Closing(substream);
175            }
176            InboundSubstreamState::Cancelled => {
177                *self = InboundSubstreamState::Cancelled;
178            }
179            InboundSubstreamState::Poisoned { .. } => unreachable!(),
180        }
181    }
182}
183
184/// Event produced by the Kademlia handler.
185#[derive(Debug)]
186pub enum HandlerEvent {
187    /// The configured protocol name has been confirmed by the peer through
188    /// a successfully negotiated substream or by learning the supported protocols of the remote.
189    ProtocolConfirmed { endpoint: ConnectedPoint },
190    /// The configured protocol name(s) are not or no longer supported by the peer on the provided
191    /// connection and it should be removed from the routing table.
192    ProtocolNotSupported { endpoint: ConnectedPoint },
193
194    /// Request for the list of nodes whose IDs are the closest to `key`. The number of nodes
195    /// returned is not specified, but should be around 20.
196    FindNodeReq {
197        /// The key for which to locate the closest nodes.
198        key: Vec<u8>,
199        /// Identifier of the request. Needs to be passed back when answering.
200        request_id: RequestId,
201    },
202
203    /// Response to an `HandlerIn::FindNodeReq`.
204    FindNodeRes {
205        /// Results of the request.
206        closer_peers: Vec<KadPeer>,
207        /// The user data passed to the `FindNodeReq`.
208        query_id: QueryId,
209    },
210
211    /// Same as `FindNodeReq`, but should also return the entries of the local providers list for
212    /// this key.
213    GetProvidersReq {
214        /// The key for which providers are requested.
215        key: record::Key,
216        /// Identifier of the request. Needs to be passed back when answering.
217        request_id: RequestId,
218    },
219
220    /// Response to an `HandlerIn::GetProvidersReq`.
221    GetProvidersRes {
222        /// Nodes closest to the key.
223        closer_peers: Vec<KadPeer>,
224        /// Known providers for this key.
225        provider_peers: Vec<KadPeer>,
226        /// The user data passed to the `GetProvidersReq`.
227        query_id: QueryId,
228    },
229
230    /// An error happened when performing a query.
231    QueryError {
232        /// The error that happened.
233        error: HandlerQueryErr,
234        /// The user data passed to the query.
235        query_id: QueryId,
236    },
237
238    /// The peer announced itself as a provider of a key.
239    AddProvider {
240        /// The key for which the peer is a provider of the associated value.
241        key: record::Key,
242        /// The peer that is the provider of the value for `key`.
243        provider: KadPeer,
244    },
245
246    /// Request to get a value from the dht records
247    GetRecord {
248        /// Key for which we should look in the dht
249        key: record::Key,
250        /// Identifier of the request. Needs to be passed back when answering.
251        request_id: RequestId,
252    },
253
254    /// Response to a `HandlerIn::GetRecord`.
255    GetRecordRes {
256        /// The result is present if the key has been found
257        record: Option<Record>,
258        /// Nodes closest to the key.
259        closer_peers: Vec<KadPeer>,
260        /// The user data passed to the `GetValue`.
261        query_id: QueryId,
262    },
263
264    /// Request to put a value in the dht records
265    PutRecord {
266        record: Record,
267        /// Identifier of the request. Needs to be passed back when answering.
268        request_id: RequestId,
269    },
270
271    /// Response to a request to store a record.
272    PutRecordRes {
273        /// The key of the stored record.
274        key: record::Key,
275        /// The value of the stored record.
276        value: Vec<u8>,
277        /// The user data passed to the `PutValue`.
278        query_id: QueryId,
279    },
280}
281
282/// Error that can happen when requesting an RPC query.
283#[derive(Debug)]
284pub enum HandlerQueryErr {
285    /// Received an answer that doesn't correspond to the request.
286    UnexpectedMessage,
287    /// I/O error in the substream.
288    Io(io::Error),
289}
290
291impl fmt::Display for HandlerQueryErr {
292    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
293        match self {
294            HandlerQueryErr::UnexpectedMessage => {
295                write!(
296                    f,
297                    "Remote answered our Kademlia RPC query with the wrong message type"
298                )
299            }
300            HandlerQueryErr::Io(err) => {
301                write!(f, "I/O error during a Kademlia RPC query: {err}")
302            }
303        }
304    }
305}
306
307impl error::Error for HandlerQueryErr {
308    fn source(&self) -> Option<&(dyn error::Error + 'static)> {
309        match self {
310            HandlerQueryErr::UnexpectedMessage => None,
311            HandlerQueryErr::Io(err) => Some(err),
312        }
313    }
314}
315
316/// Event to send to the handler.
317#[derive(Debug)]
318pub enum HandlerIn {
319    /// Resets the (sub)stream associated with the given request ID,
320    /// thus signaling an error to the remote.
321    ///
322    /// Explicitly resetting the (sub)stream associated with a request
323    /// can be used as an alternative to letting requests simply time
324    /// out on the remote peer, thus potentially avoiding some delay
325    /// for the query on the remote.
326    Reset(RequestId),
327
328    /// Change the connection to the specified mode.
329    ReconfigureMode { new_mode: Mode },
330
331    /// Request for the list of nodes whose IDs are the closest to `key`. The number of nodes
332    /// returned is not specified, but should be around 20.
333    FindNodeReq {
334        /// Identifier of the node.
335        key: Vec<u8>,
336        /// ID of the query that generated this request.
337        query_id: QueryId,
338    },
339
340    /// Response to a `FindNodeReq`.
341    FindNodeRes {
342        /// Results of the request.
343        closer_peers: Vec<KadPeer>,
344        /// Identifier of the request that was made by the remote.
345        ///
346        /// It is a logic error to use an id of the handler of a different node.
347        request_id: RequestId,
348    },
349
350    /// Same as `FindNodeReq`, but should also return the entries of the local providers list for
351    /// this key.
352    GetProvidersReq {
353        /// Identifier being searched.
354        key: record::Key,
355        /// ID of the query that generated this request.
356        query_id: QueryId,
357    },
358
359    /// Response to a `GetProvidersReq`.
360    GetProvidersRes {
361        /// Nodes closest to the key.
362        closer_peers: Vec<KadPeer>,
363        /// Known providers for this key.
364        provider_peers: Vec<KadPeer>,
365        /// Identifier of the request that was made by the remote.
366        ///
367        /// It is a logic error to use an id of the handler of a different node.
368        request_id: RequestId,
369    },
370
371    /// Indicates that this provider is known for this key.
372    ///
373    /// The API of the handler doesn't expose any event that allows you to know whether this
374    /// succeeded.
375    AddProvider {
376        /// Key for which we should add providers.
377        key: record::Key,
378        /// Known provider for this key.
379        provider: KadPeer,
380        /// ID of the query that generated this request.
381        query_id: QueryId,
382    },
383
384    /// Request to retrieve a record from the DHT.
385    GetRecord {
386        /// The key of the record.
387        key: record::Key,
388        /// ID of the query that generated this request.
389        query_id: QueryId,
390    },
391
392    /// Response to a `GetRecord` request.
393    GetRecordRes {
394        /// The value that might have been found in our storage.
395        record: Option<Record>,
396        /// Nodes that are closer to the key we were searching for.
397        closer_peers: Vec<KadPeer>,
398        /// Identifier of the request that was made by the remote.
399        request_id: RequestId,
400    },
401
402    /// Put a value into the dht records.
403    PutRecord {
404        record: Record,
405        /// ID of the query that generated this request.
406        query_id: QueryId,
407    },
408
409    /// Response to a `PutRecord`.
410    PutRecordRes {
411        /// Key of the value that was put.
412        key: record::Key,
413        /// Value that was put.
414        value: Vec<u8>,
415        /// Identifier of the request that was made by the remote.
416        request_id: RequestId,
417    },
418}
419
420/// Unique identifier for a request. Must be passed back in order to answer a request from
421/// the remote.
422#[derive(Debug, PartialEq, Eq, Copy, Clone)]
423pub struct RequestId {
424    /// Unique identifier for an incoming connection.
425    connec_unique_id: UniqueConnecId,
426}
427
428/// Unique identifier for a connection.
429#[derive(Debug, Copy, Clone, PartialEq, Eq)]
430struct UniqueConnecId(u64);
431
432impl Handler {
433    pub fn new(
434        protocol_config: ProtocolConfig,
435        endpoint: ConnectedPoint,
436        remote_peer_id: PeerId,
437        mode: Mode,
438    ) -> Self {
439        match &endpoint {
440            ConnectedPoint::Dialer { .. } => {
441                tracing::debug!(
442                    peer=%remote_peer_id,
443                    mode=%mode,
444                    "New outbound connection"
445                );
446            }
447            ConnectedPoint::Listener { .. } => {
448                tracing::debug!(
449                    peer=%remote_peer_id,
450                    mode=%mode,
451                    "New inbound connection"
452                );
453            }
454        }
455
456        let substreams_timeout = protocol_config.substreams_timeout_s();
457
458        Handler {
459            protocol_config,
460            mode,
461            endpoint,
462            remote_peer_id,
463            next_connec_unique_id: UniqueConnecId(0),
464            inbound_substreams: Default::default(),
465            outbound_substreams: futures_bounded::FuturesTupleSet::new(
466                substreams_timeout,
467                MAX_NUM_STREAMS,
468            ),
469            pending_streams: Default::default(),
470            pending_messages: Default::default(),
471            protocol_status: None,
472            remote_supported_protocols: Default::default(),
473        }
474    }
475
476    fn on_fully_negotiated_outbound(
477        &mut self,
478        FullyNegotiatedOutbound {
479            protocol: stream,
480            info: (),
481        }: FullyNegotiatedOutbound<<Self as ConnectionHandler>::OutboundProtocol>,
482    ) {
483        if let Some(sender) = self.pending_streams.pop_front() {
484            let _ = sender.send(Ok(stream));
485        }
486
487        if self.protocol_status.is_none() {
488            // Upon the first successfully negotiated substream, we know that the
489            // remote is configured with the same protocol name and we want
490            // the behaviour to add this peer to the routing table, if possible.
491            self.protocol_status = Some(ProtocolStatus {
492                supported: true,
493                reported: false,
494            });
495        }
496    }
497
498    fn on_fully_negotiated_inbound(
499        &mut self,
500        FullyNegotiatedInbound { protocol, .. }: FullyNegotiatedInbound<
501            <Self as ConnectionHandler>::InboundProtocol,
502        >,
503    ) {
504        // If `self.allow_listening` is false, then we produced a `DeniedUpgrade` and `protocol`
505        // is a `Infallible`.
506        let protocol = match protocol {
507            future::Either::Left(p) => p,
508            future::Either::Right(p) => libp2p_core::util::unreachable(p),
509        };
510
511        if self.protocol_status.is_none() {
512            // Upon the first successfully negotiated substream, we know that the
513            // remote is configured with the same protocol name and we want
514            // the behaviour to add this peer to the routing table, if possible.
515            self.protocol_status = Some(ProtocolStatus {
516                supported: true,
517                reported: false,
518            });
519        }
520
521        if self.inbound_substreams.len() == MAX_NUM_STREAMS {
522            if let Some(s) = self.inbound_substreams.iter_mut().find(|s| {
523                matches!(
524                    s,
525                    // An inbound substream waiting to be reused.
526                    InboundSubstreamState::WaitingMessage { first: false, .. }
527                )
528            }) {
529                *s = InboundSubstreamState::Cancelled;
530                tracing::debug!(
531                    peer=?self.remote_peer_id,
532                    "New inbound substream to peer exceeds inbound substream limit. \
533                    Removed older substream waiting to be reused."
534                )
535            } else {
536                tracing::warn!(
537                    peer=?self.remote_peer_id,
538                    "New inbound substream to peer exceeds inbound substream limit. \
539                     No older substream waiting to be reused. Dropping new substream."
540                );
541                return;
542            }
543        }
544
545        let connec_unique_id = self.next_connec_unique_id;
546        self.next_connec_unique_id.0 += 1;
547        self.inbound_substreams
548            .push(InboundSubstreamState::WaitingMessage {
549                first: true,
550                connection_id: connec_unique_id,
551                substream: protocol,
552            });
553    }
554
555    /// Takes the given [`KadRequestMsg`] and composes it into an outbound request-response protocol
556    /// handshake using a [`oneshot::channel`].
557    fn queue_new_stream(&mut self, id: QueryId, msg: KadRequestMsg) {
558        let (sender, receiver) = oneshot::channel();
559
560        self.pending_streams.push_back(sender);
561        let result = self.outbound_substreams.try_push(
562            async move {
563                let mut stream = receiver
564                    .await
565                    .map_err(|_| io::Error::from(io::ErrorKind::BrokenPipe))?
566                    .map_err(|e| match e {
567                        StreamUpgradeError::Timeout => io::ErrorKind::TimedOut.into(),
568                        StreamUpgradeError::Apply(e) => e,
569                        StreamUpgradeError::NegotiationFailed => io::Error::new(
570                            io::ErrorKind::ConnectionRefused,
571                            "protocol not supported",
572                        ),
573                        StreamUpgradeError::Io(e) => e,
574                    })?;
575
576                let has_answer = !matches!(msg, KadRequestMsg::AddProvider { .. });
577
578                stream.send(msg).await?;
579                stream.close().await?;
580
581                if !has_answer {
582                    return Ok(None);
583                }
584
585                let msg = stream.next().await.ok_or(io::ErrorKind::UnexpectedEof)??;
586
587                Ok(Some(msg))
588            },
589            id,
590        );
591
592        debug_assert!(
593            result.is_ok(),
594            "Expected to not create more streams than allowed"
595        );
596    }
597}
598
599impl ConnectionHandler for Handler {
600    type FromBehaviour = HandlerIn;
601    type ToBehaviour = HandlerEvent;
602    type InboundProtocol = Either<ProtocolConfig, upgrade::DeniedUpgrade>;
603    type OutboundProtocol = ProtocolConfig;
604    type OutboundOpenInfo = ();
605    type InboundOpenInfo = ();
606
607    fn listen_protocol(&self) -> SubstreamProtocol<Self::InboundProtocol> {
608        match self.mode {
609            Mode::Server => SubstreamProtocol::new(Either::Left(self.protocol_config.clone()), ()),
610            Mode::Client => SubstreamProtocol::new(Either::Right(upgrade::DeniedUpgrade), ()),
611        }
612    }
613
614    fn on_behaviour_event(&mut self, message: HandlerIn) {
615        match message {
616            HandlerIn::Reset(request_id) => {
617                if let Some(state) = self
618                    .inbound_substreams
619                    .iter_mut()
620                    .find(|state| match state {
621                        InboundSubstreamState::WaitingBehaviour(conn_id, _, _) => {
622                            conn_id == &request_id.connec_unique_id
623                        }
624                        _ => false,
625                    })
626                {
627                    state.close();
628                }
629            }
630            HandlerIn::FindNodeReq { key, query_id } => {
631                let msg = KadRequestMsg::FindNode { key };
632                self.pending_messages.push_back((msg, query_id));
633            }
634            HandlerIn::FindNodeRes {
635                closer_peers,
636                request_id,
637            } => self.answer_pending_request(request_id, KadResponseMsg::FindNode { closer_peers }),
638            HandlerIn::GetProvidersReq { key, query_id } => {
639                let msg = KadRequestMsg::GetProviders { key };
640                self.pending_messages.push_back((msg, query_id));
641            }
642            HandlerIn::GetProvidersRes {
643                closer_peers,
644                provider_peers,
645                request_id,
646            } => self.answer_pending_request(
647                request_id,
648                KadResponseMsg::GetProviders {
649                    closer_peers,
650                    provider_peers,
651                },
652            ),
653            HandlerIn::AddProvider {
654                key,
655                provider,
656                query_id,
657            } => {
658                let msg = KadRequestMsg::AddProvider { key, provider };
659                self.pending_messages.push_back((msg, query_id));
660            }
661            HandlerIn::GetRecord { key, query_id } => {
662                let msg = KadRequestMsg::GetValue { key };
663                self.pending_messages.push_back((msg, query_id));
664            }
665            HandlerIn::PutRecord { record, query_id } => {
666                let msg = KadRequestMsg::PutValue { record };
667                self.pending_messages.push_back((msg, query_id));
668            }
669            HandlerIn::GetRecordRes {
670                record,
671                closer_peers,
672                request_id,
673            } => {
674                self.answer_pending_request(
675                    request_id,
676                    KadResponseMsg::GetValue {
677                        record,
678                        closer_peers,
679                    },
680                );
681            }
682            HandlerIn::PutRecordRes {
683                key,
684                request_id,
685                value,
686            } => {
687                self.answer_pending_request(request_id, KadResponseMsg::PutValue { key, value });
688            }
689            HandlerIn::ReconfigureMode { new_mode } => {
690                let peer = self.remote_peer_id;
691
692                match &self.endpoint {
693                    ConnectedPoint::Dialer { .. } => {
694                        tracing::debug!(
695                            %peer,
696                            mode=%new_mode,
697                            "Changed mode on outbound connection"
698                        )
699                    }
700                    ConnectedPoint::Listener { local_addr, .. } => {
701                        tracing::debug!(
702                            %peer,
703                            mode=%new_mode,
704                            local_address=%local_addr,
705                            "Changed mode on inbound connection assuming that one of our external addresses routes to the local address")
706                    }
707                }
708
709                self.mode = new_mode;
710            }
711        }
712    }
713
714    #[tracing::instrument(level = "trace", name = "ConnectionHandler::poll", skip(self, cx))]
715    fn poll(
716        &mut self,
717        cx: &mut Context<'_>,
718    ) -> Poll<ConnectionHandlerEvent<Self::OutboundProtocol, (), Self::ToBehaviour>> {
719        loop {
720            match &mut self.protocol_status {
721                Some(status) if !status.reported => {
722                    status.reported = true;
723                    let event = if status.supported {
724                        HandlerEvent::ProtocolConfirmed {
725                            endpoint: self.endpoint.clone(),
726                        }
727                    } else {
728                        HandlerEvent::ProtocolNotSupported {
729                            endpoint: self.endpoint.clone(),
730                        }
731                    };
732
733                    return Poll::Ready(ConnectionHandlerEvent::NotifyBehaviour(event));
734                }
735                _ => {}
736            }
737
738            match self.outbound_substreams.poll_unpin(cx) {
739                Poll::Ready((Ok(Ok(Some(response))), query_id)) => {
740                    return Poll::Ready(ConnectionHandlerEvent::NotifyBehaviour(
741                        process_kad_response(response, query_id),
742                    ))
743                }
744                Poll::Ready((Ok(Ok(None)), _)) => {
745                    continue;
746                }
747                Poll::Ready((Ok(Err(e)), query_id)) => {
748                    return Poll::Ready(ConnectionHandlerEvent::NotifyBehaviour(
749                        HandlerEvent::QueryError {
750                            error: HandlerQueryErr::Io(e),
751                            query_id,
752                        },
753                    ))
754                }
755                Poll::Ready((Err(_timeout), query_id)) => {
756                    return Poll::Ready(ConnectionHandlerEvent::NotifyBehaviour(
757                        HandlerEvent::QueryError {
758                            error: HandlerQueryErr::Io(io::ErrorKind::TimedOut.into()),
759                            query_id,
760                        },
761                    ))
762                }
763                Poll::Pending => {}
764            }
765
766            if let Poll::Ready(Some(event)) = self.inbound_substreams.poll_next_unpin(cx) {
767                return Poll::Ready(event);
768            }
769
770            if self.outbound_substreams.len() < MAX_NUM_STREAMS {
771                if let Some((msg, id)) = self.pending_messages.pop_front() {
772                    self.queue_new_stream(id, msg);
773                    return Poll::Ready(ConnectionHandlerEvent::OutboundSubstreamRequest {
774                        protocol: SubstreamProtocol::new(self.protocol_config.clone(), ()),
775                    });
776                }
777            }
778
779            return Poll::Pending;
780        }
781    }
782
783    fn on_connection_event(
784        &mut self,
785        event: ConnectionEvent<Self::InboundProtocol, Self::OutboundProtocol>,
786    ) {
787        match event {
788            ConnectionEvent::FullyNegotiatedOutbound(fully_negotiated_outbound) => {
789                self.on_fully_negotiated_outbound(fully_negotiated_outbound)
790            }
791            ConnectionEvent::FullyNegotiatedInbound(fully_negotiated_inbound) => {
792                self.on_fully_negotiated_inbound(fully_negotiated_inbound)
793            }
794            ConnectionEvent::DialUpgradeError(ev) => {
795                if let Some(sender) = self.pending_streams.pop_front() {
796                    let _ = sender.send(Err(ev.error));
797                }
798            }
799            ConnectionEvent::RemoteProtocolsChange(change) => {
800                let dirty = self.remote_supported_protocols.on_protocols_change(change);
801
802                if dirty {
803                    let remote_supports_our_kademlia_protocols = self
804                        .remote_supported_protocols
805                        .iter()
806                        .any(|p| self.protocol_config.protocol_names().contains(p));
807
808                    self.protocol_status = Some(compute_new_protocol_status(
809                        remote_supports_our_kademlia_protocols,
810                        self.protocol_status,
811                    ))
812                }
813            }
814            _ => {}
815        }
816    }
817}
818
819fn compute_new_protocol_status(
820    now_supported: bool,
821    current_status: Option<ProtocolStatus>,
822) -> ProtocolStatus {
823    let Some(current_status) = current_status else {
824        return ProtocolStatus {
825            supported: now_supported,
826            reported: false,
827        };
828    };
829
830    if now_supported == current_status.supported {
831        return ProtocolStatus {
832            supported: now_supported,
833            reported: true,
834        };
835    }
836
837    if now_supported {
838        tracing::debug!("Remote now supports our kademlia protocol");
839    } else {
840        tracing::debug!("Remote no longer supports our kademlia protocol");
841    }
842
843    ProtocolStatus {
844        supported: now_supported,
845        reported: false,
846    }
847}
848
849impl Handler {
850    fn answer_pending_request(&mut self, request_id: RequestId, mut msg: KadResponseMsg) {
851        for state in self.inbound_substreams.iter_mut() {
852            match state.try_answer_with(request_id, msg) {
853                Ok(()) => return,
854                Err(m) => {
855                    msg = m;
856                }
857            }
858        }
859
860        debug_assert!(false, "Cannot find inbound substream for {request_id:?}")
861    }
862}
863
864impl futures::Stream for InboundSubstreamState {
865    type Item = ConnectionHandlerEvent<ProtocolConfig, (), HandlerEvent>;
866
867    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
868        let this = self.get_mut();
869
870        loop {
871            match std::mem::replace(
872                this,
873                Self::Poisoned {
874                    phantom: PhantomData,
875                },
876            ) {
877                InboundSubstreamState::WaitingMessage {
878                    first,
879                    connection_id,
880                    mut substream,
881                } => match substream.poll_next_unpin(cx) {
882                    Poll::Ready(Some(Ok(KadRequestMsg::Ping))) => {
883                        tracing::warn!("Kademlia PING messages are unsupported");
884
885                        *this = InboundSubstreamState::Closing(substream);
886                    }
887                    Poll::Ready(Some(Ok(KadRequestMsg::FindNode { key }))) => {
888                        *this =
889                            InboundSubstreamState::WaitingBehaviour(connection_id, substream, None);
890                        return Poll::Ready(Some(ConnectionHandlerEvent::NotifyBehaviour(
891                            HandlerEvent::FindNodeReq {
892                                key,
893                                request_id: RequestId {
894                                    connec_unique_id: connection_id,
895                                },
896                            },
897                        )));
898                    }
899                    Poll::Ready(Some(Ok(KadRequestMsg::GetProviders { key }))) => {
900                        *this =
901                            InboundSubstreamState::WaitingBehaviour(connection_id, substream, None);
902                        return Poll::Ready(Some(ConnectionHandlerEvent::NotifyBehaviour(
903                            HandlerEvent::GetProvidersReq {
904                                key,
905                                request_id: RequestId {
906                                    connec_unique_id: connection_id,
907                                },
908                            },
909                        )));
910                    }
911                    Poll::Ready(Some(Ok(KadRequestMsg::AddProvider { key, provider }))) => {
912                        *this = InboundSubstreamState::WaitingMessage {
913                            first: false,
914                            connection_id,
915                            substream,
916                        };
917                        return Poll::Ready(Some(ConnectionHandlerEvent::NotifyBehaviour(
918                            HandlerEvent::AddProvider { key, provider },
919                        )));
920                    }
921                    Poll::Ready(Some(Ok(KadRequestMsg::GetValue { key }))) => {
922                        *this =
923                            InboundSubstreamState::WaitingBehaviour(connection_id, substream, None);
924                        return Poll::Ready(Some(ConnectionHandlerEvent::NotifyBehaviour(
925                            HandlerEvent::GetRecord {
926                                key,
927                                request_id: RequestId {
928                                    connec_unique_id: connection_id,
929                                },
930                            },
931                        )));
932                    }
933                    Poll::Ready(Some(Ok(KadRequestMsg::PutValue { record }))) => {
934                        *this =
935                            InboundSubstreamState::WaitingBehaviour(connection_id, substream, None);
936                        return Poll::Ready(Some(ConnectionHandlerEvent::NotifyBehaviour(
937                            HandlerEvent::PutRecord {
938                                record,
939                                request_id: RequestId {
940                                    connec_unique_id: connection_id,
941                                },
942                            },
943                        )));
944                    }
945                    Poll::Pending => {
946                        *this = InboundSubstreamState::WaitingMessage {
947                            first,
948                            connection_id,
949                            substream,
950                        };
951                        return Poll::Pending;
952                    }
953                    Poll::Ready(None) => {
954                        return Poll::Ready(None);
955                    }
956                    Poll::Ready(Some(Err(e))) => {
957                        tracing::trace!("Inbound substream error: {:?}", e);
958                        return Poll::Ready(None);
959                    }
960                },
961                InboundSubstreamState::WaitingBehaviour(id, substream, _) => {
962                    *this = InboundSubstreamState::WaitingBehaviour(
963                        id,
964                        substream,
965                        Some(cx.waker().clone()),
966                    );
967
968                    return Poll::Pending;
969                }
970                InboundSubstreamState::PendingSend(id, mut substream, msg) => {
971                    match substream.poll_ready_unpin(cx) {
972                        Poll::Ready(Ok(())) => match substream.start_send_unpin(msg) {
973                            Ok(()) => {
974                                *this = InboundSubstreamState::PendingFlush(id, substream);
975                            }
976                            Err(_) => return Poll::Ready(None),
977                        },
978                        Poll::Pending => {
979                            *this = InboundSubstreamState::PendingSend(id, substream, msg);
980                            return Poll::Pending;
981                        }
982                        Poll::Ready(Err(_)) => return Poll::Ready(None),
983                    }
984                }
985                InboundSubstreamState::PendingFlush(id, mut substream) => {
986                    match substream.poll_flush_unpin(cx) {
987                        Poll::Ready(Ok(())) => {
988                            *this = InboundSubstreamState::WaitingMessage {
989                                first: false,
990                                connection_id: id,
991                                substream,
992                            };
993                        }
994                        Poll::Pending => {
995                            *this = InboundSubstreamState::PendingFlush(id, substream);
996                            return Poll::Pending;
997                        }
998                        Poll::Ready(Err(_)) => return Poll::Ready(None),
999                    }
1000                }
1001                InboundSubstreamState::Closing(mut stream) => match stream.poll_close_unpin(cx) {
1002                    Poll::Ready(Ok(())) | Poll::Ready(Err(_)) => return Poll::Ready(None),
1003                    Poll::Pending => {
1004                        *this = InboundSubstreamState::Closing(stream);
1005                        return Poll::Pending;
1006                    }
1007                },
1008                InboundSubstreamState::Poisoned { .. } => unreachable!(),
1009                InboundSubstreamState::Cancelled => return Poll::Ready(None),
1010            }
1011        }
1012    }
1013}
1014
1015/// Process a Kademlia message that's supposed to be a response to one of our requests.
1016fn process_kad_response(event: KadResponseMsg, query_id: QueryId) -> HandlerEvent {
1017    // TODO: must check that the response corresponds to the request
1018    match event {
1019        KadResponseMsg::Pong => {
1020            // We never send out pings.
1021            HandlerEvent::QueryError {
1022                error: HandlerQueryErr::UnexpectedMessage,
1023                query_id,
1024            }
1025        }
1026        KadResponseMsg::FindNode { closer_peers } => HandlerEvent::FindNodeRes {
1027            closer_peers,
1028            query_id,
1029        },
1030        KadResponseMsg::GetProviders {
1031            closer_peers,
1032            provider_peers,
1033        } => HandlerEvent::GetProvidersRes {
1034            closer_peers,
1035            provider_peers,
1036            query_id,
1037        },
1038        KadResponseMsg::GetValue {
1039            record,
1040            closer_peers,
1041        } => HandlerEvent::GetRecordRes {
1042            record,
1043            closer_peers,
1044            query_id,
1045        },
1046        KadResponseMsg::PutValue { key, value, .. } => HandlerEvent::PutRecordRes {
1047            key,
1048            value,
1049            query_id,
1050        },
1051    }
1052}
1053
1054#[cfg(test)]
1055mod tests {
1056    use quickcheck::{Arbitrary, Gen};
1057    use tracing_subscriber::EnvFilter;
1058
1059    use super::*;
1060
1061    impl Arbitrary for ProtocolStatus {
1062        fn arbitrary(g: &mut Gen) -> Self {
1063            Self {
1064                supported: bool::arbitrary(g),
1065                reported: bool::arbitrary(g),
1066            }
1067        }
1068    }
1069
1070    #[test]
1071    fn compute_next_protocol_status_test() {
1072        let _ = tracing_subscriber::fmt()
1073            .with_env_filter(EnvFilter::from_default_env())
1074            .try_init();
1075
1076        fn prop(now_supported: bool, current: Option<ProtocolStatus>) {
1077            let new = compute_new_protocol_status(now_supported, current);
1078
1079            match current {
1080                None => {
1081                    assert!(!new.reported);
1082                    assert_eq!(new.supported, now_supported);
1083                }
1084                Some(current) => {
1085                    if current.supported == now_supported {
1086                        assert!(new.reported);
1087                    } else {
1088                        assert!(!new.reported);
1089                    }
1090
1091                    assert_eq!(new.supported, now_supported);
1092                }
1093            }
1094        }
1095
1096        quickcheck::quickcheck(prop as fn(_, _))
1097    }
1098}