1use std::{
22    collections::{HashMap, VecDeque},
23    iter,
24    task::{Context, Poll},
25    time::Duration,
26};
27
28use futures::{
29    future::{BoxFuture, FutureExt},
30    stream::{FuturesUnordered, StreamExt},
31};
32use libp2p_core::{transport::PortUse, Endpoint, Multiaddr, PeerRecord};
33use libp2p_identity::{Keypair, PeerId, SigningError};
34use libp2p_request_response::{OutboundRequestId, ProtocolSupport};
35use libp2p_swarm::{
36    ConnectionDenied, ConnectionId, ExternalAddresses, FromSwarm, NetworkBehaviour, THandler,
37    THandlerInEvent, THandlerOutEvent, ToSwarm,
38};
39
40use crate::codec::{
41    Cookie, ErrorCode, Message, Message::*, Namespace, NewRegistration, Registration, Ttl,
42};
43
44pub struct Behaviour {
45    events: VecDeque<ToSwarm<<Self as NetworkBehaviour>::ToSwarm, THandlerInEvent<Self>>>,
46
47    inner: libp2p_request_response::Behaviour<crate::codec::Codec>,
48
49    keypair: Keypair,
50
51    waiting_for_register: HashMap<OutboundRequestId, (PeerId, Namespace)>,
52    waiting_for_discovery: HashMap<OutboundRequestId, (PeerId, Option<Namespace>)>,
53
54    discovered_peers: HashMap<PeerId, HashMap<Namespace, Vec<Multiaddr>>>,
59
60    registered_namespaces: HashMap<(PeerId, Namespace), Ttl>,
61
62    expiring_registrations: FuturesUnordered<BoxFuture<'static, (PeerId, Namespace)>>,
65
66    external_addresses: ExternalAddresses,
67}
68
69impl Behaviour {
70    pub fn new(keypair: Keypair) -> Self {
72        Self {
73            events: Default::default(),
74            inner: libp2p_request_response::Behaviour::with_codec(
75                crate::codec::Codec::default(),
76                iter::once((crate::PROTOCOL_IDENT, ProtocolSupport::Outbound)),
77                libp2p_request_response::Config::default(),
78            ),
79            keypair,
80            waiting_for_register: Default::default(),
81            waiting_for_discovery: Default::default(),
82            discovered_peers: Default::default(),
83            registered_namespaces: Default::default(),
84            expiring_registrations: FuturesUnordered::from_iter(vec![
85                futures::future::pending().boxed()
86            ]),
87            external_addresses: Default::default(),
88        }
89    }
90
91    pub fn register(
97        &mut self,
98        namespace: Namespace,
99        rendezvous_node: PeerId,
100        ttl: Option<Ttl>,
101    ) -> Result<(), RegisterError> {
102        let external_addresses = self.external_addresses.iter().cloned().collect::<Vec<_>>();
103        if external_addresses.is_empty() {
104            return Err(RegisterError::NoExternalAddresses);
105        }
106
107        let peer_record = PeerRecord::new(&self.keypair, external_addresses)?;
108        let req_id = self.inner.send_request(
109            &rendezvous_node,
110            Register(NewRegistration::new(namespace.clone(), peer_record, ttl)),
111        );
112        self.waiting_for_register
113            .insert(req_id, (rendezvous_node, namespace));
114
115        Ok(())
116    }
117
118    pub fn unregister(&mut self, namespace: Namespace, rendezvous_node: PeerId) {
120        self.registered_namespaces
121            .retain(|(rz_node, ns), _| rz_node.ne(&rendezvous_node) && ns.ne(&namespace));
122
123        self.inner
124            .send_request(&rendezvous_node, Unregister(namespace));
125    }
126
127    pub fn discover(
135        &mut self,
136        namespace: Option<Namespace>,
137        cookie: Option<Cookie>,
138        limit: Option<u64>,
139        rendezvous_node: PeerId,
140    ) {
141        let req_id = self.inner.send_request(
142            &rendezvous_node,
143            Discover {
144                namespace: namespace.clone(),
145                cookie,
146                limit,
147            },
148        );
149
150        self.waiting_for_discovery
151            .insert(req_id, (rendezvous_node, namespace));
152    }
153}
154
155#[derive(Debug, thiserror::Error)]
156pub enum RegisterError {
157    #[error("We don't know about any externally reachable addresses of ours")]
158    NoExternalAddresses,
159    #[error("Failed to make a new PeerRecord")]
160    FailedToMakeRecord(#[from] SigningError),
161}
162
163#[derive(Debug)]
164#[allow(clippy::large_enum_variant)]
165pub enum Event {
166    Discovered {
168        rendezvous_node: PeerId,
169        registrations: Vec<Registration>,
170        cookie: Cookie,
171    },
172    DiscoverFailed {
174        rendezvous_node: PeerId,
175        namespace: Option<Namespace>,
176        error: ErrorCode,
177    },
178    Registered {
180        rendezvous_node: PeerId,
181        ttl: Ttl,
182        namespace: Namespace,
183    },
184    RegisterFailed {
186        rendezvous_node: PeerId,
187        namespace: Namespace,
188        error: ErrorCode,
189    },
190    Expired { peer: PeerId },
192}
193
194impl NetworkBehaviour for Behaviour {
195    type ConnectionHandler = <libp2p_request_response::Behaviour<
196        crate::codec::Codec,
197    > as NetworkBehaviour>::ConnectionHandler;
198
199    type ToSwarm = Event;
200
201    fn handle_established_inbound_connection(
202        &mut self,
203        connection_id: ConnectionId,
204        peer: PeerId,
205        local_addr: &Multiaddr,
206        remote_addr: &Multiaddr,
207    ) -> Result<THandler<Self>, ConnectionDenied> {
208        self.inner.handle_established_inbound_connection(
209            connection_id,
210            peer,
211            local_addr,
212            remote_addr,
213        )
214    }
215
216    fn handle_established_outbound_connection(
217        &mut self,
218        connection_id: ConnectionId,
219        peer: PeerId,
220        addr: &Multiaddr,
221        role_override: Endpoint,
222        port_use: PortUse,
223    ) -> Result<THandler<Self>, ConnectionDenied> {
224        self.inner.handle_established_outbound_connection(
225            connection_id,
226            peer,
227            addr,
228            role_override,
229            port_use,
230        )
231    }
232
233    fn on_connection_handler_event(
234        &mut self,
235        peer_id: PeerId,
236        connection_id: ConnectionId,
237        event: THandlerOutEvent<Self>,
238    ) {
239        self.inner
240            .on_connection_handler_event(peer_id, connection_id, event);
241    }
242
243    fn on_swarm_event(&mut self, event: FromSwarm) {
244        let changed = self.external_addresses.on_swarm_event(&event);
245
246        self.inner.on_swarm_event(event);
247
248        if changed && self.external_addresses.iter().count() > 0 {
249            let registered = self.registered_namespaces.clone();
250            for ((rz_node, ns), ttl) in registered {
251                if let Err(e) = self.register(ns, rz_node, Some(ttl)) {
252                    tracing::warn!("refreshing registration failed: {e}")
253                }
254            }
255        }
256    }
257
258    #[tracing::instrument(level = "trace", name = "NetworkBehaviour::poll", skip(self, cx))]
259    fn poll(
260        &mut self,
261        cx: &mut Context<'_>,
262    ) -> Poll<ToSwarm<Self::ToSwarm, THandlerInEvent<Self>>> {
263        use libp2p_request_response as req_res;
264        loop {
265            if let Some(event) = self.events.pop_front() {
266                return Poll::Ready(event);
267            }
268
269            match self.inner.poll(cx) {
270                Poll::Ready(ToSwarm::GenerateEvent(req_res::Event::Message {
271                    message:
272                        req_res::Message::Response {
273                            request_id,
274                            response,
275                        },
276                    ..
277                })) => {
278                    if let Some(event) = self.handle_response(&request_id, response) {
279                        return Poll::Ready(ToSwarm::GenerateEvent(event));
280                    }
281
282                    continue; }
284                Poll::Ready(ToSwarm::GenerateEvent(req_res::Event::OutboundFailure {
285                    request_id,
286                    ..
287                })) => {
288                    if let Some(event) = self.event_for_outbound_failure(&request_id) {
289                        return Poll::Ready(ToSwarm::GenerateEvent(event));
290                    }
291
292                    continue; }
294                Poll::Ready(ToSwarm::GenerateEvent(
295                    req_res::Event::InboundFailure { .. }
296                    | req_res::Event::ResponseSent { .. }
297                    | req_res::Event::Message {
298                        message: req_res::Message::Request { .. },
299                        ..
300                    },
301                )) => {
302                    unreachable!("rendezvous clients never receive requests")
303                }
304                Poll::Ready(other) => {
305                    let new_to_swarm =
306                        other.map_out(|_| unreachable!("we manually map `GenerateEvent` variants"));
307
308                    return Poll::Ready(new_to_swarm);
309                }
310                Poll::Pending => {}
311            }
312
313            if let Poll::Ready(Some((peer, expired_registration))) =
314                self.expiring_registrations.poll_next_unpin(cx)
315            {
316                let Some(registrations) = self.discovered_peers.get_mut(&peer) else {
317                    continue;
318                };
319                registrations.remove(&expired_registration);
320                if registrations.is_empty() {
321                    self.discovered_peers.remove(&peer);
322                }
323                return Poll::Ready(ToSwarm::GenerateEvent(Event::Expired { peer }));
324            }
325
326            return Poll::Pending;
327        }
328    }
329
330    fn handle_pending_outbound_connection(
331        &mut self,
332        _connection_id: ConnectionId,
333        maybe_peer: Option<PeerId>,
334        _addresses: &[Multiaddr],
335        _effective_role: Endpoint,
336    ) -> Result<Vec<Multiaddr>, ConnectionDenied> {
337        let addrs = maybe_peer
338            .map(|peer| self.discovered_peer_addrs(&peer).cloned().collect())
339            .unwrap_or_default();
340        Ok(addrs)
341    }
342}
343
344impl Behaviour {
345    fn event_for_outbound_failure(&mut self, req_id: &OutboundRequestId) -> Option<Event> {
346        if let Some((rendezvous_node, namespace)) = self.waiting_for_register.remove(req_id) {
347            return Some(Event::RegisterFailed {
348                rendezvous_node,
349                namespace,
350                error: ErrorCode::Unavailable,
351            });
352        };
353
354        if let Some((rendezvous_node, namespace)) = self.waiting_for_discovery.remove(req_id) {
355            return Some(Event::DiscoverFailed {
356                rendezvous_node,
357                namespace,
358                error: ErrorCode::Unavailable,
359            });
360        };
361
362        None
363    }
364
365    fn handle_response(
366        &mut self,
367        request_id: &OutboundRequestId,
368        response: Message,
369    ) -> Option<Event> {
370        match response {
371            RegisterResponse(Ok(ttl)) => {
372                let (rendezvous_node, namespace) = self.waiting_for_register.remove(request_id)?;
373                self.registered_namespaces
374                    .insert((rendezvous_node, namespace.clone()), ttl);
375
376                Some(Event::Registered {
377                    rendezvous_node,
378                    ttl,
379                    namespace,
380                })
381            }
382            RegisterResponse(Err(error_code)) => {
383                let (rendezvous_node, namespace) = self.waiting_for_register.remove(request_id)?;
384                Some(Event::RegisterFailed {
385                    rendezvous_node,
386                    namespace,
387                    error: error_code,
388                })
389            }
390            DiscoverResponse(Ok((registrations, cookie))) => {
391                let (rendezvous_node, _ns) = self.waiting_for_discovery.remove(request_id)?;
392                registrations.iter().for_each(|registration| {
393                    let peer_id = registration.record.peer_id();
394                    let addresses = registration.record.addresses();
395                    let namespace = registration.namespace.clone();
396                    let ttl = registration.ttl;
397
398                    let new_addr_events = addresses
400                        .iter()
401                        .filter_map(|address| {
402                            if self.discovered_peer_addrs(&peer_id).any(|a| a == address) {
403                                return None;
404                            }
405                            Some(ToSwarm::NewExternalAddrOfPeer {
406                                peer_id,
407                                address: address.clone(),
408                            })
409                        })
410                        .collect::<Vec<_>>();
411                    self.events.extend(new_addr_events);
412
413                    self.discovered_peers
415                        .entry(peer_id)
416                        .or_default()
417                        .insert(namespace.clone(), addresses.to_owned());
418
419                    self.expiring_registrations.push(
421                        async move {
422                            futures_timer::Delay::new(Duration::from_secs(ttl)).await;
424                            (peer_id, namespace)
425                        }
426                        .boxed(),
427                    );
428                });
429
430                Some(Event::Discovered {
431                    rendezvous_node,
432                    registrations,
433                    cookie,
434                })
435            }
436            DiscoverResponse(Err(error_code)) => {
437                let (rendezvous_node, ns) = self.waiting_for_discovery.remove(request_id)?;
438                Some(Event::DiscoverFailed {
439                    rendezvous_node,
440                    namespace: ns,
441                    error: error_code,
442                })
443            }
444            _ => unreachable!("rendezvous clients never receive requests"),
445        }
446    }
447
448    fn discovered_peer_addrs(&self, peer: &PeerId) -> impl Iterator<Item = &Multiaddr> {
449        self.discovered_peers
450            .get(peer)
451            .map(|addrs| addrs.values().flatten())
452            .unwrap_or_default()
453    }
454}