1pub(crate) use std::io::{Error, Result};
22use std::{
23 cmp,
24 collections::VecDeque,
25 fmt, io, mem,
26 sync::Arc,
27 task::{Context, Poll, Waker},
28};
29
30use asynchronous_codec::Framed;
31use bytes::Bytes;
32use futures::{
33 prelude::*,
34 ready,
35 stream::Fuse,
36 task::{waker_ref, ArcWake, AtomicWaker, WakerRef},
37};
38use nohash_hasher::{IntMap, IntSet};
39use parking_lot::Mutex;
40use smallvec::SmallVec;
41
42use crate::{
43 codec::{Codec, Frame, LocalStreamId, RemoteStreamId},
44 Config, MaxBufferBehaviour,
45};
46#[derive(Clone, Copy)]
51struct ConnectionId(u64);
52
53impl fmt::Debug for ConnectionId {
54 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
55 write!(f, "{:16x}", self.0)
56 }
57}
58
59impl fmt::Display for ConnectionId {
60 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
61 write!(f, "{:16x}", self.0)
62 }
63}
64pub(crate) struct Multiplexed<C> {
66 id: ConnectionId,
68 status: Status,
70 io: Fuse<Framed<C, Codec>>,
72 config: Config,
74 open_buffer: VecDeque<LocalStreamId>,
78 pending_flush_open: IntSet<LocalStreamId>,
81 blocking_stream: Option<LocalStreamId>,
85 pending_frames: VecDeque<Frame<LocalStreamId>>,
93 substreams: IntMap<LocalStreamId, SubstreamState>,
95 next_outbound_stream_id: LocalStreamId,
97 notifier_read: Arc<NotifierRead>,
99 notifier_write: Arc<NotifierWrite>,
101 notifier_open: NotifierOpen,
107}
108
109#[derive(Debug)]
111enum Status {
112 Open,
114 Closed,
116 Err(io::Error),
118}
119
120impl<C> Multiplexed<C>
121where
122 C: AsyncRead + AsyncWrite + Unpin,
123{
124 pub(crate) fn new(io: C, config: Config) -> Self {
126 let id = ConnectionId(rand::random());
127 tracing::debug!(connection=%id, "New multiplexed connection");
128 Multiplexed {
129 id,
130 config,
131 status: Status::Open,
132 io: Framed::new(io, Codec::new()).fuse(),
133 open_buffer: Default::default(),
134 substreams: Default::default(),
135 pending_flush_open: Default::default(),
136 pending_frames: Default::default(),
137 blocking_stream: None,
138 next_outbound_stream_id: LocalStreamId::dialer(0),
139 notifier_read: Arc::new(NotifierRead {
140 read_stream: Mutex::new(Default::default()),
141 next_stream: AtomicWaker::new(),
142 }),
143 notifier_write: Arc::new(NotifierWrite {
144 pending: Mutex::new(Default::default()),
145 }),
146 notifier_open: NotifierOpen {
147 pending: Default::default(),
148 },
149 }
150 }
151
152 pub(crate) fn poll_flush(&mut self, cx: &Context<'_>) -> Poll<io::Result<()>> {
154 match &self.status {
155 Status::Closed => return Poll::Ready(Ok(())),
156 Status::Err(e) => return Poll::Ready(Err(io::Error::new(e.kind(), e.to_string()))),
157 Status::Open => {}
158 }
159
160 ready!(self.send_pending_frames(cx))?;
162
163 let waker = NotifierWrite::register(&self.notifier_write, cx.waker());
165 match ready!(self.io.poll_flush_unpin(&mut Context::from_waker(&waker))) {
166 Err(e) => Poll::Ready(self.on_error(e)),
167 Ok(()) => {
168 self.pending_flush_open = Default::default();
169 Poll::Ready(Ok(()))
170 }
171 }
172 }
173
174 pub(crate) fn poll_close(&mut self, cx: &Context<'_>) -> Poll<io::Result<()>> {
180 match &self.status {
181 Status::Closed => return Poll::Ready(Ok(())),
182 Status::Err(e) => return Poll::Ready(Err(io::Error::new(e.kind(), e.to_string()))),
183 Status::Open => {}
184 }
185
186 let waker = NotifierWrite::register(&self.notifier_write, cx.waker());
190 match self.io.poll_close_unpin(&mut Context::from_waker(&waker)) {
191 Poll::Pending => Poll::Pending,
192 Poll::Ready(Err(e)) => Poll::Ready(self.on_error(e)),
193 Poll::Ready(Ok(())) => {
194 self.pending_frames = VecDeque::new();
195 self.open_buffer = Default::default();
198 self.substreams = Default::default();
199 self.status = Status::Closed;
200 Poll::Ready(Ok(()))
201 }
202 }
203 }
204
205 pub(crate) fn poll_next_stream(&mut self, cx: &Context<'_>) -> Poll<io::Result<LocalStreamId>> {
219 self.guard_open()?;
220
221 if let Some(stream_id) = self.open_buffer.pop_back() {
223 return Poll::Ready(Ok(stream_id));
224 }
225
226 debug_assert!(self.open_buffer.is_empty());
227 let mut num_buffered = 0;
228
229 loop {
230 if num_buffered == self.config.max_buffer_len {
235 cx.waker().wake_by_ref();
236 return Poll::Pending;
237 }
238
239 match ready!(self.poll_read_frame(cx, None))? {
241 Frame::Open { stream_id } => {
242 if let Some(id) = self.on_open(stream_id)? {
243 return Poll::Ready(Ok(id));
244 }
245 }
246 Frame::Data { stream_id, data } => {
247 self.buffer(stream_id.into_local(), data)?;
248 num_buffered += 1;
249 }
250 Frame::Close { stream_id } => {
251 self.on_close(stream_id.into_local());
252 }
253 Frame::Reset { stream_id } => self.on_reset(stream_id.into_local()),
254 }
255 }
256 }
257
258 pub(crate) fn poll_open_stream(&mut self, cx: &Context<'_>) -> Poll<io::Result<LocalStreamId>> {
260 self.guard_open()?;
261
262 if self.substreams.len() >= self.config.max_substreams {
264 tracing::debug!(
265 connection=%self.id,
266 total_substreams=%self.substreams.len(),
267 max_substreams=%self.config.max_substreams,
268 "Maximum number of substreams reached"
269 );
270 self.notifier_open.register(cx.waker());
271 return Poll::Pending;
272 }
273
274 let waker = NotifierWrite::register(&self.notifier_write, cx.waker());
276 match ready!(self.io.poll_ready_unpin(&mut Context::from_waker(&waker))) {
277 Ok(()) => {
278 let stream_id = self.next_outbound_stream_id();
279 let frame = Frame::Open { stream_id };
280 match self.io.start_send_unpin(frame) {
281 Ok(()) => {
282 self.substreams.insert(
283 stream_id,
284 SubstreamState::Open {
285 buf: Default::default(),
286 },
287 );
288 tracing::debug!(
289 connection=%self.id,
290 substream=%stream_id,
291 total_substreams=%self.substreams.len(),
292 "New outbound substream"
293 );
294 self.pending_flush_open.insert(stream_id);
297 Poll::Ready(Ok(stream_id))
298 }
299 Err(e) => Poll::Ready(self.on_error(e)),
300 }
301 }
302 Err(e) => Poll::Ready(self.on_error(e)),
303 }
304 }
305
306 pub(crate) fn drop_stream(&mut self, id: LocalStreamId) {
327 match self.status {
329 Status::Closed | Status::Err(_) => return,
330 Status::Open => {}
331 }
332
333 self.notifier_read.wake_read_stream(id);
338
339 match self.substreams.remove(&id) {
341 None => {}
342 Some(state) => {
343 let below_limit = self.substreams.len() == self.config.max_substreams - 1;
346 if below_limit {
347 self.notifier_open.wake_all();
348 }
349 match state {
351 SubstreamState::Closed { .. } => {}
352 SubstreamState::SendClosed { .. } => {}
353 SubstreamState::Reset { .. } => {}
354 SubstreamState::RecvClosed { .. } => {
355 if self.check_max_pending_frames().is_err() {
356 return;
357 }
358 tracing::trace!(
359 connection=%self.id,
360 substream=%id,
361 "Pending close for substream"
362 );
363 self.pending_frames
364 .push_front(Frame::Close { stream_id: id });
365 }
366 SubstreamState::Open { .. } => {
367 if self.check_max_pending_frames().is_err() {
368 return;
369 }
370 tracing::trace!(
371 connection=%self.id,
372 substream=%id,
373 "Pending reset for substream"
374 );
375 self.pending_frames
376 .push_front(Frame::Reset { stream_id: id });
377 }
378 }
379 }
380 }
381 }
382
383 pub(crate) fn poll_write_stream(
385 &mut self,
386 cx: &Context<'_>,
387 id: LocalStreamId,
388 buf: &[u8],
389 ) -> Poll<io::Result<usize>> {
390 self.guard_open()?;
391
392 match self.substreams.get(&id) {
394 None | Some(SubstreamState::Reset { .. }) => {
395 return Poll::Ready(Err(io::ErrorKind::BrokenPipe.into()))
396 }
397 Some(SubstreamState::SendClosed { .. }) | Some(SubstreamState::Closed { .. }) => {
398 return Poll::Ready(Err(io::ErrorKind::WriteZero.into()))
399 }
400 Some(SubstreamState::Open { .. }) | Some(SubstreamState::RecvClosed { .. }) => {
401 }
403 }
404
405 let frame_len = cmp::min(buf.len(), self.config.split_send_size);
407
408 ready!(self.poll_send_frame(cx, || {
410 let data = Bytes::copy_from_slice(&buf[..frame_len]);
411 Frame::Data {
412 stream_id: id,
413 data,
414 }
415 }))?;
416
417 Poll::Ready(Ok(frame_len))
418 }
419
420 pub(crate) fn poll_read_stream(
435 &mut self,
436 cx: &Context<'_>,
437 id: LocalStreamId,
438 ) -> Poll<io::Result<Option<Bytes>>> {
439 self.guard_open()?;
440
441 if let Some(state) = self.substreams.get_mut(&id) {
443 let buf = state.recv_buf();
444 if !buf.is_empty() {
445 if self.blocking_stream == Some(id) {
446 self.blocking_stream = None;
448 ArcWake::wake_by_ref(&self.notifier_read);
449 }
450 let data = buf.remove(0);
451 return Poll::Ready(Ok(Some(data)));
452 }
453 buf.shrink_to_fit();
455 }
456
457 let mut num_buffered = 0;
458
459 loop {
460 if num_buffered == self.config.max_buffer_len {
465 cx.waker().wake_by_ref();
466 return Poll::Pending;
467 }
468
469 if !self.can_read(&id) {
471 return Poll::Ready(Ok(None));
477 }
478
479 match ready!(self.poll_read_frame(cx, Some(id)))? {
481 Frame::Data { data, stream_id } if stream_id.into_local() == id => {
482 return Poll::Ready(Ok(Some(data)))
483 }
484 Frame::Data { stream_id, data } => {
485 self.buffer(stream_id.into_local(), data)?;
489 num_buffered += 1;
490 }
491 frame @ Frame::Open { .. } => {
492 if let Some(id) = self.on_open(frame.remote_id())? {
493 self.open_buffer.push_front(id);
494 tracing::trace!(
495 connection=%self.id,
496 inbound_stream=%id,
497 inbound_buffer_len=%self.open_buffer.len(),
498 "Buffered new inbound stream"
499 );
500 self.notifier_read.wake_next_stream();
501 }
502 }
503 Frame::Close { stream_id } => {
504 let stream_id = stream_id.into_local();
505 self.on_close(stream_id);
506 if id == stream_id {
507 return Poll::Ready(Ok(None));
508 }
509 }
510 Frame::Reset { stream_id } => {
511 let stream_id = stream_id.into_local();
512 self.on_reset(stream_id);
513 if id == stream_id {
514 return Poll::Ready(Ok(None));
515 }
516 }
517 }
518 }
519 }
520
521 pub(crate) fn poll_flush_stream(
527 &mut self,
528 cx: &Context<'_>,
529 id: LocalStreamId,
530 ) -> Poll<io::Result<()>> {
531 self.guard_open()?;
532
533 ready!(self.poll_flush(cx))?;
534 tracing::trace!(
535 connection=%self.id,
536 substream=%id,
537 "Flushed substream"
538 );
539
540 Poll::Ready(Ok(()))
541 }
542
543 pub(crate) fn poll_close_stream(
547 &mut self,
548 cx: &Context<'_>,
549 id: LocalStreamId,
550 ) -> Poll<io::Result<()>> {
551 self.guard_open()?;
552
553 match self.substreams.remove(&id) {
554 None => Poll::Ready(Ok(())),
555 Some(SubstreamState::SendClosed { buf }) => {
556 self.substreams
557 .insert(id, SubstreamState::SendClosed { buf });
558 Poll::Ready(Ok(()))
559 }
560 Some(SubstreamState::Closed { buf }) => {
561 self.substreams.insert(id, SubstreamState::Closed { buf });
562 Poll::Ready(Ok(()))
563 }
564 Some(SubstreamState::Reset { buf }) => {
565 self.substreams.insert(id, SubstreamState::Reset { buf });
566 Poll::Ready(Ok(()))
567 }
568 Some(SubstreamState::Open { buf }) => {
569 if self
570 .poll_send_frame(cx, || Frame::Close { stream_id: id })?
571 .is_pending()
572 {
573 self.substreams.insert(id, SubstreamState::Open { buf });
574 Poll::Pending
575 } else {
576 tracing::debug!(
577 connection=%self.id,
578 substream=%id,
579 "Closed substream (half-close)"
580 );
581 self.substreams
582 .insert(id, SubstreamState::SendClosed { buf });
583 Poll::Ready(Ok(()))
584 }
585 }
586 Some(SubstreamState::RecvClosed { buf }) => {
587 if self
588 .poll_send_frame(cx, || Frame::Close { stream_id: id })?
589 .is_pending()
590 {
591 self.substreams
592 .insert(id, SubstreamState::RecvClosed { buf });
593 Poll::Pending
594 } else {
595 tracing::debug!(
596 connection=%self.id,
597 substream=%id,
598 "Closed substream"
599 );
600 self.substreams.insert(id, SubstreamState::Closed { buf });
601 Poll::Ready(Ok(()))
602 }
603 }
604 }
605 }
606
607 fn poll_send_frame<F>(&mut self, cx: &Context<'_>, frame: F) -> Poll<io::Result<()>>
612 where
613 F: FnOnce() -> Frame<LocalStreamId>,
614 {
615 let waker = NotifierWrite::register(&self.notifier_write, cx.waker());
616 match ready!(self.io.poll_ready_unpin(&mut Context::from_waker(&waker))) {
617 Ok(()) => {
618 let frame = frame();
619 tracing::trace!(connection=%self.id, ?frame, "Sending frame");
620 match self.io.start_send_unpin(frame) {
621 Ok(()) => Poll::Ready(Ok(())),
622 Err(e) => Poll::Ready(self.on_error(e)),
623 }
624 }
625 Err(e) => Poll::Ready(self.on_error(e)),
626 }
627 }
628
629 fn poll_read_frame(
636 &mut self,
637 cx: &Context<'_>,
638 stream_id: Option<LocalStreamId>,
639 ) -> Poll<io::Result<Frame<RemoteStreamId>>> {
640 if let Poll::Ready(Err(e)) = self.send_pending_frames(cx) {
642 return Poll::Ready(Err(e));
643 }
644
645 if let Some(id) = &stream_id {
647 if self.pending_flush_open.contains(id) {
648 tracing::trace!(
649 connection=%self.id,
650 substream=%id,
651 "Executing pending flush for substream"
652 );
653 ready!(self.poll_flush(cx))?;
654 self.pending_flush_open = Default::default();
655 }
656 }
657
658 if let Some(blocked_id) = &self.blocking_stream {
660 if !self.notifier_read.wake_read_stream(*blocked_id) {
666 tracing::trace!(
669 connection=%self.id,
670 "No task to read from blocked stream. Waking current task."
671 );
672 cx.waker().wake_by_ref();
673 } else if let Some(id) = stream_id {
674 debug_assert!(
677 blocked_id != &id,
678 "Unexpected attempt at reading a new \
679 frame from a substream with a full buffer."
680 );
681 let _ = NotifierRead::register_read_stream(&self.notifier_read, cx.waker(), id);
682 } else {
683 let _ = NotifierRead::register_next_stream(&self.notifier_read, cx.waker());
686 }
687
688 return Poll::Pending;
689 }
690
691 let waker = match stream_id {
693 Some(id) => NotifierRead::register_read_stream(&self.notifier_read, cx.waker(), id),
694 None => NotifierRead::register_next_stream(&self.notifier_read, cx.waker()),
695 };
696 match ready!(self.io.poll_next_unpin(&mut Context::from_waker(&waker))) {
697 Some(Ok(frame)) => {
698 tracing::trace!(connection=%self.id, ?frame, "Received frame");
699 Poll::Ready(Ok(frame))
700 }
701 Some(Err(e)) => Poll::Ready(self.on_error(e)),
702 None => Poll::Ready(self.on_error(io::ErrorKind::UnexpectedEof.into())),
703 }
704 }
705
706 fn on_open(&mut self, id: RemoteStreamId) -> io::Result<Option<LocalStreamId>> {
708 let id = id.into_local();
709
710 if self.substreams.contains_key(&id) {
711 tracing::debug!(
712 connection=%self.id,
713 substream=%id,
714 "Received unexpected `Open` frame for open substream",
715 );
716 return self.on_error(io::Error::new(
717 io::ErrorKind::Other,
718 "Protocol error: Received `Open` frame for open substream.",
719 ));
720 }
721
722 if self.substreams.len() >= self.config.max_substreams {
723 tracing::debug!(
724 connection=%self.id,
725 max_substreams=%self.config.max_substreams,
726 "Maximum number of substreams exceeded"
727 );
728 self.check_max_pending_frames()?;
729 tracing::debug!(
730 connection=%self.id,
731 substream=%id,
732 "Pending reset for new substream"
733 );
734 self.pending_frames
735 .push_front(Frame::Reset { stream_id: id });
736 return Ok(None);
737 }
738
739 self.substreams.insert(
740 id,
741 SubstreamState::Open {
742 buf: Default::default(),
743 },
744 );
745
746 tracing::debug!(
747 connection=%self.id,
748 substream=%id,
749 total_substreams=%self.substreams.len(),
750 "New inbound substream"
751 );
752
753 Ok(Some(id))
754 }
755
756 fn on_reset(&mut self, id: LocalStreamId) {
758 if let Some(state) = self.substreams.remove(&id) {
759 match state {
760 SubstreamState::Closed { .. } => {
761 tracing::trace!(
762 connection=%self.id,
763 substream=%id,
764 "Ignoring reset for mutually closed substream"
765 );
766 }
767 SubstreamState::Reset { .. } => {
768 tracing::trace!(
769 connection=%self.id,
770 substream=%id,
771 "Ignoring redundant reset for already reset substream"
772 );
773 }
774 SubstreamState::RecvClosed { buf }
775 | SubstreamState::SendClosed { buf }
776 | SubstreamState::Open { buf } => {
777 tracing::debug!(
778 connection=%self.id,
779 substream=%id,
780 "Substream reset by remote"
781 );
782 self.substreams.insert(id, SubstreamState::Reset { buf });
783 NotifierRead::wake_read_stream(&self.notifier_read, id);
786 }
787 }
788 } else {
789 tracing::trace!(
790 connection=%self.id,
791 substream=%id,
792 "Ignoring `Reset` for unknown substream, possibly dropped earlier"
793 );
794 }
795 }
796
797 fn on_close(&mut self, id: LocalStreamId) {
799 if let Some(state) = self.substreams.remove(&id) {
800 match state {
801 SubstreamState::RecvClosed { .. } | SubstreamState::Closed { .. } => {
802 tracing::debug!(
803 connection=%self.id,
804 substream=%id,
805 "Ignoring `Close` frame for closed substream"
806 );
807 self.substreams.insert(id, state);
808 }
809 SubstreamState::Reset { buf } => {
810 tracing::debug!(
811 connection=%self.id,
812 substream=%id,
813 "Ignoring `Close` frame for already reset substream"
814 );
815 self.substreams.insert(id, SubstreamState::Reset { buf });
816 }
817 SubstreamState::SendClosed { buf } => {
818 tracing::debug!(
819 connection=%self.id,
820 substream=%id,
821 "Substream closed by remote (SendClosed -> Closed)"
822 );
823 self.substreams.insert(id, SubstreamState::Closed { buf });
824 self.notifier_read.wake_read_stream(id);
826 }
827 SubstreamState::Open { buf } => {
828 tracing::debug!(
829 connection=%self.id,
830 substream=%id,
831 "Substream closed by remote (Open -> RecvClosed)"
832 );
833 self.substreams
834 .insert(id, SubstreamState::RecvClosed { buf });
835 self.notifier_read.wake_read_stream(id);
837 }
838 }
839 } else {
840 tracing::trace!(
841 connection=%self.id,
842 substream=%id,
843 "Ignoring `Close` for unknown substream, possibly dropped earlier."
844 );
845 }
846 }
847
848 fn next_outbound_stream_id(&mut self) -> LocalStreamId {
850 let id = self.next_outbound_stream_id;
851 self.next_outbound_stream_id = self.next_outbound_stream_id.next();
852 id
853 }
854
855 fn can_read(&self, id: &LocalStreamId) -> bool {
857 matches!(
858 self.substreams.get(id),
859 Some(SubstreamState::Open { .. }) | Some(SubstreamState::SendClosed { .. })
860 )
861 }
862
863 fn send_pending_frames(&mut self, cx: &Context<'_>) -> Poll<io::Result<()>> {
865 while let Some(frame) = self.pending_frames.pop_back() {
866 if self.poll_send_frame(cx, || frame.clone())?.is_pending() {
867 self.pending_frames.push_back(frame);
868 return Poll::Pending;
869 }
870 }
871
872 Poll::Ready(Ok(()))
873 }
874
875 fn on_error<T>(&mut self, e: io::Error) -> io::Result<T> {
877 tracing::debug!(
878 connection=%self.id,
879 "Multiplexed connection failed: {:?}",
880 e
881 );
882 self.status = Status::Err(io::Error::new(e.kind(), e.to_string()));
883 self.pending_frames = Default::default();
884 self.substreams = Default::default();
885 self.open_buffer = Default::default();
886 Err(e)
887 }
888
889 fn guard_open(&self) -> io::Result<()> {
892 match &self.status {
893 Status::Closed => Err(io::Error::new(io::ErrorKind::Other, "Connection is closed")),
894 Status::Err(e) => Err(io::Error::new(e.kind(), e.to_string())),
895 Status::Open => Ok(()),
896 }
897 }
898
899 fn check_max_pending_frames(&mut self) -> io::Result<()> {
902 if self.pending_frames.len() >= self.config.max_substreams + EXTRA_PENDING_FRAMES {
903 return self.on_error(io::Error::new(
904 io::ErrorKind::Other,
905 "Too many pending frames.",
906 ));
907 }
908 Ok(())
909 }
910
911 fn buffer(&mut self, id: LocalStreamId, data: Bytes) -> io::Result<()> {
921 let Some(state) = self.substreams.get_mut(&id) else {
922 tracing::trace!(
923 connection=%self.id,
924 substream=%id,
925 data=?data,
926 "Dropping data for unknown substream"
927 );
928 return Ok(());
929 };
930
931 let Some(buf) = state.recv_buf_open() else {
932 tracing::trace!(
933 connection=%self.id,
934 substream=%id,
935 data=?data,
936 "Dropping data for closed or reset substream",
937 );
938 return Ok(());
939 };
940
941 debug_assert!(buf.len() <= self.config.max_buffer_len);
942 tracing::trace!(
943 connection=%self.id,
944 substream=%id,
945 data=?data,
946 data_buffer=%buf.len() + 1,
947 "Buffering data for substream"
948 );
949 buf.push(data);
950 self.notifier_read.wake_read_stream(id);
951 if buf.len() > self.config.max_buffer_len {
952 tracing::debug!(
953 connection=%self.id,
954 substream=%id,
955 "Frame buffer of substream is full"
956 );
957 match self.config.max_buffer_behaviour {
958 MaxBufferBehaviour::ResetStream => {
959 let buf = buf.clone();
960 self.check_max_pending_frames()?;
961 self.substreams.insert(id, SubstreamState::Reset { buf });
962 tracing::debug!(
963 connection=%self.id,
964 substream=%id,
965 "Pending reset for stream"
966 );
967 self.pending_frames
968 .push_front(Frame::Reset { stream_id: id });
969 }
970 MaxBufferBehaviour::Block => {
971 self.blocking_stream = Some(id);
972 }
973 }
974 }
975
976 Ok(())
977 }
978}
979
980type RecvBuf = SmallVec<[Bytes; 10]>;
981
982#[derive(Clone, Debug)]
984enum SubstreamState {
985 Open { buf: RecvBuf },
987 SendClosed { buf: RecvBuf },
990 RecvClosed { buf: RecvBuf },
993 Closed { buf: RecvBuf },
997 Reset { buf: RecvBuf },
1000}
1001
1002impl SubstreamState {
1003 fn recv_buf(&mut self) -> &mut RecvBuf {
1005 match self {
1006 SubstreamState::Open { buf } => buf,
1007 SubstreamState::SendClosed { buf } => buf,
1008 SubstreamState::RecvClosed { buf } => buf,
1009 SubstreamState::Closed { buf } => buf,
1010 SubstreamState::Reset { buf } => buf,
1011 }
1012 }
1013
1014 fn recv_buf_open(&mut self) -> Option<&mut RecvBuf> {
1017 match self {
1018 SubstreamState::Open { buf } => Some(buf),
1019 SubstreamState::SendClosed { buf } => Some(buf),
1020 SubstreamState::RecvClosed { .. } => None,
1021 SubstreamState::Closed { .. } => None,
1022 SubstreamState::Reset { .. } => None,
1023 }
1024 }
1025}
1026
1027struct NotifierRead {
1028 next_stream: AtomicWaker,
1031 read_stream: Mutex<IntMap<LocalStreamId, Waker>>,
1034}
1035
1036impl NotifierRead {
1037 #[must_use]
1043 fn register_read_stream<'a>(
1044 self: &'a Arc<Self>,
1045 waker: &Waker,
1046 id: LocalStreamId,
1047 ) -> WakerRef<'a> {
1048 let mut pending = self.read_stream.lock();
1049 pending.insert(id, waker.clone());
1050 waker_ref(self)
1051 }
1052
1053 #[must_use]
1058 fn register_next_stream<'a>(self: &'a Arc<Self>, waker: &Waker) -> WakerRef<'a> {
1059 self.next_stream.register(waker);
1060 waker_ref(self)
1061 }
1062
1063 fn wake_read_stream(&self, id: LocalStreamId) -> bool {
1066 let mut pending = self.read_stream.lock();
1067
1068 if let Some(waker) = pending.remove(&id) {
1069 waker.wake();
1070 return true;
1071 }
1072
1073 false
1074 }
1075
1076 fn wake_next_stream(&self) {
1078 self.next_stream.wake();
1079 }
1080}
1081
1082impl ArcWake for NotifierRead {
1083 fn wake_by_ref(this: &Arc<Self>) {
1084 let wakers = mem::take(&mut *this.read_stream.lock());
1085 for (_, waker) in wakers {
1086 waker.wake();
1087 }
1088 this.wake_next_stream();
1089 }
1090}
1091
1092struct NotifierWrite {
1093 pending: Mutex<Vec<Waker>>,
1096}
1097
1098impl NotifierWrite {
1099 #[must_use]
1104 fn register<'a>(self: &'a Arc<Self>, waker: &Waker) -> WakerRef<'a> {
1105 let mut pending = self.pending.lock();
1106 if pending.iter().all(|w| !w.will_wake(waker)) {
1107 pending.push(waker.clone());
1108 }
1109 waker_ref(self)
1110 }
1111}
1112
1113impl ArcWake for NotifierWrite {
1114 fn wake_by_ref(this: &Arc<Self>) {
1115 let wakers = mem::take(&mut *this.pending.lock());
1116 for waker in wakers {
1117 waker.wake();
1118 }
1119 }
1120}
1121
1122struct NotifierOpen {
1123 pending: Vec<Waker>,
1126}
1127
1128impl NotifierOpen {
1129 fn register(&mut self, waker: &Waker) {
1131 if self.pending.iter().all(|w| !w.will_wake(waker)) {
1132 self.pending.push(waker.clone());
1133 }
1134 }
1135
1136 fn wake_all(&mut self) {
1137 let wakers = mem::take(&mut self.pending);
1138 for waker in wakers {
1139 waker.wake();
1140 }
1141 }
1142}
1143
1144const EXTRA_PENDING_FRAMES: usize = 1000;
1152
1153#[cfg(test)]
1154mod tests {
1155 use std::{collections::HashSet, num::NonZeroU8, ops::DerefMut, pin::Pin};
1156
1157 use async_std::task;
1158 use asynchronous_codec::{Decoder, Encoder};
1159 use bytes::BytesMut;
1160 use quickcheck::*;
1161
1162 use super::*;
1163
1164 impl Arbitrary for MaxBufferBehaviour {
1165 fn arbitrary(g: &mut Gen) -> MaxBufferBehaviour {
1166 *g.choose(&[MaxBufferBehaviour::Block, MaxBufferBehaviour::ResetStream])
1167 .unwrap()
1168 }
1169 }
1170
1171 impl Arbitrary for Config {
1172 fn arbitrary(g: &mut Gen) -> Config {
1173 Config {
1174 max_substreams: g.gen_range(1..100),
1175 max_buffer_len: g.gen_range(1..1000),
1176 max_buffer_behaviour: MaxBufferBehaviour::arbitrary(g),
1177 split_send_size: g.gen_range(1..10000),
1178 protocol_name: crate::config::DEFAULT_MPLEX_PROTOCOL_NAME,
1179 }
1180 }
1181 }
1182
1183 struct Connection {
1185 r_buf: BytesMut,
1187 w_buf: BytesMut,
1189 eof: bool,
1191 }
1192
1193 impl AsyncRead for Connection {
1194 fn poll_read(
1195 mut self: Pin<&mut Self>,
1196 _: &mut Context<'_>,
1197 buf: &mut [u8],
1198 ) -> Poll<io::Result<usize>> {
1199 if self.eof {
1200 return Poll::Ready(Err(io::ErrorKind::UnexpectedEof.into()));
1201 }
1202 let n = std::cmp::min(buf.len(), self.r_buf.len());
1203 let data = self.r_buf.split_to(n);
1204 buf[..n].copy_from_slice(&data[..]);
1205 if n == 0 {
1206 Poll::Pending
1207 } else {
1208 Poll::Ready(Ok(n))
1209 }
1210 }
1211 }
1212
1213 impl AsyncWrite for Connection {
1214 fn poll_write(
1215 mut self: Pin<&mut Self>,
1216 _: &mut Context<'_>,
1217 buf: &[u8],
1218 ) -> Poll<io::Result<usize>> {
1219 self.w_buf.extend_from_slice(buf);
1220 Poll::Ready(Ok(buf.len()))
1221 }
1222
1223 fn poll_flush(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<io::Result<()>> {
1224 Poll::Ready(Ok(()))
1225 }
1226
1227 fn poll_close(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<io::Result<()>> {
1228 Poll::Ready(Ok(()))
1229 }
1230 }
1231
1232 #[test]
1233 fn max_buffer_behaviour() {
1234 use tracing_subscriber::EnvFilter;
1235 let _ = tracing_subscriber::fmt()
1236 .with_env_filter(EnvFilter::from_default_env())
1237 .try_init();
1238
1239 fn prop(cfg: Config, overflow: NonZeroU8) {
1240 let mut r_buf = BytesMut::new();
1241 let mut codec = Codec::new();
1242
1243 for i in 0..cfg.max_substreams {
1245 let stream_id = LocalStreamId::dialer(i as u64);
1246 codec.encode(Frame::Open { stream_id }, &mut r_buf).unwrap();
1247 }
1248
1249 let stream_id = LocalStreamId::dialer(0);
1251 let data = Bytes::from("Hello world");
1252 for _ in 0..cfg.max_buffer_len + overflow.get() as usize {
1253 codec
1254 .encode(
1255 Frame::Data {
1256 stream_id,
1257 data: data.clone(),
1258 },
1259 &mut r_buf,
1260 )
1261 .unwrap();
1262 }
1263
1264 let conn = Connection {
1266 r_buf,
1267 w_buf: BytesMut::new(),
1268 eof: false,
1269 };
1270 let mut m = Multiplexed::new(conn, cfg.clone());
1271
1272 task::block_on(future::poll_fn(move |cx| {
1273 for i in 0..cfg.max_substreams {
1275 match m.poll_next_stream(cx) {
1276 Poll::Pending => panic!("Expected new inbound stream."),
1277 Poll::Ready(Err(e)) => panic!("{e:?}"),
1278 Poll::Ready(Ok(id)) => {
1279 assert_eq!(id, LocalStreamId::listener(i as u64));
1280 }
1281 };
1282 }
1283
1284 let id = LocalStreamId::listener(0);
1287 match m.poll_next_stream(cx) {
1288 Poll::Ready(r) => panic!("Unexpected result for next stream: {r:?}"),
1289 Poll::Pending => {
1290 assert_eq!(
1296 m.substreams.get_mut(&id).unwrap().recv_buf().len(),
1297 cfg.max_buffer_len
1298 );
1299 match m.poll_next_stream(cx) {
1300 Poll::Ready(r) => panic!("Unexpected result for next stream: {r:?}"),
1301 Poll::Pending => {
1302 assert_eq!(
1305 m.substreams.get_mut(&id).unwrap().recv_buf().len(),
1306 cfg.max_buffer_len + 1
1307 );
1308 }
1309 }
1310 }
1311 }
1312
1313 match cfg.max_buffer_behaviour {
1316 MaxBufferBehaviour::ResetStream => {
1317 let _ = m.poll_flush_stream(cx, id);
1318 let w_buf = &mut m.io.get_mut().deref_mut().w_buf;
1319 let frame = codec.decode(w_buf).unwrap();
1320 let stream_id = stream_id.into_remote();
1321 assert_eq!(frame, Some(Frame::Reset { stream_id }));
1322 }
1323 MaxBufferBehaviour::Block => {
1324 assert!(m.poll_next_stream(cx).is_pending());
1325 for i in 1..cfg.max_substreams {
1326 let id = LocalStreamId::listener(i as u64);
1327 assert!(m.poll_read_stream(cx, id).is_pending());
1328 }
1329 }
1330 }
1331
1332 for _ in 0..cfg.max_buffer_len + 1 {
1334 match m.poll_read_stream(cx, id) {
1335 Poll::Ready(Ok(Some(bytes))) => {
1336 assert_eq!(bytes, data);
1337 }
1338 x => panic!("Unexpected: {x:?}"),
1339 }
1340 }
1341
1342 match cfg.max_buffer_behaviour {
1346 MaxBufferBehaviour::ResetStream => {
1347 match m.poll_read_stream(cx, id) {
1349 Poll::Ready(Ok(None)) => {}
1350 poll => panic!("Unexpected: {poll:?}"),
1351 }
1352 }
1353 MaxBufferBehaviour::Block => {
1354 match m.poll_read_stream(cx, id) {
1356 Poll::Ready(Ok(Some(bytes))) => assert_eq!(bytes, data),
1357 Poll::Pending => assert_eq!(overflow.get(), 1),
1358 poll => panic!("Unexpected: {poll:?}"),
1359 }
1360 }
1361 }
1362
1363 Poll::Ready(())
1364 }));
1365 }
1366
1367 quickcheck(prop as fn(_, _))
1368 }
1369
1370 #[test]
1371 fn close_on_error() {
1372 use tracing_subscriber::EnvFilter;
1373 let _ = tracing_subscriber::fmt()
1374 .with_env_filter(EnvFilter::from_default_env())
1375 .try_init();
1376
1377 fn prop(cfg: Config, num_streams: NonZeroU8) {
1378 let num_streams = cmp::min(cfg.max_substreams, num_streams.get() as usize);
1379
1380 let conn = Connection {
1382 r_buf: BytesMut::new(),
1383 w_buf: BytesMut::new(),
1384 eof: false,
1385 };
1386 let mut m = Multiplexed::new(conn, cfg);
1387
1388 let mut opened = HashSet::new();
1390 task::block_on(future::poll_fn(move |cx| {
1391 for _ in 0..num_streams {
1393 let id = ready!(m.poll_open_stream(cx)).unwrap();
1394 assert!(opened.insert(id));
1395 assert!(m.poll_read_stream(cx, id).is_pending());
1396 }
1397
1398 m.io.get_mut().deref_mut().eof = true;
1400
1401 assert!(opened.iter().all(|id| match m.poll_read_stream(cx, *id) {
1404 Poll::Ready(Err(e)) => e.kind() == io::ErrorKind::UnexpectedEof,
1405 _ => false,
1406 }));
1407
1408 assert!(m.substreams.is_empty());
1409
1410 Poll::Ready(())
1411 }))
1412 }
1413
1414 quickcheck(prop as fn(_, _))
1415 }
1416}