1use std::{
24 collections::{HashMap, HashSet, VecDeque},
25 convert::Infallible,
26 task::{Context, Poll},
27};
28
29use either::Either;
30use hashlink::LruCache;
31use libp2p_core::{
32 connection::ConnectedPoint, multiaddr::Protocol, transport::PortUse, Endpoint, Multiaddr,
33};
34use libp2p_identity::PeerId;
35use libp2p_swarm::{
36 behaviour::{ConnectionClosed, DialFailure, FromSwarm},
37 dial_opts::{self, DialOpts},
38 dummy, ConnectionDenied, ConnectionHandler, ConnectionId, NetworkBehaviour,
39 NewExternalAddrCandidate, NotifyHandler, THandler, THandlerInEvent, THandlerOutEvent, ToSwarm,
40};
41use thiserror::Error;
42
43use crate::{handler, protocol};
44
45pub(crate) const MAX_NUMBER_OF_UPGRADE_ATTEMPTS: u8 = 3;
46
47#[derive(Debug)]
49pub struct Event {
50 pub remote_peer_id: PeerId,
51 pub result: Result<ConnectionId, Error>,
52}
53
54#[derive(Debug, Error)]
55#[error("Failed to hole-punch connection: {inner}")]
56pub struct Error {
57 inner: InnerError,
58}
59
60#[derive(Debug, Error)]
61enum InnerError {
62 #[error("Giving up after {0} dial attempts")]
63 AttemptsExceeded(u8),
64 #[error("Inbound stream error: {0}")]
65 InboundError(protocol::inbound::Error),
66 #[error("Outbound stream error: {0}")]
67 OutboundError(protocol::outbound::Error),
68}
69
70pub struct Behaviour {
71 queued_events: VecDeque<ToSwarm<Event, Either<handler::relayed::Command, Infallible>>>,
73
74 direct_connections: HashMap<PeerId, HashSet<ConnectionId>>,
76
77 address_candidates: Candidates,
78
79 direct_to_relayed_connections: HashMap<ConnectionId, ConnectionId>,
80
81 outgoing_direct_connection_attempts: HashMap<(ConnectionId, PeerId), u8>,
84}
85
86impl Behaviour {
87 pub fn new(local_peer_id: PeerId) -> Self {
88 Behaviour {
89 queued_events: Default::default(),
90 direct_connections: Default::default(),
91 address_candidates: Candidates::new(local_peer_id),
92 direct_to_relayed_connections: Default::default(),
93 outgoing_direct_connection_attempts: Default::default(),
94 }
95 }
96
97 fn observed_addresses(&self) -> Vec<Multiaddr> {
98 self.address_candidates.iter().cloned().collect()
99 }
100
101 fn on_dial_failure(
102 &mut self,
103 DialFailure {
104 peer_id,
105 connection_id: failed_direct_connection,
106 ..
107 }: DialFailure,
108 ) {
109 let Some(peer_id) = peer_id else {
110 return;
111 };
112
113 let Some(relayed_connection_id) = self
114 .direct_to_relayed_connections
115 .get(&failed_direct_connection)
116 else {
117 return;
118 };
119
120 let Some(attempt) = self
121 .outgoing_direct_connection_attempts
122 .get(&(*relayed_connection_id, peer_id))
123 else {
124 return;
125 };
126
127 if *attempt < MAX_NUMBER_OF_UPGRADE_ATTEMPTS {
128 self.queued_events.push_back(ToSwarm::NotifyHandler {
129 handler: NotifyHandler::One(*relayed_connection_id),
130 peer_id,
131 event: Either::Left(handler::relayed::Command::Connect),
132 })
133 } else {
134 self.queued_events.extend([ToSwarm::GenerateEvent(Event {
135 remote_peer_id: peer_id,
136 result: Err(Error {
137 inner: InnerError::AttemptsExceeded(MAX_NUMBER_OF_UPGRADE_ATTEMPTS),
138 }),
139 })]);
140 }
141 }
142
143 fn on_connection_closed(
144 &mut self,
145 ConnectionClosed {
146 peer_id,
147 connection_id,
148 endpoint: connected_point,
149 ..
150 }: ConnectionClosed,
151 ) {
152 if !connected_point.is_relayed() {
153 let connections = self
154 .direct_connections
155 .get_mut(&peer_id)
156 .expect("Peer of direct connection to be tracked.");
157 connections
158 .remove(&connection_id)
159 .then_some(())
160 .expect("Direct connection to be tracked.");
161 if connections.is_empty() {
162 self.direct_connections.remove(&peer_id);
163 }
164 }
165 }
166}
167
168impl NetworkBehaviour for Behaviour {
169 type ConnectionHandler = Either<handler::relayed::Handler, dummy::ConnectionHandler>;
170 type ToSwarm = Event;
171
172 fn handle_established_inbound_connection(
173 &mut self,
174 connection_id: ConnectionId,
175 peer: PeerId,
176 local_addr: &Multiaddr,
177 remote_addr: &Multiaddr,
178 ) -> Result<THandler<Self>, ConnectionDenied> {
179 if is_relayed(local_addr) {
180 let connected_point = ConnectedPoint::Listener {
181 local_addr: local_addr.clone(),
182 send_back_addr: remote_addr.clone(),
183 };
184 let mut handler =
185 handler::relayed::Handler::new(connected_point, self.observed_addresses());
186 handler.on_behaviour_event(handler::relayed::Command::Connect);
187
188 return Ok(Either::Left(handler));
190 }
191 self.direct_connections
192 .entry(peer)
193 .or_default()
194 .insert(connection_id);
195
196 assert!(
197 !self
198 .direct_to_relayed_connections
199 .contains_key(&connection_id),
200 "state mismatch"
201 );
202
203 Ok(Either::Right(dummy::ConnectionHandler))
204 }
205
206 fn handle_established_outbound_connection(
207 &mut self,
208 connection_id: ConnectionId,
209 peer: PeerId,
210 addr: &Multiaddr,
211 role_override: Endpoint,
212 port_use: PortUse,
213 ) -> Result<THandler<Self>, ConnectionDenied> {
214 if is_relayed(addr) {
215 return Ok(Either::Left(handler::relayed::Handler::new(
216 ConnectedPoint::Dialer {
217 address: addr.clone(),
218 role_override,
219 port_use,
220 },
221 self.observed_addresses(),
222 ))); }
225
226 self.direct_connections
227 .entry(peer)
228 .or_default()
229 .insert(connection_id);
230
231 if let Some(&relayed_connection_id) = self.direct_to_relayed_connections.get(&connection_id)
233 {
234 if role_override == Endpoint::Listener {
235 assert!(
236 self.outgoing_direct_connection_attempts
237 .remove(&(relayed_connection_id, peer))
238 .is_some(),
239 "state mismatch"
240 );
241 }
242
243 self.queued_events.extend([ToSwarm::GenerateEvent(Event {
244 remote_peer_id: peer,
245 result: Ok(connection_id),
246 })]);
247 }
248 Ok(Either::Right(dummy::ConnectionHandler))
249 }
250
251 fn on_connection_handler_event(
252 &mut self,
253 event_source: PeerId,
254 connection_id: ConnectionId,
255 handler_event: THandlerOutEvent<Self>,
256 ) {
257 let relayed_connection_id = match handler_event.as_ref() {
258 Either::Left(_) => connection_id,
259 Either::Right(_) => match self.direct_to_relayed_connections.get(&connection_id) {
260 None => {
261 return;
264 }
265 Some(relayed_connection_id) => *relayed_connection_id,
266 },
267 };
268
269 match handler_event {
270 Either::Left(handler::relayed::Event::InboundConnectNegotiated { remote_addrs }) => {
271 tracing::debug!(target=%event_source, addresses=?remote_addrs, "Attempting to hole-punch as dialer");
272
273 let opts = DialOpts::peer_id(event_source)
274 .addresses(remote_addrs)
275 .condition(dial_opts::PeerCondition::Always)
276 .build();
277
278 let maybe_direct_connection_id = opts.connection_id();
279
280 self.direct_to_relayed_connections
281 .insert(maybe_direct_connection_id, relayed_connection_id);
282 self.queued_events.push_back(ToSwarm::Dial { opts });
283 }
284 Either::Left(handler::relayed::Event::InboundConnectFailed { error }) => {
285 self.queued_events.push_back(ToSwarm::GenerateEvent(Event {
286 remote_peer_id: event_source,
287 result: Err(Error {
288 inner: InnerError::InboundError(error),
289 }),
290 }));
291 }
292 Either::Left(handler::relayed::Event::OutboundConnectFailed { error }) => {
293 self.queued_events.push_back(ToSwarm::GenerateEvent(Event {
294 remote_peer_id: event_source,
295 result: Err(Error {
296 inner: InnerError::OutboundError(error),
297 }),
298 }));
299
300 }
302 Either::Left(handler::relayed::Event::OutboundConnectNegotiated { remote_addrs }) => {
303 tracing::debug!(target=%event_source, addresses=?remote_addrs, "Attempting to hole-punch as listener");
304
305 let opts = DialOpts::peer_id(event_source)
306 .condition(dial_opts::PeerCondition::Always)
307 .addresses(remote_addrs)
308 .override_role()
309 .build();
310
311 let maybe_direct_connection_id = opts.connection_id();
312
313 self.direct_to_relayed_connections
314 .insert(maybe_direct_connection_id, relayed_connection_id);
315 *self
316 .outgoing_direct_connection_attempts
317 .entry((relayed_connection_id, event_source))
318 .or_default() += 1;
319 self.queued_events.push_back(ToSwarm::Dial { opts });
320 }
321 Either::Right(never) => libp2p_core::util::unreachable(never),
322 };
323 }
324
325 #[tracing::instrument(level = "trace", name = "NetworkBehaviour::poll", skip(self))]
326 fn poll(&mut self, _: &mut Context<'_>) -> Poll<ToSwarm<Self::ToSwarm, THandlerInEvent<Self>>> {
327 if let Some(event) = self.queued_events.pop_front() {
328 return Poll::Ready(event);
329 }
330
331 Poll::Pending
332 }
333
334 fn on_swarm_event(&mut self, event: FromSwarm) {
335 match event {
336 FromSwarm::ConnectionClosed(connection_closed) => {
337 self.on_connection_closed(connection_closed)
338 }
339 FromSwarm::DialFailure(dial_failure) => self.on_dial_failure(dial_failure),
340 FromSwarm::NewExternalAddrCandidate(NewExternalAddrCandidate { addr }) => {
341 self.address_candidates.add(addr.clone());
342 }
343 _ => {}
344 }
345 }
346}
347
348struct Candidates {
356 inner: LruCache<Multiaddr, ()>,
357 me: PeerId,
358}
359
360impl Candidates {
361 fn new(me: PeerId) -> Self {
362 Self {
363 inner: LruCache::new(20),
364 me,
365 }
366 }
367
368 fn add(&mut self, mut address: Multiaddr) {
369 if is_relayed(&address) {
370 return;
371 }
372
373 if address.iter().last() != Some(Protocol::P2p(self.me)) {
374 address.push(Protocol::P2p(self.me));
375 }
376
377 self.inner.insert(address, ());
378 }
379
380 fn iter(&self) -> impl Iterator<Item = &Multiaddr> {
381 self.inner.iter().map(|(a, _)| a)
382 }
383}
384
385fn is_relayed(addr: &Multiaddr) -> bool {
386 addr.iter().any(|p| p == Protocol::P2pCircuit)
387}