1use std::{
65 collections::{HashSet, VecDeque},
66 convert::Infallible,
67 fmt,
68 task::{Context, Poll, Waker},
69};
70
71use libp2p_core::{transport::PortUse, Endpoint, Multiaddr};
72use libp2p_identity::PeerId;
73use libp2p_swarm::{
74 dummy, CloseConnection, ConnectionDenied, ConnectionId, FromSwarm, NetworkBehaviour, THandler,
75 THandlerInEvent, THandlerOutEvent, ToSwarm,
76};
77
78#[derive(Default, Debug)]
80pub struct Behaviour<S> {
81 state: S,
82 close_connections: VecDeque<PeerId>,
83 waker: Option<Waker>,
84}
85
86#[derive(Default)]
88pub struct AllowedPeers {
89 peers: HashSet<PeerId>,
90}
91
92#[derive(Default)]
94pub struct BlockedPeers {
95 peers: HashSet<PeerId>,
96}
97
98impl Behaviour<AllowedPeers> {
99 pub fn allowed_peers(&self) -> &HashSet<PeerId> {
101 &self.state.peers
102 }
103
104 pub fn allow_peer(&mut self, peer: PeerId) -> bool {
109 let inserted = self.state.peers.insert(peer);
110 if inserted {
111 if let Some(waker) = self.waker.take() {
112 waker.wake()
113 }
114 }
115 inserted
116 }
117
118 pub fn disallow_peer(&mut self, peer: PeerId) -> bool {
125 let removed = self.state.peers.remove(&peer);
126 if removed {
127 self.close_connections.push_back(peer);
128 if let Some(waker) = self.waker.take() {
129 waker.wake()
130 }
131 }
132 removed
133 }
134}
135
136impl Behaviour<BlockedPeers> {
137 pub fn blocked_peers(&self) -> &HashSet<PeerId> {
139 &self.state.peers
140 }
141
142 pub fn block_peer(&mut self, peer: PeerId) -> bool {
149 let inserted = self.state.peers.insert(peer);
150 if inserted {
151 self.close_connections.push_back(peer);
152 if let Some(waker) = self.waker.take() {
153 waker.wake()
154 }
155 }
156 inserted
157 }
158
159 pub fn unblock_peer(&mut self, peer: PeerId) -> bool {
164 let removed = self.state.peers.remove(&peer);
165 if removed {
166 if let Some(waker) = self.waker.take() {
167 waker.wake()
168 }
169 }
170 removed
171 }
172}
173
174#[derive(Debug)]
176pub struct NotAllowed {
177 peer: PeerId,
178}
179
180impl fmt::Display for NotAllowed {
181 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
182 write!(f, "peer {} is not in the allow list", self.peer)
183 }
184}
185
186impl std::error::Error for NotAllowed {}
187
188#[derive(Debug)]
190pub struct Blocked {
191 peer: PeerId,
192}
193
194impl fmt::Display for Blocked {
195 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
196 write!(f, "peer {} is in the block list", self.peer)
197 }
198}
199
200impl std::error::Error for Blocked {}
201
202trait Enforce: 'static {
203 fn enforce(&self, peer: &PeerId) -> Result<(), ConnectionDenied>;
204}
205
206impl Enforce for AllowedPeers {
207 fn enforce(&self, peer: &PeerId) -> Result<(), ConnectionDenied> {
208 if !self.peers.contains(peer) {
209 return Err(ConnectionDenied::new(NotAllowed { peer: *peer }));
210 }
211
212 Ok(())
213 }
214}
215
216impl Enforce for BlockedPeers {
217 fn enforce(&self, peer: &PeerId) -> Result<(), ConnectionDenied> {
218 if self.peers.contains(peer) {
219 return Err(ConnectionDenied::new(Blocked { peer: *peer }));
220 }
221
222 Ok(())
223 }
224}
225
226impl<S> NetworkBehaviour for Behaviour<S>
227where
228 S: Enforce,
229{
230 type ConnectionHandler = dummy::ConnectionHandler;
231 type ToSwarm = Infallible;
232
233 fn handle_established_inbound_connection(
234 &mut self,
235 _: ConnectionId,
236 peer: PeerId,
237 _: &Multiaddr,
238 _: &Multiaddr,
239 ) -> Result<THandler<Self>, ConnectionDenied> {
240 self.state.enforce(&peer)?;
241
242 Ok(dummy::ConnectionHandler)
243 }
244
245 fn handle_pending_outbound_connection(
246 &mut self,
247 _: ConnectionId,
248 peer: Option<PeerId>,
249 _: &[Multiaddr],
250 _: Endpoint,
251 ) -> Result<Vec<Multiaddr>, ConnectionDenied> {
252 if let Some(peer) = peer {
253 self.state.enforce(&peer)?;
254 }
255
256 Ok(vec![])
257 }
258
259 fn handle_established_outbound_connection(
260 &mut self,
261 _: ConnectionId,
262 peer: PeerId,
263 _: &Multiaddr,
264 _: Endpoint,
265 _: PortUse,
266 ) -> Result<THandler<Self>, ConnectionDenied> {
267 self.state.enforce(&peer)?;
268
269 Ok(dummy::ConnectionHandler)
270 }
271
272 fn on_swarm_event(&mut self, _event: FromSwarm) {}
273
274 fn on_connection_handler_event(
275 &mut self,
276 _id: PeerId,
277 _: ConnectionId,
278 event: THandlerOutEvent<Self>,
279 ) {
280 libp2p_core::util::unreachable(event)
281 }
282
283 fn poll(
284 &mut self,
285 cx: &mut Context<'_>,
286 ) -> Poll<ToSwarm<Self::ToSwarm, THandlerInEvent<Self>>> {
287 if let Some(peer) = self.close_connections.pop_front() {
288 return Poll::Ready(ToSwarm::CloseConnection {
289 peer_id: peer,
290 connection: CloseConnection::All,
291 });
292 }
293
294 self.waker = Some(cx.waker().clone());
295 Poll::Pending
296 }
297}
298
299#[cfg(test)]
300mod tests {
301 use libp2p_swarm::{dial_opts::DialOpts, DialError, ListenError, Swarm, SwarmEvent};
302 use libp2p_swarm_test::SwarmExt;
303
304 use super::*;
305
306 #[tokio::test]
307 async fn cannot_dial_blocked_peer() {
308 let mut dialer = Swarm::new_ephemeral(|_| Behaviour::<BlockedPeers>::default());
309 let mut listener = Swarm::new_ephemeral(|_| Behaviour::<BlockedPeers>::default());
310 listener.listen().with_memory_addr_external().await;
311
312 dialer.behaviour_mut().block_peer(*listener.local_peer_id());
313
314 let DialError::Denied { cause } = dial(&mut dialer, &listener).unwrap_err() else {
315 panic!("unexpected dial error")
316 };
317 assert!(cause.downcast::<Blocked>().is_ok());
318 }
319
320 #[tokio::test]
321 async fn can_dial_unblocked_peer() {
322 let mut dialer = Swarm::new_ephemeral(|_| Behaviour::<BlockedPeers>::default());
323 let mut listener = Swarm::new_ephemeral(|_| Behaviour::<BlockedPeers>::default());
324 listener.listen().with_memory_addr_external().await;
325
326 dialer.behaviour_mut().block_peer(*listener.local_peer_id());
327 dialer
328 .behaviour_mut()
329 .unblock_peer(*listener.local_peer_id());
330
331 dial(&mut dialer, &listener).unwrap();
332 }
333
334 #[tokio::test]
335 async fn blocked_peer_cannot_dial_us() {
336 let mut dialer = Swarm::new_ephemeral(|_| Behaviour::<BlockedPeers>::default());
337 let mut listener = Swarm::new_ephemeral(|_| Behaviour::<BlockedPeers>::default());
338 listener.listen().with_memory_addr_external().await;
339
340 listener.behaviour_mut().block_peer(*dialer.local_peer_id());
341 dial(&mut dialer, &listener).unwrap();
342 tokio::spawn(dialer.loop_on_next());
343
344 let cause = listener
345 .wait(|e| match e {
346 SwarmEvent::IncomingConnectionError {
347 error: ListenError::Denied { cause },
348 ..
349 } => Some(cause),
350 _ => None,
351 })
352 .await;
353 assert!(cause.downcast::<Blocked>().is_ok());
354 }
355
356 #[tokio::test]
357 async fn connections_get_closed_upon_blocked() {
358 let mut dialer = Swarm::new_ephemeral(|_| Behaviour::<BlockedPeers>::default());
359 let mut listener = Swarm::new_ephemeral(|_| Behaviour::<BlockedPeers>::default());
360 listener.listen().with_memory_addr_external().await;
361 dialer.connect(&mut listener).await;
362
363 dialer.behaviour_mut().block_peer(*listener.local_peer_id());
364
365 let (
366 [SwarmEvent::ConnectionClosed {
367 peer_id: closed_dialer_peer,
368 ..
369 }],
370 [SwarmEvent::ConnectionClosed {
371 peer_id: closed_listener_peer,
372 ..
373 }],
374 ) = libp2p_swarm_test::drive(&mut dialer, &mut listener).await
375 else {
376 panic!("unexpected events")
377 };
378 assert_eq!(closed_dialer_peer, *listener.local_peer_id());
379 assert_eq!(closed_listener_peer, *dialer.local_peer_id());
380 }
381
382 #[tokio::test]
383 async fn cannot_dial_peer_unless_allowed() {
384 let mut dialer = Swarm::new_ephemeral(|_| Behaviour::<AllowedPeers>::default());
385 let mut listener = Swarm::new_ephemeral(|_| Behaviour::<AllowedPeers>::default());
386 listener.listen().with_memory_addr_external().await;
387
388 let DialError::Denied { cause } = dial(&mut dialer, &listener).unwrap_err() else {
389 panic!("unexpected dial error")
390 };
391 assert!(cause.downcast::<NotAllowed>().is_ok());
392
393 dialer.behaviour_mut().allow_peer(*listener.local_peer_id());
394 assert!(dial(&mut dialer, &listener).is_ok());
395 }
396
397 #[tokio::test]
398 async fn cannot_dial_disallowed_peer() {
399 let mut dialer = Swarm::new_ephemeral(|_| Behaviour::<AllowedPeers>::default());
400 let mut listener = Swarm::new_ephemeral(|_| Behaviour::<AllowedPeers>::default());
401 listener.listen().with_memory_addr_external().await;
402
403 dialer.behaviour_mut().allow_peer(*listener.local_peer_id());
404 dialer
405 .behaviour_mut()
406 .disallow_peer(*listener.local_peer_id());
407
408 let DialError::Denied { cause } = dial(&mut dialer, &listener).unwrap_err() else {
409 panic!("unexpected dial error")
410 };
411 assert!(cause.downcast::<NotAllowed>().is_ok());
412 }
413
414 #[tokio::test]
415 async fn not_allowed_peer_cannot_dial_us() {
416 let mut dialer = Swarm::new_ephemeral(|_| Behaviour::<AllowedPeers>::default());
417 let mut listener = Swarm::new_ephemeral(|_| Behaviour::<AllowedPeers>::default());
418 listener.listen().with_memory_addr_external().await;
419
420 dialer
421 .dial(
422 DialOpts::unknown_peer_id()
423 .address(listener.external_addresses().next().cloned().unwrap())
424 .build(),
425 )
426 .unwrap();
427
428 let (
429 [SwarmEvent::OutgoingConnectionError {
430 error:
431 DialError::Denied {
432 cause: outgoing_cause,
433 },
434 ..
435 }],
436 [_, SwarmEvent::IncomingConnectionError {
437 error:
438 ListenError::Denied {
439 cause: incoming_cause,
440 },
441 ..
442 }],
443 ) = libp2p_swarm_test::drive(&mut dialer, &mut listener).await
444 else {
445 panic!("unexpected events")
446 };
447 assert!(outgoing_cause.downcast::<NotAllowed>().is_ok());
448 assert!(incoming_cause.downcast::<NotAllowed>().is_ok());
449 }
450
451 #[tokio::test]
452 async fn connections_get_closed_upon_disallow() {
453 let mut dialer = Swarm::new_ephemeral(|_| Behaviour::<AllowedPeers>::default());
454 let mut listener = Swarm::new_ephemeral(|_| Behaviour::<AllowedPeers>::default());
455 listener.listen().with_memory_addr_external().await;
456 dialer.behaviour_mut().allow_peer(*listener.local_peer_id());
457 listener.behaviour_mut().allow_peer(*dialer.local_peer_id());
458
459 dialer.connect(&mut listener).await;
460
461 dialer
462 .behaviour_mut()
463 .disallow_peer(*listener.local_peer_id());
464 let (
465 [SwarmEvent::ConnectionClosed {
466 peer_id: closed_dialer_peer,
467 ..
468 }],
469 [SwarmEvent::ConnectionClosed {
470 peer_id: closed_listener_peer,
471 ..
472 }],
473 ) = libp2p_swarm_test::drive(&mut dialer, &mut listener).await
474 else {
475 panic!("unexpected events")
476 };
477 assert_eq!(closed_dialer_peer, *listener.local_peer_id());
478 assert_eq!(closed_listener_peer, *dialer.local_peer_id());
479 }
480
481 fn dial<S>(
482 dialer: &mut Swarm<Behaviour<S>>,
483 listener: &Swarm<Behaviour<S>>,
484 ) -> Result<(), DialError>
485 where
486 S: Enforce,
487 {
488 dialer.dial(
489 DialOpts::peer_id(*listener.local_peer_id())
490 .addresses(listener.external_addresses().cloned().collect())
491 .build(),
492 )
493 }
494}