libp2p_webrtc/tokio/
udp_mux.rs

1// Copyright 2022 Parity Technologies (UK) Ltd.
2//
3// Permission is hereby granted, free of charge, to any person obtaining a
4// copy of this software and associated documentation files (the "Software"),
5// to deal in the Software without restriction, including without limitation
6// the rights to use, copy, modify, merge, publish, distribute, sublicense,
7// and/or sell copies of the Software, and to permit persons to whom the
8// Software is furnished to do so, subject to the following conditions:
9//
10// The above copyright notice and this permission notice shall be included in
11// all copies or substantial portions of the Software.
12//
13// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
14// OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
16// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
18// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
19// DEALINGS IN THE SOFTWARE.
20
21use std::{
22    collections::{HashMap, HashSet},
23    io,
24    io::ErrorKind,
25    net::SocketAddr,
26    sync::Arc,
27    task::{Context, Poll},
28};
29
30use async_trait::async_trait;
31use futures::{
32    channel::oneshot,
33    future::{BoxFuture, FutureExt, OptionFuture},
34    stream::FuturesUnordered,
35    StreamExt,
36};
37use stun::{
38    attributes::ATTR_USERNAME,
39    message::{is_message as is_stun_message, Message as STUNMessage},
40};
41use thiserror::Error;
42use tokio::{io::ReadBuf, net::UdpSocket};
43use webrtc::{
44    ice::udp_mux::{UDPMux, UDPMuxConn, UDPMuxConnParams, UDPMuxWriter},
45    util::{Conn, Error},
46};
47
48use crate::tokio::req_res_chan;
49
50const RECEIVE_MTU: usize = 8192;
51
52/// A previously unseen address of a remote which has sent us an ICE binding request.
53#[derive(Debug)]
54pub(crate) struct NewAddr {
55    pub(crate) addr: SocketAddr,
56    pub(crate) ufrag: String,
57}
58
59/// An event emitted by [`UDPMuxNewAddr`] when it's polled.
60#[derive(Debug)]
61pub(crate) enum UDPMuxEvent {
62    /// Connection error. UDP mux should be stopped.
63    Error(std::io::Error),
64    /// Got a [`NewAddr`] from the socket.
65    NewAddr(NewAddr),
66}
67
68/// A modified version of [`webrtc::ice::udp_mux::UDPMuxDefault`].
69///
70/// - It has been rewritten to work without locks and channels instead.
71/// - It reports previously unseen addresses instead of ignoring them.
72pub(crate) struct UDPMuxNewAddr {
73    udp_sock: UdpSocket,
74
75    listen_addr: SocketAddr,
76
77    /// Maps from ufrag to the underlying connection.
78    conns: HashMap<String, UDPMuxConn>,
79
80    /// Maps from socket address to the underlying connection.
81    address_map: HashMap<SocketAddr, UDPMuxConn>,
82
83    /// Set of the new addresses to avoid sending the same address multiple times.
84    new_addrs: HashSet<SocketAddr>,
85
86    /// `true` when UDP mux is closed.
87    is_closed: bool,
88
89    send_buffer: Option<(Vec<u8>, SocketAddr, oneshot::Sender<Result<usize, Error>>)>,
90
91    close_futures: FuturesUnordered<BoxFuture<'static, ()>>,
92    write_future: OptionFuture<BoxFuture<'static, ()>>,
93
94    close_command: req_res_chan::Receiver<(), Result<(), Error>>,
95    get_conn_command: req_res_chan::Receiver<String, Result<Arc<dyn Conn + Send + Sync>, Error>>,
96    remove_conn_command: req_res_chan::Receiver<String, ()>,
97    registration_command: req_res_chan::Receiver<(UDPMuxConn, SocketAddr), ()>,
98    send_command: req_res_chan::Receiver<(Vec<u8>, SocketAddr), Result<usize, Error>>,
99
100    udp_mux_handle: Arc<UdpMuxHandle>,
101    udp_mux_writer_handle: Arc<UdpMuxWriterHandle>,
102}
103
104impl UDPMuxNewAddr {
105    pub(crate) fn listen_on(addr: SocketAddr) -> Result<Self, io::Error> {
106        let std_sock = std::net::UdpSocket::bind(addr)?;
107        std_sock.set_nonblocking(true)?;
108
109        let tokio_socket = UdpSocket::from_std(std_sock)?;
110        let listen_addr = tokio_socket.local_addr()?;
111
112        let (udp_mux_handle, close_command, get_conn_command, remove_conn_command) =
113            UdpMuxHandle::new();
114        let (udp_mux_writer_handle, registration_command, send_command) = UdpMuxWriterHandle::new();
115
116        Ok(Self {
117            udp_sock: tokio_socket,
118            listen_addr,
119            conns: HashMap::default(),
120            address_map: HashMap::default(),
121            new_addrs: HashSet::default(),
122            is_closed: false,
123            send_buffer: None,
124            close_futures: FuturesUnordered::default(),
125            write_future: OptionFuture::default(),
126            close_command,
127            get_conn_command,
128            remove_conn_command,
129            registration_command,
130            send_command,
131            udp_mux_handle: Arc::new(udp_mux_handle),
132            udp_mux_writer_handle: Arc::new(udp_mux_writer_handle),
133        })
134    }
135
136    pub(crate) fn listen_addr(&self) -> SocketAddr {
137        self.listen_addr
138    }
139
140    pub(crate) fn udp_mux_handle(&self) -> Arc<UdpMuxHandle> {
141        self.udp_mux_handle.clone()
142    }
143
144    /// Create a muxed connection for a given ufrag.
145    fn create_muxed_conn(&self, ufrag: &str) -> Result<UDPMuxConn, Error> {
146        let local_addr = self.udp_sock.local_addr()?;
147
148        let params = UDPMuxConnParams {
149            local_addr,
150            key: ufrag.into(),
151            udp_mux: Arc::downgrade(
152                &(self.udp_mux_writer_handle.clone() as Arc<dyn UDPMuxWriter + Send + Sync>),
153            ),
154        };
155
156        Ok(UDPMuxConn::new(params))
157    }
158
159    /// Returns a muxed connection if the `ufrag` from the given STUN message matches an existing
160    /// connection.
161    fn conn_from_stun_message(
162        &self,
163        buffer: &[u8],
164        addr: &SocketAddr,
165    ) -> Option<Result<UDPMuxConn, ConnQueryError>> {
166        match ufrag_from_stun_message(buffer, true) {
167            Ok(ufrag) => {
168                if let Some(conn) = self.conns.get(&ufrag) {
169                    let associated_addrs = conn.get_addresses();
170                    // This basically ensures only one address is registered per ufrag.
171                    if associated_addrs.is_empty() || associated_addrs.contains(addr) {
172                        return Some(Ok(conn.clone()));
173                    } else {
174                        return Some(Err(ConnQueryError::UfragAlreadyTaken { associated_addrs }));
175                    }
176                }
177                None
178            }
179            Err(e) => {
180                tracing::debug!(address=%addr, "{}", e);
181                None
182            }
183        }
184    }
185
186    /// Reads from the underlying UDP socket and either reports a new address or proxies data to the
187    /// muxed connection.
188    pub(crate) fn poll(&mut self, cx: &mut Context) -> Poll<UDPMuxEvent> {
189        let mut recv_buf = [0u8; RECEIVE_MTU];
190
191        loop {
192            // => Send data to target
193            match self.send_buffer.take() {
194                None => {
195                    if let Poll::Ready(Some(((buf, target), response))) =
196                        self.send_command.poll_next_unpin(cx)
197                    {
198                        self.send_buffer = Some((buf, target, response));
199                        continue;
200                    }
201                }
202                Some((buf, target, response)) => {
203                    match self.udp_sock.poll_send_to(cx, &buf, target) {
204                        Poll::Ready(result) => {
205                            let _ = response.send(result.map_err(|e| Error::Io(e.into())));
206                            continue;
207                        }
208                        Poll::Pending => {
209                            self.send_buffer = Some((buf, target, response));
210                        }
211                    }
212                }
213            }
214
215            // => Register a new connection
216            if let Poll::Ready(Some(((conn, addr), response))) =
217                self.registration_command.poll_next_unpin(cx)
218            {
219                let key = conn.key();
220
221                self.address_map
222                    .entry(addr)
223                    .and_modify(|e| {
224                        if e.key() != key {
225                            e.remove_address(&addr);
226                            *e = conn.clone();
227                        }
228                    })
229                    .or_insert_with(|| conn.clone());
230
231                // remove addr from new_addrs once conn is established
232                self.new_addrs.remove(&addr);
233
234                let _ = response.send(());
235
236                continue;
237            }
238
239            // => Get connection with the given ufrag
240            if let Poll::Ready(Some((ufrag, response))) = self.get_conn_command.poll_next_unpin(cx)
241            {
242                if self.is_closed {
243                    let _ = response.send(Err(Error::ErrUseClosedNetworkConn));
244                    continue;
245                }
246
247                if let Some(conn) = self.conns.get(&ufrag).cloned() {
248                    let _ = response.send(Ok(Arc::new(conn)));
249                    continue;
250                }
251
252                let muxed_conn = match self.create_muxed_conn(&ufrag) {
253                    Ok(conn) => conn,
254                    Err(e) => {
255                        let _ = response.send(Err(e));
256                        continue;
257                    }
258                };
259                let mut close_rx = muxed_conn.close_rx();
260
261                self.close_futures.push({
262                    let ufrag = ufrag.clone();
263                    let udp_mux_handle = self.udp_mux_handle.clone();
264
265                    Box::pin(async move {
266                        let _ = close_rx.changed().await;
267                        udp_mux_handle.remove_conn_by_ufrag(&ufrag).await;
268                    })
269                });
270
271                self.conns.insert(ufrag, muxed_conn.clone());
272
273                let _ = response.send(Ok(Arc::new(muxed_conn) as Arc<dyn Conn + Send + Sync>));
274
275                continue;
276            }
277
278            // => Close UDPMux
279            if let Poll::Ready(Some(((), response))) = self.close_command.poll_next_unpin(cx) {
280                if self.is_closed {
281                    let _ = response.send(Err(Error::ErrAlreadyClosed));
282                    continue;
283                }
284
285                for (_, conn) in self.conns.drain() {
286                    conn.close();
287                }
288
289                // NOTE: This is important, we need to drop all instances of `UDPMuxConn` to
290                // avoid a retain cycle due to the use of [`std::sync::Arc`] on both sides.
291                self.address_map.clear();
292
293                // NOTE: This is important, we need to drop all instances of `UDPMuxConn` to
294                // avoid a retain cycle due to the use of [`std::sync::Arc`] on both sides.
295                self.new_addrs.clear();
296
297                let _ = response.send(Ok(()));
298
299                self.is_closed = true;
300
301                continue;
302            }
303
304            // => Remove connection with the given ufrag
305            if let Poll::Ready(Some((ufrag, response))) =
306                self.remove_conn_command.poll_next_unpin(cx)
307            {
308                // Pion's ice implementation has both `RemoveConnByFrag` and `RemoveConn`, but since
309                // `conns` is keyed on `ufrag` their implementation is equivalent.
310
311                if let Some(removed_conn) = self.conns.remove(&ufrag) {
312                    for address in removed_conn.get_addresses() {
313                        self.address_map.remove(&address);
314                    }
315                }
316
317                let _ = response.send(());
318
319                continue;
320            }
321
322            // => Remove closed connections
323            let _ = self.close_futures.poll_next_unpin(cx);
324
325            // => Write previously received data to local connections
326            match self.write_future.poll_unpin(cx) {
327                Poll::Ready(Some(())) => {
328                    self.write_future = OptionFuture::default();
329                    continue;
330                }
331                Poll::Ready(None) => {
332                    // => Read from the socket
333                    let mut read = ReadBuf::new(&mut recv_buf);
334
335                    match self.udp_sock.poll_recv_from(cx, &mut read) {
336                        Poll::Ready(Ok(addr)) => {
337                            // Find connection based on previously having seen this source address
338                            let conn = self.address_map.get(&addr);
339
340                            let conn = match conn {
341                                // If we couldn't find the connection based on source address, see
342                                // if this is a STUN message and if
343                                // so if we can find the connection based on ufrag.
344                                None if is_stun_message(read.filled()) => {
345                                    match self.conn_from_stun_message(read.filled(), &addr) {
346                                        Some(Ok(s)) => Some(s),
347                                        Some(Err(e)) => {
348                                            tracing::debug!(address=%&addr, "Error when querying existing connections: {}", e);
349                                            continue;
350                                        }
351                                        None => None,
352                                    }
353                                }
354                                Some(s) => Some(s.to_owned()),
355                                _ => None,
356                            };
357
358                            match conn {
359                                None => {
360                                    if !self.new_addrs.contains(&addr) {
361                                        match ufrag_from_stun_message(read.filled(), false) {
362                                            Ok(ufrag) => {
363                                                tracing::trace!(
364                                                    address=%&addr,
365                                                    %ufrag,
366                                                    "Notifying about new address from ufrag",
367                                                );
368                                                self.new_addrs.insert(addr);
369                                                return Poll::Ready(UDPMuxEvent::NewAddr(
370                                                    NewAddr { addr, ufrag },
371                                                ));
372                                            }
373                                            Err(e) => {
374                                                tracing::debug!(
375                                                    address=%&addr,
376                                                    "Unknown address (non STUN packet: {})",
377                                                    e
378                                                );
379                                            }
380                                        }
381                                    }
382                                }
383                                Some(conn) => {
384                                    let mut packet = vec![0u8; read.filled().len()];
385                                    packet.copy_from_slice(read.filled());
386                                    self.write_future = OptionFuture::from(Some(
387                                        async move {
388                                            if let Err(err) = conn.write_packet(&packet, addr).await
389                                            {
390                                                tracing::error!(
391                                                    address=%addr,
392                                                    "Failed to write packet: {}",
393                                                    err,
394                                                );
395                                            }
396                                        }
397                                        .boxed(),
398                                    ));
399                                }
400                            }
401
402                            continue;
403                        }
404                        Poll::Pending => {}
405                        Poll::Ready(Err(err)) if err.kind() == ErrorKind::TimedOut => {}
406                        Poll::Ready(Err(err)) if err.kind() == ErrorKind::ConnectionReset => {
407                            tracing::debug!("ConnectionReset by remote client {err:?}")
408                        }
409                        Poll::Ready(Err(err)) => {
410                            tracing::error!("Could not read udp packet: {}", err);
411                            return Poll::Ready(UDPMuxEvent::Error(err));
412                        }
413                    }
414                }
415                Poll::Pending => {}
416            }
417
418            return Poll::Pending;
419        }
420    }
421}
422
423/// Handle which utilizes [`req_res_chan`] to transmit commands (e.g. remove connection) from the
424/// WebRTC ICE agent to [`UDPMuxNewAddr::poll`].
425pub(crate) struct UdpMuxHandle {
426    close_sender: req_res_chan::Sender<(), Result<(), Error>>,
427    get_conn_sender: req_res_chan::Sender<String, Result<Arc<dyn Conn + Send + Sync>, Error>>,
428    remove_sender: req_res_chan::Sender<String, ()>,
429}
430
431impl UdpMuxHandle {
432    /// Returns a new `UdpMuxHandle` and `close`, `get_conn` and `remove` receivers.
433    pub(crate) fn new() -> (
434        Self,
435        req_res_chan::Receiver<(), Result<(), Error>>,
436        req_res_chan::Receiver<String, Result<Arc<dyn Conn + Send + Sync>, Error>>,
437        req_res_chan::Receiver<String, ()>,
438    ) {
439        let (sender1, receiver1) = req_res_chan::new(1);
440        let (sender2, receiver2) = req_res_chan::new(1);
441        let (sender3, receiver3) = req_res_chan::new(1);
442
443        let this = Self {
444            close_sender: sender1,
445            get_conn_sender: sender2,
446            remove_sender: sender3,
447        };
448
449        (this, receiver1, receiver2, receiver3)
450    }
451}
452
453#[async_trait]
454impl UDPMux for UdpMuxHandle {
455    async fn close(&self) -> Result<(), Error> {
456        self.close_sender
457            .send(())
458            .await
459            .map_err(|e| Error::Io(e.into()))??;
460
461        Ok(())
462    }
463
464    async fn get_conn(self: Arc<Self>, ufrag: &str) -> Result<Arc<dyn Conn + Send + Sync>, Error> {
465        let conn = self
466            .get_conn_sender
467            .send(ufrag.to_owned())
468            .await
469            .map_err(|e| Error::Io(e.into()))??;
470
471        Ok(conn)
472    }
473
474    async fn remove_conn_by_ufrag(&self, ufrag: &str) {
475        if let Err(e) = self.remove_sender.send(ufrag.to_owned()).await {
476            tracing::debug!("Failed to send message through channel: {:?}", e);
477        }
478    }
479}
480
481/// Handle which utilizes [`req_res_chan`] to transmit commands from [`UDPMuxConn`] connections to
482/// [`UDPMuxNewAddr::poll`].
483pub(crate) struct UdpMuxWriterHandle {
484    registration_channel: req_res_chan::Sender<(UDPMuxConn, SocketAddr), ()>,
485    send_channel: req_res_chan::Sender<(Vec<u8>, SocketAddr), Result<usize, Error>>,
486}
487
488impl UdpMuxWriterHandle {
489    /// Returns a new `UdpMuxWriterHandle` and `registration`, `send` receivers.
490    fn new() -> (
491        Self,
492        req_res_chan::Receiver<(UDPMuxConn, SocketAddr), ()>,
493        req_res_chan::Receiver<(Vec<u8>, SocketAddr), Result<usize, Error>>,
494    ) {
495        let (sender1, receiver1) = req_res_chan::new(1);
496        let (sender2, receiver2) = req_res_chan::new(1);
497
498        let this = Self {
499            registration_channel: sender1,
500            send_channel: sender2,
501        };
502
503        (this, receiver1, receiver2)
504    }
505}
506
507#[async_trait]
508impl UDPMuxWriter for UdpMuxWriterHandle {
509    async fn register_conn_for_address(&self, conn: &UDPMuxConn, addr: SocketAddr) {
510        match self
511            .registration_channel
512            .send((conn.to_owned(), addr))
513            .await
514        {
515            Ok(()) => {}
516            Err(e) => {
517                tracing::debug!("Failed to send message through channel: {:?}", e);
518                return;
519            }
520        }
521
522        tracing::debug!(address=%addr, connection=%conn.key(), "Registered address for connection");
523    }
524
525    async fn send_to(&self, buf: &[u8], target: &SocketAddr) -> Result<usize, Error> {
526        let bytes_written = self
527            .send_channel
528            .send((buf.to_owned(), target.to_owned()))
529            .await
530            .map_err(|e| Error::Io(e.into()))??;
531
532        Ok(bytes_written)
533    }
534}
535
536/// Gets the ufrag from the given STUN message or returns an error, if failed to decode or the
537/// username attribute is not present.
538fn ufrag_from_stun_message(buffer: &[u8], local_ufrag: bool) -> Result<String, Error> {
539    let (result, message) = {
540        let mut m = STUNMessage::new();
541
542        (m.unmarshal_binary(buffer), m)
543    };
544
545    if let Err(err) = result {
546        Err(Error::Other(format!("failed to handle decode ICE: {err}")))
547    } else {
548        let (attr, found) = message.attributes.get(ATTR_USERNAME);
549        if !found {
550            return Err(Error::Other("no username attribute in STUN message".into()));
551        }
552
553        match String::from_utf8(attr.value) {
554            // Per the RFC this shouldn't happen
555            // https://datatracker.ietf.org/doc/html/rfc5389#section-15.3
556            Err(err) => Err(Error::Other(format!(
557                "failed to decode USERNAME from STUN message as UTF-8: {err}"
558            ))),
559            Ok(s) => {
560                // s is a combination of the local_ufrag and the remote ufrag separated by `:`.
561                let res = if local_ufrag {
562                    s.split(':').next()
563                } else {
564                    s.split(':').next_back()
565                };
566                match res {
567                    Some(s) => Ok(s.to_owned()),
568                    None => Err(Error::Other("can't get ufrag from username".into())),
569                }
570            }
571        }
572    }
573}
574
575#[derive(Error, Debug)]
576enum ConnQueryError {
577    #[error("ufrag is already taken (associated_addrs={associated_addrs:?})")]
578    UfragAlreadyTaken { associated_addrs: Vec<SocketAddr> },
579}