libp2p_quic/
transport.rs

1// Copyright 2017-2020 Parity Technologies (UK) Ltd.
2//
3// Permission is hereby granted, free of charge, to any person obtaining a
4// copy of this software and associated documentation files (the "Software"),
5// to deal in the Software without restriction, including without limitation
6// the rights to use, copy, modify, merge, publish, distribute, sublicense,
7// and/or sell copies of the Software, and to permit persons to whom the
8// Software is furnished to do so, subject to the following conditions:
9//
10// The above copyright notice and this permission notice shall be included in
11// all copies or substantial portions of the Software.
12//
13// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
14// OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
16// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
18// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
19// DEALINGS IN THE SOFTWARE.
20
21use std::{
22    collections::{
23        hash_map::{DefaultHasher, Entry},
24        HashMap, HashSet,
25    },
26    fmt,
27    hash::{Hash, Hasher},
28    io,
29    net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr, UdpSocket},
30    pin::Pin,
31    task::{Context, Poll, Waker},
32    time::Duration,
33};
34
35use futures::{
36    channel::oneshot,
37    future::{BoxFuture, Either},
38    prelude::*,
39    ready,
40    stream::{SelectAll, StreamExt},
41};
42use if_watch::IfEvent;
43use libp2p_core::{
44    multiaddr::{Multiaddr, Protocol},
45    transport::{DialOpts, ListenerId, PortUse, TransportError, TransportEvent},
46    Endpoint, Transport,
47};
48use libp2p_identity::PeerId;
49use socket2::{Domain, Socket, Type};
50
51use crate::{
52    config::{Config, QuinnConfig},
53    hole_punching::hole_puncher,
54    provider::Provider,
55    ConnectError, Connecting, Connection, Error,
56};
57
58/// Implementation of the [`Transport`] trait for QUIC.
59///
60/// By default only QUIC Version 1 (RFC 9000) is supported. In the [`Multiaddr`] this maps to
61/// [`libp2p_core::multiaddr::Protocol::QuicV1`].
62/// The [`libp2p_core::multiaddr::Protocol::Quic`] codepoint is interpreted as QUIC version
63/// draft-29 and only supported if [`Config::support_draft_29`] is set to `true`.
64/// Note that in that case servers support both version an all QUIC listening addresses.
65///
66/// Version draft-29 should only be used to connect to nodes from other libp2p implementations
67/// that do not support `QuicV1` yet. Support for it will be removed long-term.
68/// See <https://github.com/multiformats/multiaddr/issues/145>.
69#[derive(Debug)]
70pub struct GenTransport<P: Provider> {
71    /// Config for the inner [`quinn`] structs.
72    quinn_config: QuinnConfig,
73    /// Timeout for the [`Connecting`] future.
74    handshake_timeout: Duration,
75    /// Whether draft-29 is supported for dialing and listening.
76    support_draft_29: bool,
77    /// Streams of active [`Listener`]s.
78    listeners: SelectAll<Listener<P>>,
79    /// Dialer for each socket family if no matching listener exists.
80    dialer: HashMap<SocketFamily, quinn::Endpoint>,
81    /// Waker to poll the transport again when a new dialer or listener is added.
82    waker: Option<Waker>,
83    /// Holepunching attempts
84    hole_punch_attempts: HashMap<SocketAddr, oneshot::Sender<Connecting>>,
85}
86
87#[expect(deprecated)]
88impl<P: Provider> GenTransport<P> {
89    /// Create a new [`GenTransport`] with the given [`Config`].
90    pub fn new(config: Config) -> Self {
91        let handshake_timeout = config.handshake_timeout;
92        let support_draft_29 = config.support_draft_29;
93        let quinn_config = config.into();
94        Self {
95            listeners: SelectAll::new(),
96            quinn_config,
97            handshake_timeout,
98            dialer: HashMap::new(),
99            waker: None,
100            support_draft_29,
101            hole_punch_attempts: Default::default(),
102        }
103    }
104
105    /// Create a new [`quinn::Endpoint`] with the given configs.
106    fn new_endpoint(
107        endpoint_config: quinn::EndpointConfig,
108        server_config: Option<quinn::ServerConfig>,
109        socket: UdpSocket,
110    ) -> Result<quinn::Endpoint, Error> {
111        use crate::provider::Runtime;
112        match P::runtime() {
113            #[cfg(feature = "tokio")]
114            Runtime::Tokio => {
115                let runtime = std::sync::Arc::new(quinn::TokioRuntime);
116                let endpoint =
117                    quinn::Endpoint::new(endpoint_config, server_config, socket, runtime)?;
118                Ok(endpoint)
119            }
120            Runtime::Dummy => {
121                let _ = endpoint_config;
122                let _ = server_config;
123                let _ = socket;
124                let err = std::io::Error::other("no async runtime found");
125                Err(Error::Io(err))
126            }
127        }
128    }
129
130    /// Extract the addr, quic version and peer id from the given [`Multiaddr`].
131    fn remote_multiaddr_to_socketaddr(
132        &self,
133        addr: Multiaddr,
134        check_unspecified_addr: bool,
135    ) -> Result<
136        (SocketAddr, ProtocolVersion, Option<PeerId>),
137        TransportError<<Self as Transport>::Error>,
138    > {
139        let (socket_addr, version, peer_id) = multiaddr_to_socketaddr(&addr, self.support_draft_29)
140            .ok_or_else(|| TransportError::MultiaddrNotSupported(addr.clone()))?;
141        if check_unspecified_addr && (socket_addr.port() == 0 || socket_addr.ip().is_unspecified())
142        {
143            return Err(TransportError::MultiaddrNotSupported(addr));
144        }
145        Ok((socket_addr, version, peer_id))
146    }
147
148    /// Pick any listener to use for dialing.
149    fn eligible_listener(&mut self, socket_addr: &SocketAddr) -> Option<&mut Listener<P>> {
150        let mut listeners: Vec<_> = self
151            .listeners
152            .iter_mut()
153            .filter(|l| {
154                if l.is_closed {
155                    return false;
156                }
157                SocketFamily::is_same(&l.socket_addr().ip(), &socket_addr.ip())
158            })
159            .filter(|l| {
160                if socket_addr.ip().is_loopback() {
161                    l.listening_addresses
162                        .iter()
163                        .any(|ip_addr| ip_addr.is_loopback())
164                } else {
165                    true
166                }
167            })
168            .collect();
169        match listeners.len() {
170            0 => None,
171            1 => listeners.pop(),
172            _ => {
173                // Pick any listener to use for dialing.
174                // We hash the socket address to achieve determinism.
175                let mut hasher = DefaultHasher::new();
176                socket_addr.hash(&mut hasher);
177                let index = hasher.finish() as usize % listeners.len();
178                Some(listeners.swap_remove(index))
179            }
180        }
181    }
182
183    fn create_socket(&self, socket_addr: SocketAddr) -> io::Result<UdpSocket> {
184        let socket = Socket::new(
185            Domain::for_address(socket_addr),
186            Type::DGRAM,
187            Some(socket2::Protocol::UDP),
188        )?;
189        if socket_addr.is_ipv6() {
190            socket.set_only_v6(true)?;
191        }
192
193        socket.bind(&socket_addr.into())?;
194
195        Ok(socket.into())
196    }
197
198    fn bound_socket(&mut self, socket_addr: SocketAddr) -> Result<quinn::Endpoint, Error> {
199        let socket_family = socket_addr.ip().into();
200        if let Some(waker) = self.waker.take() {
201            waker.wake();
202        }
203        let listen_socket_addr = match socket_family {
204            SocketFamily::Ipv4 => SocketAddr::new(Ipv4Addr::UNSPECIFIED.into(), 0),
205            SocketFamily::Ipv6 => SocketAddr::new(Ipv6Addr::UNSPECIFIED.into(), 0),
206        };
207        let socket = UdpSocket::bind(listen_socket_addr)?;
208        let endpoint_config = self.quinn_config.endpoint_config.clone();
209        let endpoint = Self::new_endpoint(endpoint_config, None, socket)?;
210        Ok(endpoint)
211    }
212}
213
214impl<P: Provider> Transport for GenTransport<P> {
215    type Output = (PeerId, Connection);
216    type Error = Error;
217    type ListenerUpgrade = Connecting;
218    type Dial = BoxFuture<'static, Result<Self::Output, Self::Error>>;
219
220    fn listen_on(
221        &mut self,
222        listener_id: ListenerId,
223        addr: Multiaddr,
224    ) -> Result<(), TransportError<Self::Error>> {
225        let (socket_addr, version, _peer_id) = self.remote_multiaddr_to_socketaddr(addr, false)?;
226        let endpoint_config = self.quinn_config.endpoint_config.clone();
227        let server_config = self.quinn_config.server_config.clone();
228        let socket = self.create_socket(socket_addr).map_err(Self::Error::from)?;
229
230        let socket_c = socket.try_clone().map_err(Self::Error::from)?;
231        let endpoint = Self::new_endpoint(endpoint_config, Some(server_config), socket)?;
232        let listener = Listener::new(
233            listener_id,
234            socket_c,
235            endpoint,
236            self.handshake_timeout,
237            version,
238        )?;
239        self.listeners.push(listener);
240
241        if let Some(waker) = self.waker.take() {
242            waker.wake();
243        }
244
245        // Remove dialer endpoint so that the endpoint is dropped once the last
246        // connection that uses it is closed.
247        // New outbound connections will use the bidirectional (listener) endpoint.
248        self.dialer.remove(&socket_addr.ip().into());
249
250        Ok(())
251    }
252
253    fn remove_listener(&mut self, id: ListenerId) -> bool {
254        if let Some(listener) = self.listeners.iter_mut().find(|l| l.listener_id == id) {
255            // Close the listener, which will eventually finish its stream.
256            // `SelectAll` removes streams once they are finished.
257            listener.close(Ok(()));
258            true
259        } else {
260            false
261        }
262    }
263
264    fn dial(
265        &mut self,
266        addr: Multiaddr,
267        dial_opts: DialOpts,
268    ) -> Result<Self::Dial, TransportError<Self::Error>> {
269        let (socket_addr, version, peer_id) =
270            self.remote_multiaddr_to_socketaddr(addr.clone(), true)?;
271
272        match (dial_opts.role, dial_opts.port_use) {
273            (Endpoint::Dialer, _) | (Endpoint::Listener, PortUse::Reuse) => {
274                let endpoint = if let Some(listener) = dial_opts
275                    .port_use
276                    .eq(&PortUse::Reuse)
277                    .then(|| self.eligible_listener(&socket_addr))
278                    .flatten()
279                {
280                    listener.endpoint.clone()
281                } else {
282                    let socket_family = socket_addr.ip().into();
283                    let dialer = if dial_opts.port_use == PortUse::Reuse {
284                        if let Some(occupied) = self.dialer.get(&socket_family) {
285                            occupied.clone()
286                        } else {
287                            let endpoint = self.bound_socket(socket_addr)?;
288                            self.dialer.insert(socket_family, endpoint.clone());
289                            endpoint
290                        }
291                    } else {
292                        self.bound_socket(socket_addr)?
293                    };
294                    dialer
295                };
296                let handshake_timeout = self.handshake_timeout;
297                let mut client_config = self.quinn_config.client_config.clone();
298                if version == ProtocolVersion::Draft29 {
299                    client_config.version(0xff00_001d);
300                }
301                Ok(Box::pin(async move {
302                    // This `"l"` seems necessary because an empty string is an invalid domain
303                    // name. While we don't use domain names, the underlying rustls library
304                    // is based upon the assumption that we do.
305                    let connecting = endpoint
306                        .connect_with(client_config, socket_addr, "l")
307                        .map_err(ConnectError)?;
308                    Connecting::new(connecting, handshake_timeout).await
309                }))
310            }
311            (Endpoint::Listener, _) => {
312                let peer_id = peer_id.ok_or(TransportError::MultiaddrNotSupported(addr.clone()))?;
313
314                let socket = self
315                    .eligible_listener(&socket_addr)
316                    .ok_or(TransportError::Other(
317                        Error::NoActiveListenerForDialAsListener,
318                    ))?
319                    .try_clone_socket()
320                    .map_err(Self::Error::from)?;
321
322                tracing::debug!("Preparing for hole-punch from {addr}");
323
324                let hole_puncher = hole_puncher::<P>(socket, socket_addr, self.handshake_timeout);
325
326                let (sender, receiver) = oneshot::channel();
327
328                match self.hole_punch_attempts.entry(socket_addr) {
329                    Entry::Occupied(mut sender_entry) => {
330                        // Stale senders, i.e. from failed hole punches are not removed.
331                        // Thus, we can just overwrite a stale sender.
332                        if !sender_entry.get().is_canceled() {
333                            return Err(TransportError::Other(Error::HolePunchInProgress(
334                                socket_addr,
335                            )));
336                        }
337                        sender_entry.insert(sender);
338                    }
339                    Entry::Vacant(entry) => {
340                        entry.insert(sender);
341                    }
342                };
343
344                Ok(Box::pin(async move {
345                    futures::pin_mut!(hole_puncher);
346                    match futures::future::select(receiver, hole_puncher).await {
347                        Either::Left((message, _)) => {
348                            let (inbound_peer_id, connection) = message
349                                .expect(
350                                    "hole punch connection sender is never dropped before receiver",
351                                )
352                                .await?;
353                            if inbound_peer_id != peer_id {
354                                tracing::warn!(
355                                    peer=%peer_id,
356                                    inbound_peer=%inbound_peer_id,
357                                    socket_address=%socket_addr,
358                                    "expected inbound connection from socket_address to resolve to peer but got inbound peer"
359                                );
360                            }
361                            Ok((inbound_peer_id, connection))
362                        }
363                        Either::Right((hole_punch_err, _)) => Err(hole_punch_err),
364                    }
365                }))
366            }
367        }
368    }
369
370    fn poll(
371        mut self: Pin<&mut Self>,
372        cx: &mut Context<'_>,
373    ) -> Poll<TransportEvent<Self::ListenerUpgrade, Self::Error>> {
374        while let Poll::Ready(Some(ev)) = self.listeners.poll_next_unpin(cx) {
375            match ev {
376                TransportEvent::Incoming {
377                    listener_id,
378                    mut upgrade,
379                    local_addr,
380                    send_back_addr,
381                } => {
382                    let socket_addr =
383                        multiaddr_to_socketaddr(&send_back_addr, self.support_draft_29)
384                            .unwrap()
385                            .0;
386
387                    if let Some(sender) = self.hole_punch_attempts.remove(&socket_addr) {
388                        match sender.send(upgrade) {
389                            Ok(()) => continue,
390                            Err(timed_out_holepunch) => {
391                                upgrade = timed_out_holepunch;
392                            }
393                        }
394                    }
395
396                    return Poll::Ready(TransportEvent::Incoming {
397                        listener_id,
398                        upgrade,
399                        local_addr,
400                        send_back_addr,
401                    });
402                }
403                _ => return Poll::Ready(ev),
404            }
405        }
406
407        self.waker = Some(cx.waker().clone());
408        Poll::Pending
409    }
410}
411
412impl From<Error> for TransportError<Error> {
413    fn from(err: Error) -> Self {
414        TransportError::Other(err)
415    }
416}
417
418/// Listener for incoming connections.
419struct Listener<P: Provider> {
420    /// Id of the listener.
421    listener_id: ListenerId,
422
423    /// Version of the supported quic protocol.
424    version: ProtocolVersion,
425
426    /// Endpoint
427    endpoint: quinn::Endpoint,
428
429    /// An underlying copy of the socket to be able to hole punch with
430    socket: UdpSocket,
431
432    /// A future to poll new incoming connections.
433    accept: BoxFuture<'static, Option<quinn::Incoming>>,
434    /// Timeout for connection establishment on inbound connections.
435    handshake_timeout: Duration,
436
437    /// Watcher for network interface changes.
438    ///
439    /// None if we are only listening on a single interface.
440    if_watcher: Option<P::IfWatcher>,
441
442    /// Whether the listener was closed and the stream should terminate.
443    is_closed: bool,
444
445    /// Pending event to reported.
446    pending_event: Option<<Self as Stream>::Item>,
447
448    /// The stream must be awaken after it has been closed to deliver the last event.
449    close_listener_waker: Option<Waker>,
450
451    listening_addresses: HashSet<IpAddr>,
452}
453
454impl<P: Provider> Listener<P> {
455    fn new(
456        listener_id: ListenerId,
457        socket: UdpSocket,
458        endpoint: quinn::Endpoint,
459        handshake_timeout: Duration,
460        version: ProtocolVersion,
461    ) -> Result<Self, Error> {
462        let if_watcher;
463        let pending_event;
464        let mut listening_addresses = HashSet::new();
465        let local_addr = socket.local_addr()?;
466        if local_addr.ip().is_unspecified() {
467            if_watcher = Some(P::new_if_watcher()?);
468            pending_event = None;
469        } else {
470            if_watcher = None;
471            listening_addresses.insert(local_addr.ip());
472            let ma = socketaddr_to_multiaddr(&local_addr, version);
473            pending_event = Some(TransportEvent::NewAddress {
474                listener_id,
475                listen_addr: ma,
476            })
477        }
478
479        let endpoint_c = endpoint.clone();
480        let accept = async move { endpoint_c.accept().await }.boxed();
481
482        Ok(Listener {
483            endpoint,
484            socket,
485            accept,
486            listener_id,
487            version,
488            handshake_timeout,
489            if_watcher,
490            is_closed: false,
491            pending_event,
492            close_listener_waker: None,
493            listening_addresses,
494        })
495    }
496
497    /// Report the listener as closed in a [`TransportEvent::ListenerClosed`] and
498    /// terminate the stream.
499    fn close(&mut self, reason: Result<(), Error>) {
500        if self.is_closed {
501            return;
502        }
503        self.endpoint.close(From::from(0u32), &[]);
504        self.pending_event = Some(TransportEvent::ListenerClosed {
505            listener_id: self.listener_id,
506            reason,
507        });
508        self.is_closed = true;
509
510        // Wake the stream to deliver the last event.
511        if let Some(waker) = self.close_listener_waker.take() {
512            waker.wake();
513        }
514    }
515
516    /// Clone underlying socket (for hole punching).
517    fn try_clone_socket(&self) -> std::io::Result<UdpSocket> {
518        self.socket.try_clone()
519    }
520
521    fn socket_addr(&self) -> SocketAddr {
522        self.socket
523            .local_addr()
524            .expect("Cannot fail because the socket is bound")
525    }
526
527    /// Poll for a next If Event.
528    fn poll_if_addr(&mut self, cx: &mut Context<'_>) -> Poll<<Self as Stream>::Item> {
529        let endpoint_addr = self.socket_addr();
530        let Some(if_watcher) = self.if_watcher.as_mut() else {
531            return Poll::Pending;
532        };
533        loop {
534            match ready!(P::poll_if_event(if_watcher, cx)) {
535                Ok(IfEvent::Up(inet)) => {
536                    if let Some(listen_addr) =
537                        ip_to_listenaddr(&endpoint_addr, inet.addr(), self.version)
538                    {
539                        tracing::debug!(
540                            address=%listen_addr,
541                            "New listen address"
542                        );
543                        self.listening_addresses.insert(inet.addr());
544                        return Poll::Ready(TransportEvent::NewAddress {
545                            listener_id: self.listener_id,
546                            listen_addr,
547                        });
548                    }
549                }
550                Ok(IfEvent::Down(inet)) => {
551                    if let Some(listen_addr) =
552                        ip_to_listenaddr(&endpoint_addr, inet.addr(), self.version)
553                    {
554                        tracing::debug!(
555                            address=%listen_addr,
556                            "Expired listen address"
557                        );
558                        self.listening_addresses.remove(&inet.addr());
559                        return Poll::Ready(TransportEvent::AddressExpired {
560                            listener_id: self.listener_id,
561                            listen_addr,
562                        });
563                    }
564                }
565                Err(err) => {
566                    return Poll::Ready(TransportEvent::ListenerError {
567                        listener_id: self.listener_id,
568                        error: err.into(),
569                    })
570                }
571            }
572        }
573    }
574}
575
576impl<P: Provider> Stream for Listener<P> {
577    type Item = TransportEvent<<GenTransport<P> as Transport>::ListenerUpgrade, Error>;
578    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
579        loop {
580            if let Some(event) = self.pending_event.take() {
581                return Poll::Ready(Some(event));
582            }
583            if self.is_closed {
584                return Poll::Ready(None);
585            }
586            if let Poll::Ready(event) = self.poll_if_addr(cx) {
587                return Poll::Ready(Some(event));
588            }
589
590            match self.accept.poll_unpin(cx) {
591                Poll::Ready(Some(incoming)) => {
592                    let endpoint = self.endpoint.clone();
593                    self.accept = async move { endpoint.accept().await }.boxed();
594
595                    let connecting = match incoming.accept() {
596                        Ok(connecting) => connecting,
597                        Err(error) => {
598                            return Poll::Ready(Some(TransportEvent::ListenerError {
599                                listener_id: self.listener_id,
600                                error: Error::Connection(crate::ConnectionError(error)),
601                            }))
602                        }
603                    };
604
605                    let local_addr = socketaddr_to_multiaddr(&self.socket_addr(), self.version);
606                    let remote_addr = connecting.remote_address();
607                    let send_back_addr = socketaddr_to_multiaddr(&remote_addr, self.version);
608
609                    let event = TransportEvent::Incoming {
610                        upgrade: Connecting::new(connecting, self.handshake_timeout),
611                        local_addr,
612                        send_back_addr,
613                        listener_id: self.listener_id,
614                    };
615                    return Poll::Ready(Some(event));
616                }
617                Poll::Ready(None) => {
618                    self.close(Ok(()));
619                    continue;
620                }
621                Poll::Pending => {}
622            };
623
624            self.close_listener_waker = Some(cx.waker().clone());
625
626            return Poll::Pending;
627        }
628    }
629}
630
631impl<P: Provider> fmt::Debug for Listener<P> {
632    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
633        f.debug_struct("Listener")
634            .field("listener_id", &self.listener_id)
635            .field("handshake_timeout", &self.handshake_timeout)
636            .field("is_closed", &self.is_closed)
637            .field("pending_event", &self.pending_event)
638            .finish()
639    }
640}
641
642#[derive(Debug, Clone, Copy, PartialEq, Eq)]
643pub(crate) enum ProtocolVersion {
644    V1, // i.e. RFC9000
645    Draft29,
646}
647
648#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
649pub(crate) enum SocketFamily {
650    Ipv4,
651    Ipv6,
652}
653
654impl SocketFamily {
655    fn is_same(a: &IpAddr, b: &IpAddr) -> bool {
656        matches!(
657            (a, b),
658            (IpAddr::V4(_), IpAddr::V4(_)) | (IpAddr::V6(_), IpAddr::V6(_))
659        )
660    }
661}
662
663impl From<IpAddr> for SocketFamily {
664    fn from(ip: IpAddr) -> Self {
665        match ip {
666            IpAddr::V4(_) => SocketFamily::Ipv4,
667            IpAddr::V6(_) => SocketFamily::Ipv6,
668        }
669    }
670}
671
672/// Turn an [`IpAddr`] reported by the interface watcher into a
673/// listen-address for the endpoint.
674///
675/// For this, the `ip` is combined with the port that the endpoint
676/// is actually bound.
677///
678/// Returns `None` if the `ip` is not the same socket family as the
679/// address that the endpoint is bound to.
680fn ip_to_listenaddr(
681    endpoint_addr: &SocketAddr,
682    ip: IpAddr,
683    version: ProtocolVersion,
684) -> Option<Multiaddr> {
685    // True if either both addresses are Ipv4 or both Ipv6.
686    if !SocketFamily::is_same(&endpoint_addr.ip(), &ip) {
687        return None;
688    }
689    let socket_addr = SocketAddr::new(ip, endpoint_addr.port());
690    Some(socketaddr_to_multiaddr(&socket_addr, version))
691}
692
693/// Tries to turn a QUIC multiaddress into a UDP [`SocketAddr`]. Returns None if the format
694/// of the multiaddr is wrong.
695fn multiaddr_to_socketaddr(
696    addr: &Multiaddr,
697    support_draft_29: bool,
698) -> Option<(SocketAddr, ProtocolVersion, Option<PeerId>)> {
699    let mut iter = addr.iter();
700    let proto1 = iter.next()?;
701    let proto2 = iter.next()?;
702    let proto3 = iter.next()?;
703
704    let mut peer_id = None;
705    for proto in iter {
706        match proto {
707            Protocol::P2p(id) => {
708                peer_id = Some(id);
709            }
710            _ => return None,
711        }
712    }
713    let version = match proto3 {
714        Protocol::QuicV1 => ProtocolVersion::V1,
715        Protocol::Quic if support_draft_29 => ProtocolVersion::Draft29,
716        _ => return None,
717    };
718
719    match (proto1, proto2) {
720        (Protocol::Ip4(ip), Protocol::Udp(port)) => {
721            Some((SocketAddr::new(ip.into(), port), version, peer_id))
722        }
723        (Protocol::Ip6(ip), Protocol::Udp(port)) => {
724            Some((SocketAddr::new(ip.into(), port), version, peer_id))
725        }
726        _ => None,
727    }
728}
729
730/// Turns an IP address and port into the corresponding QUIC multiaddr.
731fn socketaddr_to_multiaddr(socket_addr: &SocketAddr, version: ProtocolVersion) -> Multiaddr {
732    let quic_proto = match version {
733        ProtocolVersion::V1 => Protocol::QuicV1,
734        ProtocolVersion::Draft29 => Protocol::Quic,
735    };
736    Multiaddr::empty()
737        .with(socket_addr.ip().into())
738        .with(Protocol::Udp(socket_addr.port()))
739        .with(quic_proto)
740}
741
742#[cfg(test)]
743#[cfg(feature = "tokio")]
744mod tests {
745    use futures::future::poll_fn;
746
747    use super::*;
748
749    #[test]
750    fn multiaddr_to_udp_conversion() {
751        assert!(multiaddr_to_socketaddr(
752            &"/ip4/127.0.0.1/udp/1234".parse::<Multiaddr>().unwrap(),
753            true
754        )
755        .is_none());
756
757        assert_eq!(
758            multiaddr_to_socketaddr(
759                &"/ip4/127.0.0.1/udp/12345/quic-v1"
760                    .parse::<Multiaddr>()
761                    .unwrap(),
762                false
763            ),
764            Some((
765                SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 12345,),
766                ProtocolVersion::V1,
767                None
768            ))
769        );
770        assert_eq!(
771            multiaddr_to_socketaddr(
772                &"/ip4/255.255.255.255/udp/8080/quic-v1"
773                    .parse::<Multiaddr>()
774                    .unwrap(),
775                false
776            ),
777            Some((
778                SocketAddr::new(IpAddr::V4(Ipv4Addr::new(255, 255, 255, 255)), 8080,),
779                ProtocolVersion::V1,
780                None
781            ))
782        );
783        assert_eq!(
784            multiaddr_to_socketaddr(
785                &"/ip4/127.0.0.1/udp/55148/quic-v1/p2p/12D3KooW9xk7Zp1gejwfwNpfm6L9zH5NL4Bx5rm94LRYJJHJuARZ"
786                    .parse::<Multiaddr>()
787                    .unwrap(), false
788            ),
789            Some((SocketAddr::new(
790                IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)),
791                55148,
792            ), ProtocolVersion::V1, Some("12D3KooW9xk7Zp1gejwfwNpfm6L9zH5NL4Bx5rm94LRYJJHJuARZ".parse().unwrap())))
793        );
794        assert_eq!(
795            multiaddr_to_socketaddr(
796                &"/ip6/::1/udp/12345/quic-v1".parse::<Multiaddr>().unwrap(),
797                false
798            ),
799            Some((
800                SocketAddr::new(IpAddr::V6(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1)), 12345,),
801                ProtocolVersion::V1,
802                None
803            ))
804        );
805        assert_eq!(
806            multiaddr_to_socketaddr(
807                &"/ip6/ffff:ffff:ffff:ffff:ffff:ffff:ffff:ffff/udp/8080/quic-v1"
808                    .parse::<Multiaddr>()
809                    .unwrap(),
810                false
811            ),
812            Some((
813                SocketAddr::new(
814                    IpAddr::V6(Ipv6Addr::new(
815                        65535, 65535, 65535, 65535, 65535, 65535, 65535, 65535,
816                    )),
817                    8080,
818                ),
819                ProtocolVersion::V1,
820                None
821            ))
822        );
823
824        assert!(multiaddr_to_socketaddr(
825            &"/ip4/127.0.0.1/udp/1234/quic".parse::<Multiaddr>().unwrap(),
826            false
827        )
828        .is_none());
829        assert_eq!(
830            multiaddr_to_socketaddr(
831                &"/ip4/127.0.0.1/udp/1234/quic".parse::<Multiaddr>().unwrap(),
832                true
833            ),
834            Some((
835                SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 1234,),
836                ProtocolVersion::Draft29,
837                None
838            ))
839        );
840    }
841
842    #[cfg(feature = "tokio")]
843    #[tokio::test]
844    async fn test_close_listener() {
845        let keypair = libp2p_identity::Keypair::generate_ed25519();
846        let config = Config::new(&keypair);
847        let mut transport = crate::tokio::Transport::new(config);
848        assert!(poll_fn(|cx| Pin::new(&mut transport).as_mut().poll(cx))
849            .now_or_never()
850            .is_none());
851
852        // Run test twice to check that there is no unexpected behaviour if `Transport.listener`
853        // is temporarily empty.
854        for _ in 0..2 {
855            let id = ListenerId::next();
856            transport
857                .listen_on(id, "/ip4/0.0.0.0/udp/0/quic-v1".parse().unwrap())
858                .unwrap();
859
860            match poll_fn(|cx| Pin::new(&mut transport).as_mut().poll(cx)).await {
861                TransportEvent::NewAddress {
862                    listener_id,
863                    listen_addr,
864                } => {
865                    assert_eq!(listener_id, id);
866                    assert!(
867                        matches!(listen_addr.iter().next(), Some(Protocol::Ip4(a)) if !a.is_unspecified())
868                    );
869                    assert!(
870                        matches!(listen_addr.iter().nth(1), Some(Protocol::Udp(port)) if port != 0)
871                    );
872                    assert!(matches!(listen_addr.iter().nth(2), Some(Protocol::QuicV1)));
873                }
874                e => panic!("Unexpected event: {e:?}"),
875            }
876            assert!(transport.remove_listener(id), "Expect listener to exist.");
877            match poll_fn(|cx| Pin::new(&mut transport).as_mut().poll(cx)).await {
878                TransportEvent::ListenerClosed {
879                    listener_id,
880                    reason: Ok(()),
881                } => {
882                    assert_eq!(listener_id, id);
883                }
884                e => panic!("Unexpected event: {e:?}"),
885            }
886            // Poll once again so that the listener has the chance to return `Poll::Ready(None)` and
887            // be removed from the list of listeners.
888            assert!(poll_fn(|cx| Pin::new(&mut transport).as_mut().poll(cx))
889                .now_or_never()
890                .is_none());
891            assert!(transport.listeners.is_empty());
892        }
893    }
894
895    #[cfg(feature = "tokio")]
896    #[tokio::test]
897    async fn test_dialer_drop() {
898        let keypair = libp2p_identity::Keypair::generate_ed25519();
899        let config = Config::new(&keypair);
900        let mut transport = crate::tokio::Transport::new(config);
901
902        let _dial = transport
903            .dial(
904                "/ip4/123.45.67.8/udp/1234/quic-v1".parse().unwrap(),
905                DialOpts {
906                    role: Endpoint::Dialer,
907                    port_use: PortUse::Reuse,
908                },
909            )
910            .unwrap();
911
912        assert!(transport.dialer.contains_key(&SocketFamily::Ipv4));
913        assert!(!transport.dialer.contains_key(&SocketFamily::Ipv6));
914
915        // Start listening so that the dialer and driver are dropped.
916        transport
917            .listen_on(
918                ListenerId::next(),
919                "/ip4/0.0.0.0/udp/0/quic-v1".parse().unwrap(),
920            )
921            .unwrap();
922        assert!(!transport.dialer.contains_key(&SocketFamily::Ipv4));
923    }
924
925    #[cfg(feature = "tokio")]
926    #[tokio::test]
927    async fn test_listens_ipv4_ipv6_separately() {
928        let keypair = libp2p_identity::Keypair::generate_ed25519();
929        let config = Config::new(&keypair);
930        let mut transport = crate::tokio::Transport::new(config);
931        let port = {
932            let socket = UdpSocket::bind("127.0.0.1:0").unwrap();
933            socket.local_addr().unwrap().port()
934        };
935
936        transport
937            .listen_on(
938                ListenerId::next(),
939                format!("/ip4/0.0.0.0/udp/{port}/quic-v1").parse().unwrap(),
940            )
941            .unwrap();
942        transport
943            .listen_on(
944                ListenerId::next(),
945                format!("/ip6/::/udp/{port}/quic-v1").parse().unwrap(),
946            )
947            .unwrap();
948    }
949}