1use std::{
22 collections::{HashMap, HashSet},
23 convert::Infallible,
24 fmt,
25 task::{Context, Poll},
26};
27
28use libp2p_core::{transport::PortUse, ConnectedPoint, Endpoint, Multiaddr};
29use libp2p_identity::PeerId;
30use libp2p_swarm::{
31 behaviour::{ConnectionEstablished, DialFailure, ListenFailure},
32 dummy, ConnectionClosed, ConnectionDenied, ConnectionId, FromSwarm, NetworkBehaviour, THandler,
33 THandlerInEvent, THandlerOutEvent, ToSwarm,
34};
35
36pub struct Behaviour {
72 limits: ConnectionLimits,
73 bypass_peer_id: HashSet<PeerId>,
75
76 pending_inbound_connections: HashSet<ConnectionId>,
77 pending_outbound_connections: HashSet<ConnectionId>,
78 established_inbound_connections: HashSet<ConnectionId>,
79 established_outbound_connections: HashSet<ConnectionId>,
80 established_per_peer: HashMap<PeerId, HashSet<ConnectionId>>,
81}
82
83impl Behaviour {
84 pub fn new(limits: ConnectionLimits) -> Self {
85 Self {
86 limits,
87 bypass_peer_id: Default::default(),
88 pending_inbound_connections: Default::default(),
89 pending_outbound_connections: Default::default(),
90 established_inbound_connections: Default::default(),
91 established_outbound_connections: Default::default(),
92 established_per_peer: Default::default(),
93 }
94 }
95
96 pub fn limits_mut(&mut self) -> &mut ConnectionLimits {
99 &mut self.limits
100 }
101
102 pub fn bypass_peer_id(&mut self, peer_id: &PeerId) {
104 self.bypass_peer_id.insert(*peer_id);
105 }
106 pub fn remove_peer_id(&mut self, peer_id: &PeerId) {
108 self.bypass_peer_id.remove(peer_id);
109 }
110 pub fn is_bypassed(&self, remote_peer: &PeerId) -> bool {
112 self.bypass_peer_id.contains(remote_peer)
113 }
114}
115
116fn check_limit(limit: Option<u32>, current: usize, kind: Kind) -> Result<(), ConnectionDenied> {
117 let limit = limit.unwrap_or(u32::MAX);
118 let current = current as u32;
119
120 if current >= limit {
121 return Err(ConnectionDenied::new(Exceeded { limit, kind }));
122 }
123
124 Ok(())
125}
126
127#[derive(Debug, Clone, Copy)]
129pub struct Exceeded {
130 limit: u32,
131 kind: Kind,
132}
133
134impl Exceeded {
135 pub fn limit(&self) -> u32 {
136 self.limit
137 }
138}
139
140impl fmt::Display for Exceeded {
141 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
142 write!(
143 f,
144 "connection limit exceeded: at most {} {} are allowed",
145 self.limit, self.kind
146 )
147 }
148}
149
150#[derive(Debug, Clone, Copy)]
151enum Kind {
152 PendingIncoming,
153 PendingOutgoing,
154 EstablishedIncoming,
155 EstablishedOutgoing,
156 EstablishedPerPeer,
157 EstablishedTotal,
158}
159
160impl fmt::Display for Kind {
161 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
162 match self {
163 Kind::PendingIncoming => write!(f, "pending incoming connections"),
164 Kind::PendingOutgoing => write!(f, "pending outgoing connections"),
165 Kind::EstablishedIncoming => write!(f, "established incoming connections"),
166 Kind::EstablishedOutgoing => write!(f, "established outgoing connections"),
167 Kind::EstablishedPerPeer => write!(f, "established connections per peer"),
168 Kind::EstablishedTotal => write!(f, "established connections"),
169 }
170 }
171}
172
173impl std::error::Error for Exceeded {}
174
175#[derive(Debug, Clone, Default)]
177pub struct ConnectionLimits {
178 max_pending_incoming: Option<u32>,
179 max_pending_outgoing: Option<u32>,
180 max_established_incoming: Option<u32>,
181 max_established_outgoing: Option<u32>,
182 max_established_per_peer: Option<u32>,
183 max_established_total: Option<u32>,
184}
185
186impl ConnectionLimits {
187 pub fn with_max_pending_incoming(mut self, limit: Option<u32>) -> Self {
189 self.max_pending_incoming = limit;
190 self
191 }
192
193 pub fn with_max_pending_outgoing(mut self, limit: Option<u32>) -> Self {
195 self.max_pending_outgoing = limit;
196 self
197 }
198
199 pub fn with_max_established_incoming(mut self, limit: Option<u32>) -> Self {
201 self.max_established_incoming = limit;
202 self
203 }
204
205 pub fn with_max_established_outgoing(mut self, limit: Option<u32>) -> Self {
207 self.max_established_outgoing = limit;
208 self
209 }
210
211 pub fn with_max_established(mut self, limit: Option<u32>) -> Self {
218 self.max_established_total = limit;
219 self
220 }
221
222 pub fn with_max_established_per_peer(mut self, limit: Option<u32>) -> Self {
225 self.max_established_per_peer = limit;
226 self
227 }
228}
229
230impl NetworkBehaviour for Behaviour {
231 type ConnectionHandler = dummy::ConnectionHandler;
232 type ToSwarm = Infallible;
233
234 fn handle_pending_inbound_connection(
235 &mut self,
236 connection_id: ConnectionId,
237 _: &Multiaddr,
238 _: &Multiaddr,
239 ) -> Result<(), ConnectionDenied> {
240 check_limit(
241 self.limits.max_pending_incoming,
242 self.pending_inbound_connections.len(),
243 Kind::PendingIncoming,
244 )?;
245
246 self.pending_inbound_connections.insert(connection_id);
247
248 Ok(())
249 }
250
251 fn handle_established_inbound_connection(
252 &mut self,
253 connection_id: ConnectionId,
254 peer: PeerId,
255 _: &Multiaddr,
256 _: &Multiaddr,
257 ) -> Result<THandler<Self>, ConnectionDenied> {
258 self.pending_inbound_connections.remove(&connection_id);
259
260 if self.is_bypassed(&peer) {
261 return Ok(dummy::ConnectionHandler);
262 }
263 check_limit(
264 self.limits.max_established_incoming,
265 self.established_inbound_connections.len(),
266 Kind::EstablishedIncoming,
267 )?;
268 check_limit(
269 self.limits.max_established_per_peer,
270 self.established_per_peer
271 .get(&peer)
272 .map(|connections| connections.len())
273 .unwrap_or(0),
274 Kind::EstablishedPerPeer,
275 )?;
276 check_limit(
277 self.limits.max_established_total,
278 self.established_inbound_connections.len()
279 + self.established_outbound_connections.len(),
280 Kind::EstablishedTotal,
281 )?;
282
283 Ok(dummy::ConnectionHandler)
284 }
285
286 fn handle_pending_outbound_connection(
287 &mut self,
288 connection_id: ConnectionId,
289 maybe_peer: Option<PeerId>,
290 _: &[Multiaddr],
291 _: Endpoint,
292 ) -> Result<Vec<Multiaddr>, ConnectionDenied> {
293 if maybe_peer.is_some_and(|peer| self.is_bypassed(&peer)) {
294 return Ok(vec![]);
295 }
296 check_limit(
297 self.limits.max_pending_outgoing,
298 self.pending_outbound_connections.len(),
299 Kind::PendingOutgoing,
300 )?;
301
302 self.pending_outbound_connections.insert(connection_id);
303
304 Ok(vec![])
305 }
306
307 fn handle_established_outbound_connection(
308 &mut self,
309 connection_id: ConnectionId,
310 peer: PeerId,
311 _: &Multiaddr,
312 _: Endpoint,
313 _: PortUse,
314 ) -> Result<THandler<Self>, ConnectionDenied> {
315 self.pending_outbound_connections.remove(&connection_id);
316 if self.is_bypassed(&peer) {
317 return Ok(dummy::ConnectionHandler);
318 }
319
320 check_limit(
321 self.limits.max_established_outgoing,
322 self.established_outbound_connections.len(),
323 Kind::EstablishedOutgoing,
324 )?;
325 check_limit(
326 self.limits.max_established_per_peer,
327 self.established_per_peer
328 .get(&peer)
329 .map(|connections| connections.len())
330 .unwrap_or(0),
331 Kind::EstablishedPerPeer,
332 )?;
333 check_limit(
334 self.limits.max_established_total,
335 self.established_inbound_connections.len()
336 + self.established_outbound_connections.len(),
337 Kind::EstablishedTotal,
338 )?;
339
340 Ok(dummy::ConnectionHandler)
341 }
342
343 fn on_swarm_event(&mut self, event: FromSwarm) {
344 match event {
345 FromSwarm::ConnectionClosed(ConnectionClosed {
346 peer_id,
347 connection_id,
348 ..
349 }) => {
350 self.established_inbound_connections.remove(&connection_id);
351 self.established_outbound_connections.remove(&connection_id);
352 self.established_per_peer
353 .entry(peer_id)
354 .or_default()
355 .remove(&connection_id);
356 }
357 FromSwarm::ConnectionEstablished(ConnectionEstablished {
358 peer_id,
359 endpoint,
360 connection_id,
361 ..
362 }) => {
363 match endpoint {
364 ConnectedPoint::Listener { .. } => {
365 self.established_inbound_connections.insert(connection_id);
366 }
367 ConnectedPoint::Dialer { .. } => {
368 self.established_outbound_connections.insert(connection_id);
369 }
370 }
371
372 self.established_per_peer
373 .entry(peer_id)
374 .or_default()
375 .insert(connection_id);
376 }
377 FromSwarm::DialFailure(DialFailure { connection_id, .. }) => {
378 self.pending_outbound_connections.remove(&connection_id);
379 }
380 FromSwarm::ListenFailure(ListenFailure { connection_id, .. }) => {
381 self.pending_inbound_connections.remove(&connection_id);
382 }
383 _ => {}
384 }
385 }
386
387 fn on_connection_handler_event(
388 &mut self,
389 _id: PeerId,
390 _: ConnectionId,
391 event: THandlerOutEvent<Self>,
392 ) {
393 libp2p_core::util::unreachable(event)
394 }
395
396 fn poll(&mut self, _: &mut Context<'_>) -> Poll<ToSwarm<Self::ToSwarm, THandlerInEvent<Self>>> {
397 Poll::Pending
398 }
399}
400
401#[cfg(test)]
402mod tests {
403 use libp2p_swarm::{
404 behaviour::toggle::Toggle,
405 dial_opts::{DialOpts, PeerCondition},
406 DialError, ListenError, Swarm, SwarmEvent,
407 };
408 use libp2p_swarm_test::SwarmExt;
409 use quickcheck::*;
410 use tokio::runtime::Runtime;
411
412 use super::*;
413
414 fn fill_outgoing() -> (Swarm<Behaviour>, Multiaddr, u32) {
415 use rand::Rng;
416
417 let outgoing_limit = rand::thread_rng().gen_range(1..10);
418
419 let mut network = Swarm::new_ephemeral(|_| {
420 Behaviour::new(
421 ConnectionLimits::default().with_max_pending_outgoing(Some(outgoing_limit)),
422 )
423 });
424
425 let addr: Multiaddr = "/memory/1234".parse().unwrap();
426 let target = PeerId::random();
427
428 for _ in 0..outgoing_limit {
429 network
430 .dial(
431 DialOpts::peer_id(target)
432 .condition(PeerCondition::Always)
434 .addresses(vec![addr.clone()])
435 .build(),
436 )
437 .expect("Unexpected connection limit.");
438 }
439 (network, addr, outgoing_limit)
440 }
441
442 #[test]
443 fn max_outgoing() {
444 let (mut network, addr, outgoing_limit) = fill_outgoing();
445 match network
446 .dial(
447 DialOpts::peer_id(PeerId::random())
448 .condition(PeerCondition::Always)
449 .addresses(vec![addr])
450 .build(),
451 )
452 .expect_err("Unexpected dialing success.")
453 {
454 DialError::Denied { cause } => {
455 let exceeded = cause
456 .downcast::<Exceeded>()
457 .expect("connection denied because of limit");
458
459 assert_eq!(exceeded.limit(), outgoing_limit);
460 }
461 e => panic!("Unexpected error: {e:?}"),
462 }
463
464 let info = network.network_info();
465 assert_eq!(info.num_peers(), 0);
466 assert_eq!(
467 info.connection_counters().num_pending_outgoing(),
468 outgoing_limit
469 );
470 }
471
472 #[test]
473 fn outgoing_limit_bypass() {
474 let (mut network, addr, _) = fill_outgoing();
475 let bypassed_peer = PeerId::random();
476 network
477 .behaviour_mut()
478 .limits
479 .bypass_peer_id(&bypassed_peer);
480 assert!(network.behaviour().limits.is_bypassed(&bypassed_peer));
481 if let Err(DialError::Denied { cause }) = network.dial(
482 DialOpts::peer_id(bypassed_peer)
483 .addresses(vec![addr.clone()])
484 .build(),
485 ) {
486 cause
487 .downcast::<Exceeded>()
488 .expect_err("Unexpected connection denied because of limit");
489 }
490 let not_bypassed_peer = loop {
491 let new_peer = PeerId::random();
492 if new_peer != bypassed_peer {
493 break new_peer;
494 }
495 };
496 match network
497 .dial(
498 DialOpts::peer_id(not_bypassed_peer)
499 .addresses(vec![addr])
500 .build(),
501 )
502 .expect_err("Unexpected dialing success.")
503 {
504 DialError::Denied { cause } => {
505 cause
506 .downcast::<Exceeded>()
507 .expect("connection denied because of limit");
508 }
509 e => panic!("Unexpected error: {e:?}"),
510 }
511 }
512
513 #[test]
514 fn max_established_incoming() {
515 fn prop(Limit(limit): Limit) {
516 let mut swarm1 = Swarm::new_ephemeral(|_| {
517 Behaviour::new(
518 ConnectionLimits::default().with_max_established_incoming(Some(limit)),
519 )
520 });
521 let mut swarm2 = Swarm::new_ephemeral(|_| {
522 Behaviour::new(
523 ConnectionLimits::default().with_max_established_incoming(Some(limit)),
524 )
525 });
526
527 let rt = Runtime::new().unwrap();
528 rt.block_on(async {
529 let (listen_addr, _) = swarm1.listen().with_memory_addr_external().await;
530
531 for _ in 0..limit {
532 swarm2.connect(&mut swarm1).await;
533 }
534
535 swarm2.dial(listen_addr).unwrap();
536
537 tokio::spawn(swarm2.loop_on_next());
538
539 let cause = swarm1
540 .wait(|event| match event {
541 SwarmEvent::IncomingConnectionError {
542 error: ListenError::Denied { cause },
543 ..
544 } => Some(cause),
545 _ => None,
546 })
547 .await;
548
549 assert_eq!(cause.downcast::<Exceeded>().unwrap().limit, limit);
550 });
551 }
552
553 quickcheck(prop as fn(_));
554 }
555
556 #[test]
557 fn bypass_established_incoming() {
558 fn prop(Limit(limit): Limit) {
559 let mut swarm1 = Swarm::new_ephemeral(|_| {
560 Behaviour::new(
561 ConnectionLimits::default().with_max_established_incoming(Some(limit)),
562 )
563 });
564 let mut swarm2 = Swarm::new_ephemeral(|_| {
565 Behaviour::new(
566 ConnectionLimits::default().with_max_established_incoming(Some(limit)),
567 )
568 });
569 let mut swarm3 = Swarm::new_ephemeral(|_| {
570 Behaviour::new(
571 ConnectionLimits::default().with_max_established_incoming(Some(limit)),
572 )
573 });
574
575 let rt = Runtime::new().unwrap();
576 let bypassed_peer_id = *swarm3.local_peer_id();
577 swarm1
578 .behaviour_mut()
579 .limits
580 .bypass_peer_id(&bypassed_peer_id);
581
582 rt.block_on(async {
583 let (listen_addr, _) = swarm1.listen().with_memory_addr_external().await;
584
585 for _ in 0..limit {
586 swarm2.connect(&mut swarm1).await;
587 }
588
589 swarm3.dial(listen_addr.clone()).unwrap();
590
591 tokio::spawn(swarm2.loop_on_next());
592 tokio::spawn(swarm3.loop_on_next());
593
594 swarm1
595 .wait(|event| match event {
596 SwarmEvent::ConnectionEstablished { peer_id, .. } => {
597 (peer_id == bypassed_peer_id).then_some(())
598 }
599 SwarmEvent::IncomingConnectionError {
600 error: ListenError::Denied { cause },
601 ..
602 } => {
603 cause
604 .downcast::<Exceeded>()
605 .expect_err("Unexpected connection denied because of limit");
606 None
607 }
608 _ => None,
609 })
610 .await;
611 });
612 }
613
614 quickcheck(prop as fn(_));
615 }
616
617 #[tokio::test]
625 async fn support_other_behaviour_denying_connection() {
626 let mut swarm1 = Swarm::new_ephemeral(|_| {
627 Behaviour::new_with_connection_denier(ConnectionLimits::default())
628 });
629 let mut swarm2 = Swarm::new_ephemeral(|_| Behaviour::new(ConnectionLimits::default()));
630
631 let (listen_addr, _) = swarm1.listen().await;
633 swarm2.dial(listen_addr).unwrap();
634 tokio::spawn(swarm2.loop_on_next());
635
636 let cause = swarm1
638 .wait(|event| match event {
639 SwarmEvent::IncomingConnectionError {
640 error: ListenError::Denied { cause },
641 ..
642 } => Some(cause),
643 _ => None,
644 })
645 .await;
646
647 cause.downcast::<std::io::Error>().unwrap();
648
649 assert_eq!(
650 0,
651 swarm1
652 .behaviour_mut()
653 .limits
654 .established_inbound_connections
655 .len(),
656 "swarm1 connection limit behaviour to not count denied established connection as established connection"
657 )
658 }
659
660 #[derive(libp2p_swarm_derive::NetworkBehaviour)]
661 #[behaviour(prelude = "libp2p_swarm::derive_prelude")]
662 struct Behaviour {
663 limits: super::Behaviour,
664 connection_denier: Toggle<ConnectionDenier>,
665 }
666
667 impl Behaviour {
668 fn new(limits: ConnectionLimits) -> Self {
669 Self {
670 limits: super::Behaviour::new(limits),
671 connection_denier: None.into(),
672 }
673 }
674 fn new_with_connection_denier(limits: ConnectionLimits) -> Self {
675 Self {
676 limits: super::Behaviour::new(limits),
677 connection_denier: Some(ConnectionDenier {}).into(),
678 }
679 }
680 }
681
682 struct ConnectionDenier {}
683
684 impl NetworkBehaviour for ConnectionDenier {
685 type ConnectionHandler = dummy::ConnectionHandler;
686 type ToSwarm = Infallible;
687
688 fn handle_established_inbound_connection(
689 &mut self,
690 _connection_id: ConnectionId,
691 _peer: PeerId,
692 _local_addr: &Multiaddr,
693 _remote_addr: &Multiaddr,
694 ) -> Result<THandler<Self>, ConnectionDenied> {
695 Err(ConnectionDenied::new(std::io::Error::new(
696 std::io::ErrorKind::Other,
697 "ConnectionDenier",
698 )))
699 }
700
701 fn handle_established_outbound_connection(
702 &mut self,
703 _connection_id: ConnectionId,
704 _peer: PeerId,
705 _addr: &Multiaddr,
706 _role_override: Endpoint,
707 _port_use: PortUse,
708 ) -> Result<THandler<Self>, ConnectionDenied> {
709 Err(ConnectionDenied::new(std::io::Error::new(
710 std::io::ErrorKind::Other,
711 "ConnectionDenier",
712 )))
713 }
714
715 fn on_swarm_event(&mut self, _event: FromSwarm) {}
716
717 fn on_connection_handler_event(
718 &mut self,
719 _peer_id: PeerId,
720 _connection_id: ConnectionId,
721 event: THandlerOutEvent<Self>,
722 ) {
723 libp2p_core::util::unreachable(event)
724 }
725
726 fn poll(
727 &mut self,
728 _: &mut Context<'_>,
729 ) -> Poll<ToSwarm<Self::ToSwarm, THandlerInEvent<Self>>> {
730 Poll::Pending
731 }
732 }
733
734 #[derive(Debug, Clone)]
735 struct Limit(u32);
736
737 impl Arbitrary for Limit {
738 fn arbitrary(g: &mut Gen) -> Self {
739 Self(g.gen_range(1..10))
740 }
741 }
742}