hole_punching_tests/
main.rs

1// Copyright 2023 Protocol Labs.
2//
3// Permission is hereby granted, free of charge, to any person obtaining a
4// copy of this software and associated documentation files (the "Software"),
5// to deal in the Software without restriction, including without limitation
6// the rights to use, copy, modify, merge, publish, distribute, sublicense,
7// and/or sell copies of the Software, and to permit persons to whom the
8// Software is furnished to do so, subject to the following conditions:
9//
10// The above copyright notice and this permission notice shall be included in
11// all copies or substantial portions of the Software.
12//
13// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
14// OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
16// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
18// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
19// DEALINGS IN THE SOFTWARE.
20
21use std::{
22    collections::HashMap,
23    fmt, io,
24    net::{IpAddr, Ipv4Addr},
25    str::FromStr,
26    time::Duration,
27};
28
29use anyhow::{Context, Result};
30use either::Either;
31use futures::stream::StreamExt;
32use libp2p::{
33    core::{
34        multiaddr::{Multiaddr, Protocol},
35        transport::ListenerId,
36    },
37    dcutr, identify, noise, ping, relay,
38    swarm::{dial_opts::DialOpts, ConnectionId, NetworkBehaviour, SwarmEvent},
39    tcp, yamux, Swarm,
40};
41use redis::AsyncCommands;
42
43/// The redis key we push the relay's TCP listen address to.
44const RELAY_TCP_ADDRESS: &str = "RELAY_TCP_ADDRESS";
45/// The redis key we push the relay's QUIC listen address to.
46const RELAY_QUIC_ADDRESS: &str = "RELAY_QUIC_ADDRESS";
47/// The redis key we push the listen client's PeerId to.
48const LISTEN_CLIENT_PEER_ID: &str = "LISTEN_CLIENT_PEER_ID";
49
50#[tokio::main]
51async fn main() -> Result<()> {
52    env_logger::builder()
53        .parse_filters("debug,netlink_proto=warn,rustls=warn,multistream_select=warn,libp2p_core::transport::choice=off,libp2p_swarm::connection=warn,libp2p_quic=trace")
54        .parse_default_env()
55        .init();
56
57    let mode = get_env("MODE")?;
58    let transport = get_env("TRANSPORT")?;
59
60    let mut redis = RedisClient::new("redis", 6379).await?;
61
62    let relay_addr = match transport {
63        TransportProtocol::Tcp => redis.pop::<Multiaddr>(RELAY_TCP_ADDRESS).await?,
64        TransportProtocol::Quic => redis.pop::<Multiaddr>(RELAY_QUIC_ADDRESS).await?,
65    };
66
67    let mut swarm = libp2p::SwarmBuilder::with_new_identity()
68        .with_tokio()
69        .with_tcp(
70            tcp::Config::new().nodelay(true),
71            noise::Config::new,
72            yamux::Config::default,
73        )?
74        .with_quic()
75        .with_relay_client(noise::Config::new, yamux::Config::default)?
76        .with_behaviour(|key, relay_client| {
77            Ok(Behaviour {
78                relay_client,
79                identify: identify::Behaviour::new(identify::Config::new(
80                    "/hole-punch-tests/1".to_owned(),
81                    key.public(),
82                )),
83                dcutr: dcutr::Behaviour::new(key.public().to_peer_id()),
84                ping: ping::Behaviour::new(
85                    ping::Config::default().with_interval(Duration::from_secs(1)),
86                ),
87            })
88        })?
89        .with_swarm_config(|c| c.with_idle_connection_timeout(Duration::from_secs(60)))
90        .build();
91
92    client_listen_on_transport(&mut swarm, transport).await?;
93    let id = client_setup(&mut swarm, &mut redis, relay_addr.clone(), mode).await?;
94
95    let mut hole_punched_peer_connection = None;
96
97    loop {
98        match (
99            swarm.next().await.unwrap(),
100            hole_punched_peer_connection,
101            id,
102        ) {
103            (
104                SwarmEvent::Behaviour(BehaviourEvent::RelayClient(
105                    relay::client::Event::ReservationReqAccepted { .. },
106                )),
107                _,
108                _,
109            ) => {
110                tracing::info!("Relay accepted our reservation request.");
111
112                redis
113                    .push(LISTEN_CLIENT_PEER_ID, swarm.local_peer_id())
114                    .await?;
115            }
116            (
117                SwarmEvent::Behaviour(BehaviourEvent::Dcutr(dcutr::Event {
118                    remote_peer_id,
119                    result: Ok(connection_id),
120                })),
121                _,
122                _,
123            ) => {
124                tracing::info!("Successfully hole-punched to {remote_peer_id}");
125
126                hole_punched_peer_connection = Some(connection_id);
127            }
128            (
129                SwarmEvent::Behaviour(BehaviourEvent::Ping(ping::Event {
130                    connection,
131                    result: Ok(rtt),
132                    ..
133                })),
134                Some(hole_punched_connection),
135                _,
136            ) if mode == Mode::Dial && connection == hole_punched_connection => {
137                println!("{}", serde_json::to_string(&Report::new(rtt))?);
138
139                return Ok(());
140            }
141            (
142                SwarmEvent::Behaviour(BehaviourEvent::Dcutr(dcutr::Event {
143                    remote_peer_id,
144                    result: Err(error),
145                    ..
146                })),
147                _,
148                _,
149            ) => {
150                tracing::info!("Failed to hole-punched to {remote_peer_id}");
151                return Err(anyhow::Error::new(error));
152            }
153            (
154                SwarmEvent::ListenerClosed {
155                    listener_id,
156                    reason: Err(e),
157                    ..
158                },
159                _,
160                Either::Left(reservation),
161            ) if listener_id == reservation => {
162                anyhow::bail!("Reservation on relay failed: {e}");
163            }
164            (
165                SwarmEvent::OutgoingConnectionError {
166                    connection_id,
167                    error,
168                    ..
169                },
170                _,
171                Either::Right(circuit),
172            ) if connection_id == circuit => {
173                anyhow::bail!("Circuit request relay failed: {error}");
174            }
175            _ => {}
176        }
177    }
178}
179
180#[derive(serde::Serialize)]
181struct Report {
182    rtt_to_holepunched_peer_millis: u128,
183}
184
185impl Report {
186    fn new(rtt: Duration) -> Self {
187        Self {
188            rtt_to_holepunched_peer_millis: rtt.as_millis(),
189        }
190    }
191}
192
193fn get_env<T>(key: &'static str) -> Result<T>
194where
195    T: FromStr,
196    T::Err: std::error::Error + Send + Sync + 'static,
197{
198    let val = std::env::var(key)
199        .with_context(|| format!("Missing env var `{key}`"))?
200        .parse()
201        .with_context(|| format!("Failed to parse `{key}`)"))?;
202
203    Ok(val)
204}
205
206async fn client_listen_on_transport(
207    swarm: &mut Swarm<Behaviour>,
208    transport: TransportProtocol,
209) -> Result<()> {
210    let listen_addr = match transport {
211        TransportProtocol::Tcp => tcp_addr(Ipv4Addr::UNSPECIFIED.into()),
212        TransportProtocol::Quic => quic_addr(Ipv4Addr::UNSPECIFIED.into()),
213    };
214    let expected_listener_id = swarm
215        .listen_on(listen_addr)
216        .context("Failed to listen on address")?;
217
218    let mut listen_addresses = 0;
219
220    // We should have at least two listen addresses, one for localhost and the actual interface.
221    while listen_addresses < 2 {
222        if let SwarmEvent::NewListenAddr {
223            listener_id,
224            address,
225        } = swarm.next().await.unwrap()
226        {
227            if listener_id == expected_listener_id {
228                listen_addresses += 1;
229            }
230
231            tracing::info!("Listening on {address}");
232        }
233    }
234    Ok(())
235}
236
237async fn client_setup(
238    swarm: &mut Swarm<Behaviour>,
239    redis: &mut RedisClient,
240    relay_addr: Multiaddr,
241    mode: Mode,
242) -> Result<Either<ListenerId, ConnectionId>> {
243    let either = match mode {
244        Mode::Listen => {
245            let id = swarm.listen_on(relay_addr.with(Protocol::P2pCircuit))?;
246
247            Either::Left(id)
248        }
249        Mode::Dial => {
250            let remote_peer_id = redis.pop(LISTEN_CLIENT_PEER_ID).await?;
251
252            let opts = DialOpts::from(
253                relay_addr
254                    .with(Protocol::P2pCircuit)
255                    .with(Protocol::P2p(remote_peer_id)),
256            );
257            let id = opts.connection_id();
258
259            swarm.dial(opts)?;
260
261            Either::Right(id)
262        }
263    };
264
265    Ok(either)
266}
267
268fn tcp_addr(addr: IpAddr) -> Multiaddr {
269    Multiaddr::empty().with(addr.into()).with(Protocol::Tcp(0))
270}
271
272fn quic_addr(addr: IpAddr) -> Multiaddr {
273    Multiaddr::empty()
274        .with(addr.into())
275        .with(Protocol::Udp(0))
276        .with(Protocol::QuicV1)
277}
278
279struct RedisClient {
280    inner: redis::aio::Connection,
281}
282
283impl RedisClient {
284    async fn new(host: &str, port: u16) -> Result<Self> {
285        let client = redis::Client::open(format!("redis://{host}:{port}/"))
286            .context("Bad redis server URL")?;
287        let connection = client
288            .get_async_connection()
289            .await
290            .context("Failed to connect to redis server")?;
291
292        Ok(Self { inner: connection })
293    }
294
295    async fn push(&mut self, key: &str, value: impl ToString) -> Result<()> {
296        let value = value.to_string();
297
298        tracing::debug!("Pushing {key}={value} to redis");
299
300        self.inner.rpush(key, value).await.map_err(Into::into)
301    }
302
303    async fn pop<V>(&mut self, key: &str) -> Result<V>
304    where
305        V: FromStr + fmt::Display,
306        V::Err: std::error::Error + Send + Sync + 'static,
307    {
308        tracing::debug!("Fetching {key} from redis");
309
310        let value = self
311            .inner
312            .blpop::<_, HashMap<String, String>>(key, 0.0)
313            .await?
314            .remove(key)
315            .with_context(|| format!("Failed to get value for {key} from redis"))?
316            .parse()?;
317
318        tracing::debug!("{key}={value}");
319
320        Ok(value)
321    }
322}
323
324#[derive(Clone, Copy, Debug, PartialEq)]
325enum TransportProtocol {
326    Tcp,
327    Quic,
328}
329
330impl FromStr for TransportProtocol {
331    type Err = io::Error;
332    fn from_str(mode: &str) -> Result<Self, Self::Err> {
333        match mode {
334            "tcp" => Ok(TransportProtocol::Tcp),
335            "quic" => Ok(TransportProtocol::Quic),
336            _ => Err(io::Error::new(
337                io::ErrorKind::Other,
338                "Expected either 'tcp' or 'quic'",
339            )),
340        }
341    }
342}
343
344#[derive(Clone, Copy, Debug, PartialEq)]
345enum Mode {
346    Dial,
347    Listen,
348}
349
350impl FromStr for Mode {
351    type Err = io::Error;
352    fn from_str(mode: &str) -> Result<Self, Self::Err> {
353        match mode {
354            "dial" => Ok(Mode::Dial),
355            "listen" => Ok(Mode::Listen),
356            _ => Err(io::Error::new(
357                io::ErrorKind::Other,
358                "Expected either 'dial' or 'listen'",
359            )),
360        }
361    }
362}
363
364#[derive(NetworkBehaviour)]
365struct Behaviour {
366    relay_client: relay::client::Behaviour,
367    identify: identify::Behaviour,
368    dcutr: dcutr::Behaviour,
369    ping: ping::Behaviour,
370}