1use std::{
22 collections::HashSet,
23 sync::Arc,
24 task::{Context, Poll},
25 time::Duration,
26};
27
28use either::Either;
29use futures::prelude::*;
30use futures_bounded::Timeout;
31use futures_timer::Delay;
32use libp2p_core::{
33 upgrade::{ReadyUpgrade, SelectUpgrade},
34 Multiaddr,
35};
36use libp2p_identity::PeerId;
37use libp2p_swarm::{
38 handler::{
39 ConnectionEvent, DialUpgradeError, FullyNegotiatedInbound, FullyNegotiatedOutbound,
40 ProtocolSupport,
41 },
42 ConnectionHandler, ConnectionHandlerEvent, StreamProtocol, StreamUpgradeError,
43 SubstreamProtocol, SupportedProtocols,
44};
45use smallvec::SmallVec;
46use tracing::Level;
47
48use crate::{
49 behaviour::KeyType,
50 protocol::{self, Info, PushInfo, UpgradeError},
51 PROTOCOL_NAME, PUSH_PROTOCOL_NAME,
52};
53
54const STREAM_TIMEOUT: Duration = Duration::from_secs(60);
55const MAX_CONCURRENT_STREAMS_PER_CONNECTION: usize = 10;
56
57pub struct Handler {
63 remote_peer_id: PeerId,
64 events: SmallVec<
66 [ConnectionHandlerEvent<
67 Either<ReadyUpgrade<StreamProtocol>, ReadyUpgrade<StreamProtocol>>,
68 (),
69 Event,
70 >; 4],
71 >,
72
73 active_streams: futures_bounded::FuturesSet<Result<Success, UpgradeError>>,
74
75 trigger_next_identify: Delay,
77
78 exchanged_one_periodic_identify: bool,
80
81 interval: Duration,
83
84 local_key: Arc<KeyType>,
86
87 protocol_version: String,
90
91 agent_version: String,
94
95 observed_addr: Multiaddr,
97
98 remote_info: Option<Info>,
100
101 local_supported_protocols: SupportedProtocols,
102 remote_supported_protocols: HashSet<StreamProtocol>,
103 external_addresses: HashSet<Multiaddr>,
104}
105
106#[derive(Debug)]
108pub enum InEvent {
109 AddressesChanged(HashSet<Multiaddr>),
110 Push,
111}
112
113#[derive(Debug)]
115#[allow(clippy::large_enum_variant)]
116pub enum Event {
117 Identified(Info),
119 Identification,
121 IdentificationPushed(Info),
123 IdentificationError(StreamUpgradeError<UpgradeError>),
125}
126
127impl Handler {
128 pub(crate) fn new(
130 interval: Duration,
131 remote_peer_id: PeerId,
132 local_key: Arc<KeyType>,
133 protocol_version: String,
134 agent_version: String,
135 observed_addr: Multiaddr,
136 external_addresses: HashSet<Multiaddr>,
137 ) -> Self {
138 Self {
139 remote_peer_id,
140 events: SmallVec::new(),
141 active_streams: futures_bounded::FuturesSet::new(
142 STREAM_TIMEOUT,
143 MAX_CONCURRENT_STREAMS_PER_CONNECTION,
144 ),
145 trigger_next_identify: Delay::new(Duration::ZERO),
146 exchanged_one_periodic_identify: false,
147 interval,
148 local_key,
149 protocol_version,
150 agent_version,
151 observed_addr,
152 local_supported_protocols: SupportedProtocols::default(),
153 remote_supported_protocols: HashSet::default(),
154 remote_info: Default::default(),
155 external_addresses,
156 }
157 }
158
159 fn on_fully_negotiated_inbound(
160 &mut self,
161 FullyNegotiatedInbound {
162 protocol: output, ..
163 }: FullyNegotiatedInbound<<Self as ConnectionHandler>::InboundProtocol>,
164 ) {
165 match output {
166 future::Either::Left(stream) => {
167 let info = self.build_info();
168
169 if self
170 .active_streams
171 .try_push(
172 protocol::send_identify(stream, info).map_ok(|_| Success::SentIdentify),
173 )
174 .is_err()
175 {
176 tracing::warn!("Dropping inbound stream because we are at capacity");
177 } else {
178 self.exchanged_one_periodic_identify = true;
179 }
180 }
181 future::Either::Right(stream) => {
182 if self
183 .active_streams
184 .try_push(protocol::recv_push(stream).map_ok(Success::ReceivedIdentifyPush))
185 .is_err()
186 {
187 tracing::warn!(
188 "Dropping inbound identify push stream because we are at capacity"
189 );
190 }
191 }
192 }
193 }
194
195 fn on_fully_negotiated_outbound(
196 &mut self,
197 FullyNegotiatedOutbound {
198 protocol: output, ..
199 }: FullyNegotiatedOutbound<<Self as ConnectionHandler>::OutboundProtocol>,
200 ) {
201 match output {
202 future::Either::Left(stream) => {
203 if self
204 .active_streams
205 .try_push(protocol::recv_identify(stream).map_ok(Success::ReceivedIdentify))
206 .is_err()
207 {
208 tracing::warn!("Dropping outbound identify stream because we are at capacity");
209 }
210 }
211 future::Either::Right(stream) => {
212 let info = self.build_info();
213
214 if self
215 .active_streams
216 .try_push(
217 protocol::send_identify(stream, info).map_ok(Success::SentIdentifyPush),
218 )
219 .is_err()
220 {
221 tracing::warn!(
222 "Dropping outbound identify push stream because we are at capacity"
223 );
224 }
225 }
226 }
227 }
228
229 fn build_info(&mut self) -> Info {
230 let signed_envelope = match self.local_key.as_ref() {
231 KeyType::PublicKey(_) => None,
232 KeyType::Keypair { keypair, .. } => libp2p_core::PeerRecord::new(
233 keypair,
234 Vec::from_iter(self.external_addresses.iter().cloned()),
235 )
236 .ok()
237 .map(|r| r.into_signed_envelope()),
238 };
239 Info {
240 public_key: self.local_key.public_key().clone(),
241 protocol_version: self.protocol_version.clone(),
242 agent_version: self.agent_version.clone(),
243 listen_addrs: Vec::from_iter(self.external_addresses.iter().cloned()),
244 protocols: Vec::from_iter(self.local_supported_protocols.iter().cloned()),
245 observed_addr: self.observed_addr.clone(),
246 signed_peer_record: signed_envelope,
247 }
248 }
249
250 fn handle_incoming_info(&mut self, info: &Info) -> bool {
252 let derived_peer_id = info.public_key.to_peer_id();
253 if self.remote_peer_id != derived_peer_id {
254 return false;
255 }
256
257 self.remote_info.replace(info.clone());
258
259 self.update_supported_protocols_for_remote(info);
260 true
261 }
262
263 fn update_supported_protocols_for_remote(&mut self, remote_info: &Info) {
264 let new_remote_protocols = HashSet::from_iter(remote_info.protocols.clone());
265
266 let remote_added_protocols = new_remote_protocols
267 .difference(&self.remote_supported_protocols)
268 .cloned()
269 .collect::<HashSet<_>>();
270 let remote_removed_protocols = self
271 .remote_supported_protocols
272 .difference(&new_remote_protocols)
273 .cloned()
274 .collect::<HashSet<_>>();
275
276 if !remote_added_protocols.is_empty() {
277 self.events
278 .push(ConnectionHandlerEvent::ReportRemoteProtocols(
279 ProtocolSupport::Added(remote_added_protocols),
280 ));
281 }
282
283 if !remote_removed_protocols.is_empty() {
284 self.events
285 .push(ConnectionHandlerEvent::ReportRemoteProtocols(
286 ProtocolSupport::Removed(remote_removed_protocols),
287 ));
288 }
289
290 self.remote_supported_protocols = new_remote_protocols;
291 }
292
293 fn local_protocols_to_string(&mut self) -> String {
294 self.local_supported_protocols
295 .iter()
296 .map(|p| p.to_string())
297 .collect::<Vec<_>>()
298 .join(", ")
299 }
300}
301
302impl ConnectionHandler for Handler {
303 type FromBehaviour = InEvent;
304 type ToBehaviour = Event;
305 type InboundProtocol =
306 SelectUpgrade<ReadyUpgrade<StreamProtocol>, ReadyUpgrade<StreamProtocol>>;
307 type OutboundProtocol = Either<ReadyUpgrade<StreamProtocol>, ReadyUpgrade<StreamProtocol>>;
308 type OutboundOpenInfo = ();
309 type InboundOpenInfo = ();
310
311 fn listen_protocol(&self) -> SubstreamProtocol<Self::InboundProtocol> {
312 SubstreamProtocol::new(
313 SelectUpgrade::new(
314 ReadyUpgrade::new(PROTOCOL_NAME),
315 ReadyUpgrade::new(PUSH_PROTOCOL_NAME),
316 ),
317 (),
318 )
319 }
320
321 fn on_behaviour_event(&mut self, event: Self::FromBehaviour) {
322 match event {
323 InEvent::AddressesChanged(addresses) => {
324 self.external_addresses = addresses;
325 }
326 InEvent::Push => {
327 self.events
328 .push(ConnectionHandlerEvent::OutboundSubstreamRequest {
329 protocol: SubstreamProtocol::new(
330 Either::Right(ReadyUpgrade::new(PUSH_PROTOCOL_NAME)),
331 (),
332 ),
333 });
334 }
335 }
336 }
337
338 #[tracing::instrument(level = "trace", name = "ConnectionHandler::poll", skip(self, cx))]
339 fn poll(
340 &mut self,
341 cx: &mut Context<'_>,
342 ) -> Poll<ConnectionHandlerEvent<Self::OutboundProtocol, (), Event>> {
343 if let Some(event) = self.events.pop() {
344 return Poll::Ready(event);
345 }
346
347 if let Poll::Ready(()) = self.trigger_next_identify.poll_unpin(cx) {
349 self.trigger_next_identify.reset(self.interval);
350 let event = ConnectionHandlerEvent::OutboundSubstreamRequest {
351 protocol: SubstreamProtocol::new(
352 Either::Left(ReadyUpgrade::new(PROTOCOL_NAME)),
353 (),
354 ),
355 };
356 return Poll::Ready(event);
357 }
358
359 while let Poll::Ready(ready) = self.active_streams.poll_unpin(cx) {
360 match ready {
361 Ok(Ok(Success::ReceivedIdentify(remote_info))) => {
362 if self.handle_incoming_info(&remote_info) {
363 return Poll::Ready(ConnectionHandlerEvent::NotifyBehaviour(
364 Event::Identified(remote_info),
365 ));
366 } else {
367 tracing::warn!(
368 %self.remote_peer_id,
369 ?remote_info.public_key,
370 derived_peer_id=%remote_info.public_key.to_peer_id(),
371 "Discarding received identify message as public key does not match remote peer ID",
372 );
373 }
374 }
375 Ok(Ok(Success::SentIdentifyPush(info))) => {
376 return Poll::Ready(ConnectionHandlerEvent::NotifyBehaviour(
377 Event::IdentificationPushed(info),
378 ));
379 }
380 Ok(Ok(Success::SentIdentify)) => {
381 return Poll::Ready(ConnectionHandlerEvent::NotifyBehaviour(
382 Event::Identification,
383 ));
384 }
385 Ok(Ok(Success::ReceivedIdentifyPush(remote_push_info))) => {
386 if let Some(mut info) = self.remote_info.clone() {
387 info.merge(remote_push_info);
388
389 if self.handle_incoming_info(&info) {
390 return Poll::Ready(ConnectionHandlerEvent::NotifyBehaviour(
391 Event::Identified(info),
392 ));
393 } else {
394 tracing::warn!(
395 %self.remote_peer_id,
396 ?info.public_key,
397 derived_peer_id=%info.public_key.to_peer_id(),
398 "Discarding received identify message as public key does not match remote peer ID",
399 );
400 }
401 }
402 }
403 Ok(Err(e)) => {
404 return Poll::Ready(ConnectionHandlerEvent::NotifyBehaviour(
405 Event::IdentificationError(StreamUpgradeError::Apply(e)),
406 ));
407 }
408 Err(Timeout { .. }) => {
409 return Poll::Ready(ConnectionHandlerEvent::NotifyBehaviour(
410 Event::IdentificationError(StreamUpgradeError::Timeout),
411 ));
412 }
413 }
414 }
415
416 Poll::Pending
417 }
418
419 fn on_connection_event(
420 &mut self,
421 event: ConnectionEvent<Self::InboundProtocol, Self::OutboundProtocol>,
422 ) {
423 match event {
424 ConnectionEvent::FullyNegotiatedInbound(fully_negotiated_inbound) => {
425 self.on_fully_negotiated_inbound(fully_negotiated_inbound)
426 }
427 ConnectionEvent::FullyNegotiatedOutbound(fully_negotiated_outbound) => {
428 self.on_fully_negotiated_outbound(fully_negotiated_outbound)
429 }
430 ConnectionEvent::DialUpgradeError(DialUpgradeError { error, .. }) => {
431 self.events.push(ConnectionHandlerEvent::NotifyBehaviour(
432 Event::IdentificationError(
433 error.map_upgrade_err(|e| libp2p_core::util::unreachable(e.into_inner())),
434 ),
435 ));
436 self.trigger_next_identify.reset(self.interval);
437 }
438 ConnectionEvent::LocalProtocolsChange(change) => {
439 let before = tracing::enabled!(Level::DEBUG)
440 .then(|| self.local_protocols_to_string())
441 .unwrap_or_default();
442 let protocols_changed = self.local_supported_protocols.on_protocols_change(change);
443 let after = tracing::enabled!(Level::DEBUG)
444 .then(|| self.local_protocols_to_string())
445 .unwrap_or_default();
446
447 if protocols_changed && self.exchanged_one_periodic_identify {
448 tracing::debug!(
449 peer=%self.remote_peer_id,
450 %before,
451 %after,
452 "Supported listen protocols changed, pushing to peer"
453 );
454
455 self.events
456 .push(ConnectionHandlerEvent::OutboundSubstreamRequest {
457 protocol: SubstreamProtocol::new(
458 Either::Right(ReadyUpgrade::new(PUSH_PROTOCOL_NAME)),
459 (),
460 ),
461 });
462 }
463 }
464 _ => {}
465 }
466 }
467}
468
469enum Success {
470 SentIdentify,
471 ReceivedIdentify(Info),
472 SentIdentifyPush(Info),
473 ReceivedIdentifyPush(PushInfo),
474}