libp2p_webrtc/tokio/
connection.rs

1// Copyright 2022 Parity Technologies (UK) Ltd.
2//
3// Permission is hereby granted, free of charge, to any person obtaining a
4// copy of this software and associated documentation files (the "Software"),
5// to deal in the Software without restriction, including without limitation
6// the rights to use, copy, modify, merge, publish, distribute, sublicense,
7// and/or sell copies of the Software, and to permit persons to whom the
8// Software is furnished to do so, subject to the following conditions:
9//
10// The above copyright notice and this permission notice shall be included in
11// all copies or substantial portions of the Software.
12//
13// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
14// OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
16// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
18// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
19// DEALINGS IN THE SOFTWARE.
20
21use std::{
22    pin::Pin,
23    sync::Arc,
24    task::{Context, Poll, Waker},
25};
26
27use futures::{
28    channel::{
29        mpsc,
30        oneshot::{self, Sender},
31    },
32    future::BoxFuture,
33    lock::Mutex as FutMutex,
34    ready,
35    stream::FuturesUnordered,
36    StreamExt,
37};
38use libp2p_core::muxing::{StreamMuxer, StreamMuxerEvent};
39use webrtc::{
40    data::data_channel::DataChannel as DetachedDataChannel, data_channel::RTCDataChannel,
41    peer_connection::RTCPeerConnection,
42};
43
44use crate::tokio::{error::Error, stream, stream::Stream};
45
46/// Maximum number of unprocessed data channels.
47/// See [`Connection::poll_inbound`].
48const MAX_DATA_CHANNELS_IN_FLIGHT: usize = 10;
49
50/// A WebRTC connection, wrapping [`RTCPeerConnection`] and implementing [`StreamMuxer`] trait.
51pub struct Connection {
52    /// [`RTCPeerConnection`] to the remote peer.
53    ///
54    /// Uses futures mutex because used in async code (see poll_outbound and poll_close).
55    peer_conn: Arc<FutMutex<RTCPeerConnection>>,
56
57    /// Channel onto which incoming data channels are put.
58    incoming_data_channels_rx: mpsc::Receiver<Arc<DetachedDataChannel>>,
59
60    /// Future, which, once polled, will result in an outbound stream.
61    outbound_fut: Option<BoxFuture<'static, Result<Arc<DetachedDataChannel>, Error>>>,
62
63    /// Future, which, once polled, will result in closing the entire connection.
64    close_fut: Option<BoxFuture<'static, Result<(), Error>>>,
65
66    /// A list of futures, which, once completed, signal that a [`Stream`] has been dropped.
67    drop_listeners: FuturesUnordered<stream::DropListener>,
68    no_drop_listeners_waker: Option<Waker>,
69}
70
71impl Unpin for Connection {}
72
73impl Connection {
74    /// Creates a new connection.
75    pub(crate) async fn new(rtc_conn: RTCPeerConnection) -> Self {
76        let (data_channel_tx, data_channel_rx) = mpsc::channel(MAX_DATA_CHANNELS_IN_FLIGHT);
77
78        Connection::register_incoming_data_channels_handler(
79            &rtc_conn,
80            Arc::new(FutMutex::new(data_channel_tx)),
81        )
82        .await;
83
84        Self {
85            peer_conn: Arc::new(FutMutex::new(rtc_conn)),
86            incoming_data_channels_rx: data_channel_rx,
87            outbound_fut: None,
88            close_fut: None,
89            drop_listeners: FuturesUnordered::default(),
90            no_drop_listeners_waker: None,
91        }
92    }
93
94    /// Registers a handler for incoming data channels.
95    ///
96    /// NOTE: `mpsc::Sender` is wrapped in `Arc` because cloning a raw sender would make the channel
97    /// unbounded. "The channel’s capacity is equal to buffer + num-senders. In other words, each
98    /// sender gets a guaranteed slot in the channel capacity..."
99    /// See <https://docs.rs/futures/latest/futures/channel/mpsc/fn.channel.html>
100    async fn register_incoming_data_channels_handler(
101        rtc_conn: &RTCPeerConnection,
102        tx: Arc<FutMutex<mpsc::Sender<Arc<DetachedDataChannel>>>>,
103    ) {
104        rtc_conn.on_data_channel(Box::new(move |data_channel: Arc<RTCDataChannel>| {
105            tracing::debug!(channel=%data_channel.id(), "Incoming data channel");
106
107            let tx = tx.clone();
108
109            Box::pin(async move {
110                data_channel.on_open({
111                    let data_channel = data_channel.clone();
112                    Box::new(move || {
113                        tracing::debug!(channel=%data_channel.id(), "Data channel open");
114
115                        Box::pin(async move {
116                            let data_channel = data_channel.clone();
117                            let id = data_channel.id();
118                            match data_channel.detach().await {
119                                Ok(detached) => {
120                                    let mut tx = tx.lock().await;
121                                    if let Err(e) = tx.try_send(detached.clone()) {
122                                        tracing::error!(channel=%id, "Can't send data channel: {}", e);
123                                        // We're not accepting data channels fast enough =>
124                                        // close this channel.
125                                        //
126                                        // Ideally we'd refuse to accept a data channel
127                                        // during the negotiation process, but it's not
128                                        // possible with the current API.
129                                        if let Err(e) = detached.close().await {
130                                            tracing::error!(
131                                                channel=%id,
132                                                "Failed to close data channel: {}",
133                                                e
134                                            );
135                                        }
136                                    }
137                                }
138                                Err(e) => {
139                                    tracing::error!(channel=%id, "Can't detach data channel: {}", e);
140                                }
141                            };
142                        })
143                    })
144                });
145            })
146        }));
147    }
148}
149
150impl StreamMuxer for Connection {
151    type Substream = Stream;
152    type Error = Error;
153
154    fn poll_inbound(
155        mut self: Pin<&mut Self>,
156        cx: &mut Context<'_>,
157    ) -> Poll<Result<Self::Substream, Self::Error>> {
158        match ready!(self.incoming_data_channels_rx.poll_next_unpin(cx)) {
159            Some(detached) => {
160                tracing::trace!(stream=%detached.stream_identifier(), "Incoming stream");
161
162                let (stream, drop_listener) = Stream::new(detached);
163                self.drop_listeners.push(drop_listener);
164                if let Some(waker) = self.no_drop_listeners_waker.take() {
165                    waker.wake()
166                }
167
168                Poll::Ready(Ok(stream))
169            }
170            None => {
171                debug_assert!(
172                    false,
173                    "Sender-end of channel should be owned by `RTCPeerConnection`"
174                );
175
176                // Return `Pending` without registering a waker: If the channel is
177                // closed, we don't need to be called anymore.
178                Poll::Pending
179            }
180        }
181    }
182
183    fn poll(
184        mut self: Pin<&mut Self>,
185        cx: &mut Context<'_>,
186    ) -> Poll<Result<StreamMuxerEvent, Self::Error>> {
187        loop {
188            match ready!(self.drop_listeners.poll_next_unpin(cx)) {
189                Some(Ok(())) => {}
190                Some(Err(e)) => {
191                    tracing::debug!("a DropListener failed: {e}")
192                }
193                None => {
194                    self.no_drop_listeners_waker = Some(cx.waker().clone());
195                    return Poll::Pending;
196                }
197            }
198        }
199    }
200
201    fn poll_outbound(
202        mut self: Pin<&mut Self>,
203        cx: &mut Context<'_>,
204    ) -> Poll<Result<Self::Substream, Self::Error>> {
205        let peer_conn = self.peer_conn.clone();
206        let fut = self.outbound_fut.get_or_insert(Box::pin(async move {
207            let peer_conn = peer_conn.lock().await;
208
209            let data_channel = peer_conn.create_data_channel("", None).await?;
210
211            // No need to hold the lock during the DTLS handshake.
212            drop(peer_conn);
213
214            tracing::trace!(channel=%data_channel.id(), "Opening data channel");
215
216            let (tx, rx) = oneshot::channel::<Arc<DetachedDataChannel>>();
217
218            // Wait until the data channel is opened and detach it.
219            register_data_channel_open_handler(data_channel, tx).await;
220
221            // Wait until data channel is opened and ready to use
222            match rx.await {
223                Ok(detached) => Ok(detached),
224                Err(e) => Err(Error::Internal(e.to_string())),
225            }
226        }));
227
228        match ready!(fut.as_mut().poll(cx)) {
229            Ok(detached) => {
230                self.outbound_fut = None;
231
232                tracing::trace!(stream=%detached.stream_identifier(), "Outbound stream");
233
234                let (stream, drop_listener) = Stream::new(detached);
235                self.drop_listeners.push(drop_listener);
236                if let Some(waker) = self.no_drop_listeners_waker.take() {
237                    waker.wake()
238                }
239
240                Poll::Ready(Ok(stream))
241            }
242            Err(e) => {
243                self.outbound_fut = None;
244                Poll::Ready(Err(e))
245            }
246        }
247    }
248
249    fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
250        tracing::debug!("Closing connection");
251
252        let peer_conn = self.peer_conn.clone();
253        let fut = self.close_fut.get_or_insert(Box::pin(async move {
254            let peer_conn = peer_conn.lock().await;
255            peer_conn.close().await?;
256
257            Ok(())
258        }));
259
260        match ready!(fut.as_mut().poll(cx)) {
261            Ok(()) => {
262                self.incoming_data_channels_rx.close();
263                self.close_fut = None;
264                Poll::Ready(Ok(()))
265            }
266            Err(e) => {
267                self.close_fut = None;
268                Poll::Ready(Err(e))
269            }
270        }
271    }
272}
273
274pub(crate) async fn register_data_channel_open_handler(
275    data_channel: Arc<RTCDataChannel>,
276    data_channel_tx: Sender<Arc<DetachedDataChannel>>,
277) {
278    data_channel.on_open({
279        let data_channel = data_channel.clone();
280        Box::new(move || {
281            tracing::debug!(channel=%data_channel.id(), "Data channel open");
282
283            Box::pin(async move {
284                let data_channel = data_channel.clone();
285                let id = data_channel.id();
286                match data_channel.detach().await {
287                    Ok(detached) => {
288                        if let Err(e) = data_channel_tx.send(detached.clone()) {
289                            tracing::error!(channel=%id, "Can't send data channel: {:?}", e);
290                            if let Err(e) = detached.close().await {
291                                tracing::error!(channel=%id, "Failed to close data channel: {}", e);
292                            }
293                        }
294                    }
295                    Err(e) => {
296                        tracing::error!(channel=%id, "Can't detach data channel: {}", e);
297                    }
298                };
299            })
300        })
301    });
302}