libp2p_webrtc_websys/stream/
poll_data_channel.rs

1use std::{
2    cmp::min,
3    io,
4    pin::Pin,
5    rc::Rc,
6    sync::{
7        atomic::{AtomicBool, Ordering},
8        Mutex,
9    },
10    task::{Context, Poll},
11};
12
13use bytes::BytesMut;
14use futures::{task::AtomicWaker, AsyncRead, AsyncWrite};
15use libp2p_webrtc_utils::MAX_MSG_LEN;
16use wasm_bindgen::prelude::*;
17use web_sys::{Event, MessageEvent, RtcDataChannel, RtcDataChannelEvent, RtcDataChannelState};
18
19/// [`PollDataChannel`] is a wrapper around [`RtcDataChannel`] which implements [`AsyncRead`] and
20/// [`AsyncWrite`].
21#[derive(Debug, Clone)]
22pub(crate) struct PollDataChannel {
23    /// The [`RtcDataChannel`] being wrapped.
24    inner: RtcDataChannel,
25
26    new_data_waker: Rc<AtomicWaker>,
27    read_buffer: Rc<Mutex<BytesMut>>,
28
29    /// Waker for when we are waiting for the DC to be opened.
30    open_waker: Rc<AtomicWaker>,
31
32    /// Waker for when we are waiting to write (again) to the DC because we previously exceeded the
33    /// [`MAX_MSG_LEN`] threshold.
34    write_waker: Rc<AtomicWaker>,
35
36    /// Waker for when we are waiting for the DC to be closed.
37    close_waker: Rc<AtomicWaker>,
38
39    /// Whether we've been overloaded with data by the remote.
40    ///
41    /// This is set to `true` in case `read_buffer` overflows, i.e. the remote is sending us
42    /// messages faster than we can read them. In that case, we return an [`std::io::Error`]
43    /// from [`AsyncRead`] or [`AsyncWrite`], depending which one gets called earlier.
44    /// Failing these will (very likely),
45    /// cause the application developer to drop the stream which resets it.
46    overloaded: Rc<AtomicBool>,
47
48    // Store the closures for proper garbage collection.
49    // These are wrapped in an [`Rc`] so we can implement [`Clone`].
50    _on_open_closure: Rc<Closure<dyn FnMut(RtcDataChannelEvent)>>,
51    _on_write_closure: Rc<Closure<dyn FnMut(Event)>>,
52    _on_close_closure: Rc<Closure<dyn FnMut(Event)>>,
53    _on_message_closure: Rc<Closure<dyn FnMut(MessageEvent)>>,
54}
55
56impl PollDataChannel {
57    pub(crate) fn new(inner: RtcDataChannel) -> Self {
58        let open_waker = Rc::new(AtomicWaker::new());
59        let on_open_closure = Closure::new({
60            let open_waker = open_waker.clone();
61
62            move |_: RtcDataChannelEvent| {
63                tracing::trace!("DataChannel opened");
64                open_waker.wake();
65            }
66        });
67        inner.set_onopen(Some(on_open_closure.as_ref().unchecked_ref()));
68
69        let write_waker = Rc::new(AtomicWaker::new());
70        inner.set_buffered_amount_low_threshold(0);
71        let on_write_closure = Closure::new({
72            let write_waker = write_waker.clone();
73
74            move |_: Event| {
75                tracing::trace!("DataChannel available for writing (again)");
76                write_waker.wake();
77            }
78        });
79        inner.set_onbufferedamountlow(Some(on_write_closure.as_ref().unchecked_ref()));
80
81        let close_waker = Rc::new(AtomicWaker::new());
82        let on_close_closure = Closure::new({
83            let close_waker = close_waker.clone();
84
85            move |_: Event| {
86                tracing::trace!("DataChannel closed");
87                close_waker.wake();
88            }
89        });
90        inner.set_onclose(Some(on_close_closure.as_ref().unchecked_ref()));
91
92        let new_data_waker = Rc::new(AtomicWaker::new());
93        // We purposely don't use `with_capacity`
94        // so we don't eagerly allocate `MAX_READ_BUFFER` per stream.
95        let read_buffer = Rc::new(Mutex::new(BytesMut::new()));
96        let overloaded = Rc::new(AtomicBool::new(false));
97
98        let on_message_closure = Closure::<dyn FnMut(_)>::new({
99            let new_data_waker = new_data_waker.clone();
100            let read_buffer = read_buffer.clone();
101            let overloaded = overloaded.clone();
102
103            move |ev: MessageEvent| {
104                let data = js_sys::Uint8Array::new(&ev.data());
105
106                let mut read_buffer = read_buffer.lock().unwrap();
107
108                if read_buffer.len() + data.length() as usize > MAX_MSG_LEN {
109                    overloaded.store(true, Ordering::SeqCst);
110                    tracing::warn!("Remote is overloading us with messages, resetting stream",);
111                    return;
112                }
113
114                read_buffer.extend_from_slice(&data.to_vec());
115                new_data_waker.wake();
116            }
117        });
118        inner.set_onmessage(Some(on_message_closure.as_ref().unchecked_ref()));
119
120        Self {
121            inner,
122            new_data_waker,
123            read_buffer,
124            open_waker,
125            write_waker,
126            close_waker,
127            overloaded,
128            _on_open_closure: Rc::new(on_open_closure),
129            _on_write_closure: Rc::new(on_write_closure),
130            _on_close_closure: Rc::new(on_close_closure),
131            _on_message_closure: Rc::new(on_message_closure),
132        }
133    }
134
135    /// Returns the [RtcDataChannelState] of the [RtcDataChannel]
136    fn ready_state(&self) -> RtcDataChannelState {
137        self.inner.ready_state()
138    }
139
140    /// Returns the current [RtcDataChannel] BufferedAmount
141    fn buffered_amount(&self) -> usize {
142        self.inner.buffered_amount() as usize
143    }
144
145    /// Whether the data channel is ready for reading or writing.
146    fn poll_ready(&mut self, cx: &mut Context) -> Poll<io::Result<()>> {
147        match self.ready_state() {
148            RtcDataChannelState::Connecting => {
149                self.open_waker.register(cx.waker());
150                return Poll::Pending;
151            }
152            RtcDataChannelState::Closing | RtcDataChannelState::Closed => {
153                return Poll::Ready(Err(io::ErrorKind::BrokenPipe.into()))
154            }
155            RtcDataChannelState::Open | RtcDataChannelState::__Invalid => {}
156            _ => {}
157        }
158
159        if self.overloaded.load(Ordering::SeqCst) {
160            return Poll::Ready(Err(io::Error::new(
161                io::ErrorKind::BrokenPipe,
162                "remote overloaded us with messages",
163            )));
164        }
165
166        Poll::Ready(Ok(()))
167    }
168}
169
170impl AsyncRead for PollDataChannel {
171    fn poll_read(
172        self: Pin<&mut Self>,
173        cx: &mut Context<'_>,
174        buf: &mut [u8],
175    ) -> Poll<io::Result<usize>> {
176        let this = self.get_mut();
177
178        futures::ready!(this.poll_ready(cx))?;
179
180        let mut read_buffer = this.read_buffer.lock().unwrap();
181
182        if read_buffer.is_empty() {
183            this.new_data_waker.register(cx.waker());
184            return Poll::Pending;
185        }
186
187        // Ensure that we:
188        // - at most return what the caller can read (`buf.len()`)
189        // - at most what we have (`read_buffer.len()`)
190        let split_index = min(buf.len(), read_buffer.len());
191
192        let bytes_to_return = read_buffer.split_to(split_index);
193        let len = bytes_to_return.len();
194        buf[..len].copy_from_slice(&bytes_to_return);
195
196        Poll::Ready(Ok(len))
197    }
198}
199
200impl AsyncWrite for PollDataChannel {
201    fn poll_write(
202        self: Pin<&mut Self>,
203        cx: &mut Context<'_>,
204        buf: &[u8],
205    ) -> Poll<io::Result<usize>> {
206        let this = self.get_mut();
207
208        futures::ready!(this.poll_ready(cx))?;
209
210        debug_assert!(this.buffered_amount() <= MAX_MSG_LEN);
211        let remaining_space = MAX_MSG_LEN - this.buffered_amount();
212
213        if remaining_space == 0 {
214            this.write_waker.register(cx.waker());
215            return Poll::Pending;
216        }
217
218        let bytes_to_send = min(buf.len(), remaining_space);
219
220        if this
221            .inner
222            .send_with_u8_array(&buf[..bytes_to_send])
223            .is_err()
224        {
225            return Poll::Ready(Err(io::ErrorKind::BrokenPipe.into()));
226        }
227
228        Poll::Ready(Ok(bytes_to_send))
229    }
230
231    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
232        if self.buffered_amount() == 0 {
233            return Poll::Ready(Ok(()));
234        }
235
236        self.write_waker.register(cx.waker());
237        Poll::Pending
238    }
239
240    fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
241        if self.ready_state() == RtcDataChannelState::Closed {
242            return Poll::Ready(Ok(()));
243        }
244
245        if self.ready_state() != RtcDataChannelState::Closing {
246            self.inner.close();
247        }
248
249        self.close_waker.register(cx.waker());
250        Poll::Pending
251    }
252}