libp2p_webtransport_websys/
stream.rs1use std::{
2 io,
3 pin::Pin,
4 task::{ready, Context, Poll},
5};
6
7use futures::{AsyncRead, AsyncWrite, FutureExt};
8use js_sys::Uint8Array;
9use send_wrapper::SendWrapper;
10use web_sys::{ReadableStreamDefaultReader, WritableStreamDefaultWriter};
11
12use crate::{
13 bindings::WebTransportBidirectionalStream,
14 fused_js_promise::FusedJsPromise,
15 utils::{detach_promise, parse_reader_response, to_io_error, to_js_type},
16 Error,
17};
18
19#[derive(Debug)]
21pub struct Stream {
22 inner: SendWrapper<StreamInner>,
25}
26
27#[derive(Debug)]
28struct StreamInner {
29 reader: ReadableStreamDefaultReader,
30 reader_read_promise: FusedJsPromise,
31 read_leftovers: Option<Uint8Array>,
32 writer: WritableStreamDefaultWriter,
33 writer_state: StreamState,
34 writer_ready_promise: FusedJsPromise,
35 writer_closed_promise: FusedJsPromise,
36}
37
38#[derive(Debug, Clone, Copy, PartialEq, Eq)]
39enum StreamState {
40 Open,
41 Closing,
42 Closed,
43}
44
45impl Stream {
46 pub(crate) fn new(bidi_stream: WebTransportBidirectionalStream) -> Result<Self, Error> {
47 let recv_stream = bidi_stream.readable();
48 let send_stream = bidi_stream.writable();
49
50 let reader = to_js_type::<ReadableStreamDefaultReader>(recv_stream.get_reader())?;
51 let writer = send_stream.get_writer().map_err(Error::from_js_value)?;
52
53 Ok(Stream {
54 inner: SendWrapper::new(StreamInner {
55 reader,
56 reader_read_promise: FusedJsPromise::new(),
57 read_leftovers: None,
58 writer,
59 writer_state: StreamState::Open,
60 writer_ready_promise: FusedJsPromise::new(),
61 writer_closed_promise: FusedJsPromise::new(),
62 }),
63 })
64 }
65}
66
67impl StreamInner {
68 fn poll_reader_read(&mut self, cx: &mut Context) -> Poll<io::Result<Option<Uint8Array>>> {
69 let val = ready!(self
70 .reader_read_promise
71 .maybe_init(|| self.reader.read())
72 .poll_unpin(cx))
73 .map_err(to_io_error)?;
74
75 let val = parse_reader_response(&val)
76 .map_err(to_io_error)?
77 .map(Uint8Array::from);
78
79 Poll::Ready(Ok(val))
80 }
81
82 fn poll_read(&mut self, cx: &mut Context<'_>, buf: &mut [u8]) -> Poll<io::Result<usize>> {
83 let data = match self.read_leftovers.take() {
86 Some(data) => data,
87 None => {
88 match ready!(self.poll_reader_read(cx))? {
89 Some(data) => data,
90 None => return Poll::Ready(Ok(0)),
92 }
93 }
94 };
95
96 if data.byte_length() == 0 {
97 return Poll::Ready(Ok(0));
98 }
99
100 let out_len = data.byte_length().min(buf.len() as u32);
101 data.slice(0, out_len).copy_to(&mut buf[..out_len as usize]);
102
103 let leftovers = data.slice(out_len, data.byte_length());
104
105 if leftovers.byte_length() > 0 {
106 self.read_leftovers = Some(leftovers);
107 }
108
109 Poll::Ready(Ok(out_len as usize))
110 }
111
112 fn poll_writer_ready(&mut self, cx: &mut Context) -> Poll<io::Result<()>> {
113 if self.writer_state != StreamState::Open {
114 return Poll::Ready(Err(io::ErrorKind::BrokenPipe.into()));
115 }
116
117 let desired_size = self
118 .writer
119 .desired_size()
120 .map_err(to_io_error)?
121 .map(|n| n.trunc() as i64)
122 .unwrap_or(0);
123
124 if desired_size <= 0 || self.writer_ready_promise.is_active() {
128 ready!(self
129 .writer_ready_promise
130 .maybe_init(|| self.writer.ready())
131 .poll_unpin(cx))
132 .map_err(to_io_error)?;
133 }
134
135 Poll::Ready(Ok(()))
136 }
137
138 fn poll_write(&mut self, cx: &mut Context, buf: &[u8]) -> Poll<io::Result<usize>> {
139 ready!(self.poll_writer_ready(cx))?;
140
141 let len = buf.len() as u32;
142 let data = Uint8Array::new_with_length(len);
143 data.copy_from(buf);
144
145 detach_promise(self.writer.write_with_chunk(&data));
146
147 Poll::Ready(Ok(len as usize))
148 }
149
150 fn poll_flush(&mut self, cx: &mut Context) -> Poll<io::Result<()>> {
151 if self.writer_state == StreamState::Open {
152 self.poll_writer_ready(cx)
155 } else {
156 debug_assert!(
157 false,
158 "libp2p_webtransport_websys::Stream: poll_flush called after poll_close"
159 );
160 Poll::Ready(Ok(()))
161 }
162 }
163
164 fn poll_writer_close(&mut self, cx: &mut Context) -> Poll<io::Result<()>> {
165 match self.writer_state {
166 StreamState::Open => {
167 self.writer_state = StreamState::Closing;
168
169 detach_promise(self.writer.close());
171
172 let _ = ready!(self
174 .writer_closed_promise
175 .maybe_init(|| self.writer.closed())
176 .poll_unpin(cx));
177
178 self.writer_state = StreamState::Closed;
179 }
180 StreamState::Closing => {
181 let _ = ready!(self.writer_closed_promise.poll_unpin(cx));
183 self.writer_state = StreamState::Closed;
184 }
185 StreamState::Closed => {}
186 }
187
188 Poll::Ready(Ok(()))
189 }
190}
191
192impl Drop for StreamInner {
193 fn drop(&mut self) {
194 detach_promise(self.writer.close());
200
201 detach_promise(self.reader.cancel());
203 }
204}
205
206impl AsyncRead for Stream {
207 fn poll_read(
208 mut self: Pin<&mut Self>,
209 cx: &mut Context<'_>,
210 buf: &mut [u8],
211 ) -> Poll<io::Result<usize>> {
212 self.inner.poll_read(cx, buf)
213 }
214}
215
216impl AsyncWrite for Stream {
217 fn poll_write(
218 mut self: Pin<&mut Self>,
219 cx: &mut Context,
220 buf: &[u8],
221 ) -> Poll<io::Result<usize>> {
222 self.inner.poll_write(cx, buf)
223 }
224
225 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
226 self.inner.poll_flush(cx)
227 }
228
229 fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
230 self.inner.poll_writer_close(cx)
231 }
232}