libp2p_webtransport_websys/
stream.rs

1use std::{
2    io,
3    pin::Pin,
4    task::{ready, Context, Poll},
5};
6
7use futures::{AsyncRead, AsyncWrite, FutureExt};
8use js_sys::Uint8Array;
9use send_wrapper::SendWrapper;
10use web_sys::{ReadableStreamDefaultReader, WritableStreamDefaultWriter};
11
12use crate::{
13    bindings::WebTransportBidirectionalStream,
14    fused_js_promise::FusedJsPromise,
15    utils::{detach_promise, parse_reader_response, to_io_error, to_js_type},
16    Error,
17};
18
19/// A stream on a connection.
20#[derive(Debug)]
21pub struct Stream {
22    // Swarm needs all types to be Send. WASM is single-threaded
23    // and it is safe to use SendWrapper.
24    inner: SendWrapper<StreamInner>,
25}
26
27#[derive(Debug)]
28struct StreamInner {
29    reader: ReadableStreamDefaultReader,
30    reader_read_promise: FusedJsPromise,
31    read_leftovers: Option<Uint8Array>,
32    writer: WritableStreamDefaultWriter,
33    writer_state: StreamState,
34    writer_ready_promise: FusedJsPromise,
35    writer_closed_promise: FusedJsPromise,
36}
37
38#[derive(Debug, Clone, Copy, PartialEq, Eq)]
39enum StreamState {
40    Open,
41    Closing,
42    Closed,
43}
44
45impl Stream {
46    pub(crate) fn new(bidi_stream: WebTransportBidirectionalStream) -> Result<Self, Error> {
47        let recv_stream = bidi_stream.readable();
48        let send_stream = bidi_stream.writable();
49
50        let reader = to_js_type::<ReadableStreamDefaultReader>(recv_stream.get_reader())?;
51        let writer = send_stream.get_writer().map_err(Error::from_js_value)?;
52
53        Ok(Stream {
54            inner: SendWrapper::new(StreamInner {
55                reader,
56                reader_read_promise: FusedJsPromise::new(),
57                read_leftovers: None,
58                writer,
59                writer_state: StreamState::Open,
60                writer_ready_promise: FusedJsPromise::new(),
61                writer_closed_promise: FusedJsPromise::new(),
62            }),
63        })
64    }
65}
66
67impl StreamInner {
68    fn poll_reader_read(&mut self, cx: &mut Context) -> Poll<io::Result<Option<Uint8Array>>> {
69        let val = ready!(self
70            .reader_read_promise
71            .maybe_init(|| self.reader.read())
72            .poll_unpin(cx))
73        .map_err(to_io_error)?;
74
75        let val = parse_reader_response(&val)
76            .map_err(to_io_error)?
77            .map(Uint8Array::from);
78
79        Poll::Ready(Ok(val))
80    }
81
82    fn poll_read(&mut self, cx: &mut Context<'_>, buf: &mut [u8]) -> Poll<io::Result<usize>> {
83        // If we have leftovers from a previous read, then use them.
84        // Otherwise read new data.
85        let data = match self.read_leftovers.take() {
86            Some(data) => data,
87            None => {
88                match ready!(self.poll_reader_read(cx))? {
89                    Some(data) => data,
90                    // EOF
91                    None => return Poll::Ready(Ok(0)),
92                }
93            }
94        };
95
96        if data.byte_length() == 0 {
97            return Poll::Ready(Ok(0));
98        }
99
100        let out_len = data.byte_length().min(buf.len() as u32);
101        data.slice(0, out_len).copy_to(&mut buf[..out_len as usize]);
102
103        let leftovers = data.slice(out_len, data.byte_length());
104
105        if leftovers.byte_length() > 0 {
106            self.read_leftovers = Some(leftovers);
107        }
108
109        Poll::Ready(Ok(out_len as usize))
110    }
111
112    fn poll_writer_ready(&mut self, cx: &mut Context) -> Poll<io::Result<()>> {
113        if self.writer_state != StreamState::Open {
114            return Poll::Ready(Err(io::ErrorKind::BrokenPipe.into()));
115        }
116
117        let desired_size = self
118            .writer
119            .desired_size()
120            .map_err(to_io_error)?
121            .map(|n| n.trunc() as i64)
122            .unwrap_or(0);
123
124        // We need to poll if the queue is full or if the promise was already activated.
125        //
126        // NOTE: `desired_size` can be negative if we overcommit messages to the queue.
127        if desired_size <= 0 || self.writer_ready_promise.is_active() {
128            ready!(self
129                .writer_ready_promise
130                .maybe_init(|| self.writer.ready())
131                .poll_unpin(cx))
132            .map_err(to_io_error)?;
133        }
134
135        Poll::Ready(Ok(()))
136    }
137
138    fn poll_write(&mut self, cx: &mut Context, buf: &[u8]) -> Poll<io::Result<usize>> {
139        ready!(self.poll_writer_ready(cx))?;
140
141        let len = buf.len() as u32;
142        let data = Uint8Array::new_with_length(len);
143        data.copy_from(buf);
144
145        detach_promise(self.writer.write_with_chunk(&data));
146
147        Poll::Ready(Ok(len as usize))
148    }
149
150    fn poll_flush(&mut self, cx: &mut Context) -> Poll<io::Result<()>> {
151        if self.writer_state == StreamState::Open {
152            // Writer has queue size of 1, so as soon it is ready, self means the
153            // messages were flushed.
154            self.poll_writer_ready(cx)
155        } else {
156            debug_assert!(
157                false,
158                "libp2p_webtransport_websys::Stream: poll_flush called after poll_close"
159            );
160            Poll::Ready(Ok(()))
161        }
162    }
163
164    fn poll_writer_close(&mut self, cx: &mut Context) -> Poll<io::Result<()>> {
165        match self.writer_state {
166            StreamState::Open => {
167                self.writer_state = StreamState::Closing;
168
169                // Initiate close
170                detach_promise(self.writer.close());
171
172                // Assume closed on error
173                let _ = ready!(self
174                    .writer_closed_promise
175                    .maybe_init(|| self.writer.closed())
176                    .poll_unpin(cx));
177
178                self.writer_state = StreamState::Closed;
179            }
180            StreamState::Closing => {
181                // Assume closed on error
182                let _ = ready!(self.writer_closed_promise.poll_unpin(cx));
183                self.writer_state = StreamState::Closed;
184            }
185            StreamState::Closed => {}
186        }
187
188        Poll::Ready(Ok(()))
189    }
190}
191
192impl Drop for StreamInner {
193    fn drop(&mut self) {
194        // Close writer.
195        //
196        // We choose to use `close()` instead of `abort()`, because
197        // abort was causing some side effects on the WebTransport
198        // layer and connection was lost.
199        detach_promise(self.writer.close());
200
201        // Cancel any ongoing reads.
202        detach_promise(self.reader.cancel());
203    }
204}
205
206impl AsyncRead for Stream {
207    fn poll_read(
208        mut self: Pin<&mut Self>,
209        cx: &mut Context<'_>,
210        buf: &mut [u8],
211    ) -> Poll<io::Result<usize>> {
212        self.inner.poll_read(cx, buf)
213    }
214}
215
216impl AsyncWrite for Stream {
217    fn poll_write(
218        mut self: Pin<&mut Self>,
219        cx: &mut Context,
220        buf: &[u8],
221    ) -> Poll<io::Result<usize>> {
222        self.inner.poll_write(cx, buf)
223    }
224
225    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
226        self.inner.poll_flush(cx)
227    }
228
229    fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
230        self.inner.poll_writer_close(cx)
231    }
232}