libp2p_webrtc_websys/stream/
poll_data_channel.rs1use std::{
2 cmp::min,
3 io,
4 pin::Pin,
5 rc::Rc,
6 sync::{
7 atomic::{AtomicBool, Ordering},
8 Mutex,
9 },
10 task::{Context, Poll},
11};
12
13use bytes::BytesMut;
14use futures::{task::AtomicWaker, AsyncRead, AsyncWrite};
15use libp2p_webrtc_utils::MAX_MSG_LEN;
16use wasm_bindgen::prelude::*;
17use web_sys::{Event, MessageEvent, RtcDataChannel, RtcDataChannelEvent, RtcDataChannelState};
18
19#[derive(Debug, Clone)]
22pub(crate) struct PollDataChannel {
23 inner: RtcDataChannel,
25
26 new_data_waker: Rc<AtomicWaker>,
27 read_buffer: Rc<Mutex<BytesMut>>,
28
29 open_waker: Rc<AtomicWaker>,
31
32 write_waker: Rc<AtomicWaker>,
35
36 close_waker: Rc<AtomicWaker>,
38
39 overloaded: Rc<AtomicBool>,
47
48 _on_open_closure: Rc<Closure<dyn FnMut(RtcDataChannelEvent)>>,
51 _on_write_closure: Rc<Closure<dyn FnMut(Event)>>,
52 _on_close_closure: Rc<Closure<dyn FnMut(Event)>>,
53 _on_message_closure: Rc<Closure<dyn FnMut(MessageEvent)>>,
54}
55
56impl PollDataChannel {
57 pub(crate) fn new(inner: RtcDataChannel) -> Self {
58 let open_waker = Rc::new(AtomicWaker::new());
59 let on_open_closure = Closure::new({
60 let open_waker = open_waker.clone();
61
62 move |_: RtcDataChannelEvent| {
63 tracing::trace!("DataChannel opened");
64 open_waker.wake();
65 }
66 });
67 inner.set_onopen(Some(on_open_closure.as_ref().unchecked_ref()));
68
69 let write_waker = Rc::new(AtomicWaker::new());
70 inner.set_buffered_amount_low_threshold(0);
71 let on_write_closure = Closure::new({
72 let write_waker = write_waker.clone();
73
74 move |_: Event| {
75 tracing::trace!("DataChannel available for writing (again)");
76 write_waker.wake();
77 }
78 });
79 inner.set_onbufferedamountlow(Some(on_write_closure.as_ref().unchecked_ref()));
80
81 let close_waker = Rc::new(AtomicWaker::new());
82 let on_close_closure = Closure::new({
83 let close_waker = close_waker.clone();
84
85 move |_: Event| {
86 tracing::trace!("DataChannel closed");
87 close_waker.wake();
88 }
89 });
90 inner.set_onclose(Some(on_close_closure.as_ref().unchecked_ref()));
91
92 let new_data_waker = Rc::new(AtomicWaker::new());
93 let read_buffer = Rc::new(Mutex::new(BytesMut::new()));
96 let overloaded = Rc::new(AtomicBool::new(false));
97
98 let on_message_closure = Closure::<dyn FnMut(_)>::new({
99 let new_data_waker = new_data_waker.clone();
100 let read_buffer = read_buffer.clone();
101 let overloaded = overloaded.clone();
102
103 move |ev: MessageEvent| {
104 let data = js_sys::Uint8Array::new(&ev.data());
105
106 let mut read_buffer = read_buffer.lock().unwrap();
107
108 if read_buffer.len() + data.length() as usize > MAX_MSG_LEN {
109 overloaded.store(true, Ordering::SeqCst);
110 tracing::warn!("Remote is overloading us with messages, resetting stream",);
111 return;
112 }
113
114 read_buffer.extend_from_slice(&data.to_vec());
115 new_data_waker.wake();
116 }
117 });
118 inner.set_onmessage(Some(on_message_closure.as_ref().unchecked_ref()));
119
120 Self {
121 inner,
122 new_data_waker,
123 read_buffer,
124 open_waker,
125 write_waker,
126 close_waker,
127 overloaded,
128 _on_open_closure: Rc::new(on_open_closure),
129 _on_write_closure: Rc::new(on_write_closure),
130 _on_close_closure: Rc::new(on_close_closure),
131 _on_message_closure: Rc::new(on_message_closure),
132 }
133 }
134
135 fn ready_state(&self) -> RtcDataChannelState {
137 self.inner.ready_state()
138 }
139
140 fn buffered_amount(&self) -> usize {
142 self.inner.buffered_amount() as usize
143 }
144
145 fn poll_ready(&mut self, cx: &mut Context) -> Poll<io::Result<()>> {
147 match self.ready_state() {
148 RtcDataChannelState::Connecting => {
149 self.open_waker.register(cx.waker());
150 return Poll::Pending;
151 }
152 RtcDataChannelState::Closing | RtcDataChannelState::Closed => {
153 return Poll::Ready(Err(io::ErrorKind::BrokenPipe.into()))
154 }
155 RtcDataChannelState::Open | RtcDataChannelState::__Invalid => {}
156 _ => {}
157 }
158
159 if self.overloaded.load(Ordering::SeqCst) {
160 return Poll::Ready(Err(io::Error::new(
161 io::ErrorKind::BrokenPipe,
162 "remote overloaded us with messages",
163 )));
164 }
165
166 Poll::Ready(Ok(()))
167 }
168}
169
170impl AsyncRead for PollDataChannel {
171 fn poll_read(
172 self: Pin<&mut Self>,
173 cx: &mut Context<'_>,
174 buf: &mut [u8],
175 ) -> Poll<io::Result<usize>> {
176 let this = self.get_mut();
177
178 futures::ready!(this.poll_ready(cx))?;
179
180 let mut read_buffer = this.read_buffer.lock().unwrap();
181
182 if read_buffer.is_empty() {
183 this.new_data_waker.register(cx.waker());
184 return Poll::Pending;
185 }
186
187 let split_index = min(buf.len(), read_buffer.len());
191
192 let bytes_to_return = read_buffer.split_to(split_index);
193 let len = bytes_to_return.len();
194 buf[..len].copy_from_slice(&bytes_to_return);
195
196 Poll::Ready(Ok(len))
197 }
198}
199
200impl AsyncWrite for PollDataChannel {
201 fn poll_write(
202 self: Pin<&mut Self>,
203 cx: &mut Context<'_>,
204 buf: &[u8],
205 ) -> Poll<io::Result<usize>> {
206 let this = self.get_mut();
207
208 futures::ready!(this.poll_ready(cx))?;
209
210 debug_assert!(this.buffered_amount() <= MAX_MSG_LEN);
211 let remaining_space = MAX_MSG_LEN - this.buffered_amount();
212
213 if remaining_space == 0 {
214 this.write_waker.register(cx.waker());
215 return Poll::Pending;
216 }
217
218 let bytes_to_send = min(buf.len(), remaining_space);
219
220 if this
221 .inner
222 .send_with_u8_array(&buf[..bytes_to_send])
223 .is_err()
224 {
225 return Poll::Ready(Err(io::ErrorKind::BrokenPipe.into()));
226 }
227
228 Poll::Ready(Ok(bytes_to_send))
229 }
230
231 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
232 if self.buffered_amount() == 0 {
233 return Poll::Ready(Ok(()));
234 }
235
236 self.write_waker.register(cx.waker());
237 Poll::Pending
238 }
239
240 fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
241 if self.ready_state() == RtcDataChannelState::Closed {
242 return Poll::Ready(Ok(()));
243 }
244
245 if self.ready_state() != RtcDataChannelState::Closing {
246 self.inner.close();
247 }
248
249 self.close_waker.register(cx.waker());
250 Poll::Pending
251 }
252}