libp2p_webrtc_utils/
stream.rs1use std::{
23 io,
24 pin::Pin,
25 task::{Context, Poll},
26};
27
28use bytes::Bytes;
29use futures::{channel::oneshot, prelude::*, ready};
30
31use crate::{
32 proto::{Flag, Message},
33 stream::{
34 drop_listener::GracefullyClosed,
35 framed_dc::FramedDc,
36 state::{Closing, State},
37 },
38};
39
40mod drop_listener;
41mod framed_dc;
42mod state;
43
44pub const MAX_MSG_LEN: usize = 16 * 1024;
50const VARINT_LEN: usize = 2;
52const PROTO_OVERHEAD: usize = 5;
54const MAX_DATA_LEN: usize = MAX_MSG_LEN - VARINT_LEN - PROTO_OVERHEAD;
56
57pub use drop_listener::DropListener;
58pub struct Stream<T> {
63 io: FramedDc<T>,
64 state: State,
65 read_buffer: Bytes,
66 drop_notifier: Option<oneshot::Sender<GracefullyClosed>>,
68}
69
70impl<T> Stream<T>
71where
72 T: AsyncRead + AsyncWrite + Unpin + Clone,
73{
74 pub fn new(data_channel: T) -> (Self, DropListener<T>) {
77 let (sender, receiver) = oneshot::channel();
78
79 let stream = Self {
80 io: framed_dc::new(data_channel.clone()),
81 state: State::Open,
82 read_buffer: Bytes::default(),
83 drop_notifier: Some(sender),
84 };
85 let listener = DropListener::new(framed_dc::new(data_channel), receiver);
86
87 (stream, listener)
88 }
89
90 pub fn poll_close_read(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
92 loop {
93 match self.state.close_read_barrier()? {
94 Some(Closing::Requested) => {
95 ready!(self.io.poll_ready_unpin(cx))?;
96
97 self.io.start_send_unpin(Message {
98 flag: Some(Flag::STOP_SENDING),
99 message: None,
100 })?;
101 self.state.close_read_message_sent();
102
103 continue;
104 }
105 Some(Closing::MessageSent) => {
106 ready!(self.io.poll_flush_unpin(cx))?;
107
108 self.state.read_closed();
109
110 return Poll::Ready(Ok(()));
111 }
112 None => return Poll::Ready(Ok(())),
113 }
114 }
115 }
116}
117
118impl<T> AsyncRead for Stream<T>
119where
120 T: AsyncRead + AsyncWrite + Unpin,
121{
122 fn poll_read(
123 mut self: Pin<&mut Self>,
124 cx: &mut Context<'_>,
125 buf: &mut [u8],
126 ) -> Poll<io::Result<usize>> {
127 loop {
128 self.state.read_barrier()?;
129
130 if !self.read_buffer.is_empty() {
131 let n = std::cmp::min(self.read_buffer.len(), buf.len());
132 let data = self.read_buffer.split_to(n);
133 buf[0..n].copy_from_slice(&data[..]);
134
135 return Poll::Ready(Ok(n));
136 }
137
138 let Self {
139 read_buffer,
140 io,
141 state,
142 ..
143 } = &mut *self;
144
145 match ready!(io_poll_next(io, cx))? {
146 Some((flag, message)) => {
147 if let Some(flag) = flag {
148 state.handle_inbound_flag(flag, read_buffer);
149 }
150
151 debug_assert!(read_buffer.is_empty());
152 match message {
153 Some(msg) if !msg.is_empty() => {
154 *read_buffer = msg.into();
155 }
156 _ => {
157 tracing::debug!("poll_read buffer is empty, received None");
158 return Poll::Ready(Ok(0));
159 }
160 }
161 }
162 None => {
163 state.handle_inbound_flag(Flag::FIN, read_buffer);
164 return Poll::Ready(Ok(0));
165 }
166 }
167 }
168 }
169}
170
171impl<T> AsyncWrite for Stream<T>
172where
173 T: AsyncRead + AsyncWrite + Unpin,
174{
175 fn poll_write(
176 mut self: Pin<&mut Self>,
177 cx: &mut Context<'_>,
178 buf: &[u8],
179 ) -> Poll<io::Result<usize>> {
180 while self.state.read_flags_in_async_write() {
181 let Self {
186 read_buffer,
187 io,
188 state,
189 ..
190 } = &mut *self;
191
192 match io_poll_next(io, cx)? {
193 Poll::Ready(Some((Some(flag), message))) => {
194 drop(message);
196 state.handle_inbound_flag(flag, read_buffer)
198 }
199 Poll::Ready(Some((None, message))) => drop(message),
200 Poll::Ready(None) | Poll::Pending => break,
201 }
202 }
203
204 self.state.write_barrier()?;
205
206 ready!(self.io.poll_ready_unpin(cx))?;
207
208 let n = usize::min(buf.len(), MAX_DATA_LEN);
209
210 Pin::new(&mut self.io).start_send(Message {
211 flag: None,
212 message: Some(buf[0..n].into()),
213 })?;
214
215 Poll::Ready(Ok(n))
216 }
217
218 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
219 self.io.poll_flush_unpin(cx).map_err(Into::into)
220 }
221
222 fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
223 loop {
224 match self.state.close_write_barrier()? {
225 Some(Closing::Requested) => {
226 ready!(self.io.poll_ready_unpin(cx))?;
227
228 self.io.start_send_unpin(Message {
229 flag: Some(Flag::FIN),
230 message: None,
231 })?;
232 self.state.close_write_message_sent();
233
234 continue;
235 }
236 Some(Closing::MessageSent) => {
237 ready!(self.io.poll_flush_unpin(cx))?;
238
239 self.state.write_closed();
240 let _ = self
241 .drop_notifier
242 .take()
243 .expect("to not close twice")
244 .send(GracefullyClosed {});
245
246 return Poll::Ready(Ok(()));
247 }
248 None => return Poll::Ready(Ok(())),
249 }
250 }
251 }
252}
253
254fn io_poll_next<T>(
255 io: &mut FramedDc<T>,
256 cx: &mut Context<'_>,
257) -> Poll<io::Result<Option<(Option<Flag>, Option<Vec<u8>>)>>>
258where
259 T: AsyncRead + AsyncWrite + Unpin,
260{
261 match ready!(io.poll_next_unpin(cx))
262 .transpose()
263 .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?
264 {
265 Some(Message { flag, message }) => Poll::Ready(Ok(Some((flag, message)))),
266 None => Poll::Ready(Ok(None)),
267 }
268}
269
270#[cfg(test)]
271mod tests {
272 use asynchronous_codec::Encoder;
273 use bytes::BytesMut;
274
275 use super::*;
276 use crate::stream::framed_dc::codec;
277
278 #[test]
279 fn max_data_len() {
280 let message = [0; MAX_DATA_LEN];
282
283 let protobuf = Message {
284 flag: Some(Flag::FIN),
285 message: Some(message.to_vec()),
286 };
287
288 let mut codec = codec();
289
290 let mut dst = BytesMut::new();
291 codec.encode(protobuf, &mut dst).unwrap();
292
293 assert_eq!(dst.len(), MAX_MSG_LEN);
296 }
297}