1pub(crate) use std::io::{Error, Result};
22use std::{
23 cmp,
24 collections::VecDeque,
25 fmt, io, mem,
26 sync::Arc,
27 task::{Context, Poll, Waker},
28};
29
30use asynchronous_codec::Framed;
31use bytes::Bytes;
32use futures::{
33 prelude::*,
34 ready,
35 stream::Fuse,
36 task::{waker_ref, ArcWake, AtomicWaker, WakerRef},
37};
38use nohash_hasher::{IntMap, IntSet};
39use parking_lot::Mutex;
40use smallvec::SmallVec;
41
42use crate::{
43 codec::{Codec, Frame, LocalStreamId, RemoteStreamId},
44 Config, MaxBufferBehaviour,
45};
46#[derive(Clone, Copy)]
51struct ConnectionId(u64);
52
53impl fmt::Debug for ConnectionId {
54 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
55 write!(f, "{:16x}", self.0)
56 }
57}
58
59impl fmt::Display for ConnectionId {
60 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
61 write!(f, "{:16x}", self.0)
62 }
63}
64pub(crate) struct Multiplexed<C> {
66 id: ConnectionId,
68 status: Status,
70 io: Fuse<Framed<C, Codec>>,
72 config: Config,
74 open_buffer: VecDeque<LocalStreamId>,
78 pending_flush_open: IntSet<LocalStreamId>,
81 blocking_stream: Option<LocalStreamId>,
85 pending_frames: VecDeque<Frame<LocalStreamId>>,
93 substreams: IntMap<LocalStreamId, SubstreamState>,
95 next_outbound_stream_id: LocalStreamId,
97 notifier_read: Arc<NotifierRead>,
99 notifier_write: Arc<NotifierWrite>,
101 notifier_open: NotifierOpen,
107}
108
109#[derive(Debug)]
111enum Status {
112 Open,
114 Closed,
116 Err(io::Error),
118}
119
120impl<C> Multiplexed<C>
121where
122 C: AsyncRead + AsyncWrite + Unpin,
123{
124 pub(crate) fn new(io: C, config: Config) -> Self {
126 let id = ConnectionId(rand::random());
127 tracing::debug!(connection=%id, "New multiplexed connection");
128 Multiplexed {
129 id,
130 config,
131 status: Status::Open,
132 io: Framed::new(io, Codec::new()).fuse(),
133 open_buffer: Default::default(),
134 substreams: Default::default(),
135 pending_flush_open: Default::default(),
136 pending_frames: Default::default(),
137 blocking_stream: None,
138 next_outbound_stream_id: LocalStreamId::dialer(0),
139 notifier_read: Arc::new(NotifierRead {
140 read_stream: Mutex::new(Default::default()),
141 next_stream: AtomicWaker::new(),
142 }),
143 notifier_write: Arc::new(NotifierWrite {
144 pending: Mutex::new(Default::default()),
145 }),
146 notifier_open: NotifierOpen {
147 pending: Default::default(),
148 },
149 }
150 }
151
152 pub(crate) fn poll_flush(&mut self, cx: &Context<'_>) -> Poll<io::Result<()>> {
154 match &self.status {
155 Status::Closed => return Poll::Ready(Ok(())),
156 Status::Err(e) => return Poll::Ready(Err(io::Error::new(e.kind(), e.to_string()))),
157 Status::Open => {}
158 }
159
160 ready!(self.send_pending_frames(cx))?;
162
163 let waker = NotifierWrite::register(&self.notifier_write, cx.waker());
165 match ready!(self.io.poll_flush_unpin(&mut Context::from_waker(&waker))) {
166 Err(e) => Poll::Ready(self.on_error(e)),
167 Ok(()) => {
168 self.pending_flush_open = Default::default();
169 Poll::Ready(Ok(()))
170 }
171 }
172 }
173
174 pub(crate) fn poll_close(&mut self, cx: &Context<'_>) -> Poll<io::Result<()>> {
180 match &self.status {
181 Status::Closed => return Poll::Ready(Ok(())),
182 Status::Err(e) => return Poll::Ready(Err(io::Error::new(e.kind(), e.to_string()))),
183 Status::Open => {}
184 }
185
186 let waker = NotifierWrite::register(&self.notifier_write, cx.waker());
190 match self.io.poll_close_unpin(&mut Context::from_waker(&waker)) {
191 Poll::Pending => Poll::Pending,
192 Poll::Ready(Err(e)) => Poll::Ready(self.on_error(e)),
193 Poll::Ready(Ok(())) => {
194 self.pending_frames = VecDeque::new();
195 self.open_buffer = Default::default();
198 self.substreams = Default::default();
199 self.status = Status::Closed;
200 Poll::Ready(Ok(()))
201 }
202 }
203 }
204
205 pub(crate) fn poll_next_stream(&mut self, cx: &Context<'_>) -> Poll<io::Result<LocalStreamId>> {
219 self.guard_open()?;
220
221 if let Some(stream_id) = self.open_buffer.pop_back() {
223 return Poll::Ready(Ok(stream_id));
224 }
225
226 debug_assert!(self.open_buffer.is_empty());
227 let mut num_buffered = 0;
228
229 loop {
230 if num_buffered == self.config.max_buffer_len {
235 cx.waker().wake_by_ref();
236 return Poll::Pending;
237 }
238
239 match ready!(self.poll_read_frame(cx, None))? {
241 Frame::Open { stream_id } => {
242 if let Some(id) = self.on_open(stream_id)? {
243 return Poll::Ready(Ok(id));
244 }
245 }
246 Frame::Data { stream_id, data } => {
247 self.buffer(stream_id.into_local(), data)?;
248 num_buffered += 1;
249 }
250 Frame::Close { stream_id } => {
251 self.on_close(stream_id.into_local());
252 }
253 Frame::Reset { stream_id } => self.on_reset(stream_id.into_local()),
254 }
255 }
256 }
257
258 pub(crate) fn poll_open_stream(&mut self, cx: &Context<'_>) -> Poll<io::Result<LocalStreamId>> {
260 self.guard_open()?;
261
262 if self.substreams.len() >= self.config.max_substreams {
264 tracing::debug!(
265 connection=%self.id,
266 total_substreams=%self.substreams.len(),
267 max_substreams=%self.config.max_substreams,
268 "Maximum number of substreams reached"
269 );
270 self.notifier_open.register(cx.waker());
271 return Poll::Pending;
272 }
273
274 let waker = NotifierWrite::register(&self.notifier_write, cx.waker());
276 match ready!(self.io.poll_ready_unpin(&mut Context::from_waker(&waker))) {
277 Ok(()) => {
278 let stream_id = self.next_outbound_stream_id();
279 let frame = Frame::Open { stream_id };
280 match self.io.start_send_unpin(frame) {
281 Ok(()) => {
282 self.substreams.insert(
283 stream_id,
284 SubstreamState::Open {
285 buf: Default::default(),
286 },
287 );
288 tracing::debug!(
289 connection=%self.id,
290 substream=%stream_id,
291 total_substreams=%self.substreams.len(),
292 "New outbound substream"
293 );
294 self.pending_flush_open.insert(stream_id);
297 Poll::Ready(Ok(stream_id))
298 }
299 Err(e) => Poll::Ready(self.on_error(e)),
300 }
301 }
302 Err(e) => Poll::Ready(self.on_error(e)),
303 }
304 }
305
306 pub(crate) fn drop_stream(&mut self, id: LocalStreamId) {
327 match self.status {
329 Status::Closed | Status::Err(_) => return,
330 Status::Open => {}
331 }
332
333 self.notifier_read.wake_read_stream(id);
338
339 match self.substreams.remove(&id) {
341 None => {}
342 Some(state) => {
343 let below_limit = self.substreams.len() == self.config.max_substreams - 1;
346 if below_limit {
347 self.notifier_open.wake_all();
348 }
349 match state {
351 SubstreamState::Closed { .. } => {}
352 SubstreamState::SendClosed { .. } => {}
353 SubstreamState::Reset { .. } => {}
354 SubstreamState::RecvClosed { .. } => {
355 if self.check_max_pending_frames().is_err() {
356 return;
357 }
358 tracing::trace!(
359 connection=%self.id,
360 substream=%id,
361 "Pending close for substream"
362 );
363 self.pending_frames
364 .push_front(Frame::Close { stream_id: id });
365 }
366 SubstreamState::Open { .. } => {
367 if self.check_max_pending_frames().is_err() {
368 return;
369 }
370 tracing::trace!(
371 connection=%self.id,
372 substream=%id,
373 "Pending reset for substream"
374 );
375 self.pending_frames
376 .push_front(Frame::Reset { stream_id: id });
377 }
378 }
379 }
380 }
381 }
382
383 pub(crate) fn poll_write_stream(
385 &mut self,
386 cx: &Context<'_>,
387 id: LocalStreamId,
388 buf: &[u8],
389 ) -> Poll<io::Result<usize>> {
390 self.guard_open()?;
391
392 match self.substreams.get(&id) {
394 None | Some(SubstreamState::Reset { .. }) => {
395 return Poll::Ready(Err(io::ErrorKind::BrokenPipe.into()))
396 }
397 Some(SubstreamState::SendClosed { .. }) | Some(SubstreamState::Closed { .. }) => {
398 return Poll::Ready(Err(io::ErrorKind::WriteZero.into()))
399 }
400 Some(SubstreamState::Open { .. }) | Some(SubstreamState::RecvClosed { .. }) => {
401 }
403 }
404
405 let frame_len = cmp::min(buf.len(), self.config.split_send_size);
407
408 ready!(self.poll_send_frame(cx, || {
410 let data = Bytes::copy_from_slice(&buf[..frame_len]);
411 Frame::Data {
412 stream_id: id,
413 data,
414 }
415 }))?;
416
417 Poll::Ready(Ok(frame_len))
418 }
419
420 pub(crate) fn poll_read_stream(
435 &mut self,
436 cx: &Context<'_>,
437 id: LocalStreamId,
438 ) -> Poll<io::Result<Option<Bytes>>> {
439 self.guard_open()?;
440
441 if let Some(state) = self.substreams.get_mut(&id) {
443 let buf = state.recv_buf();
444 if !buf.is_empty() {
445 if self.blocking_stream == Some(id) {
446 self.blocking_stream = None;
448 ArcWake::wake_by_ref(&self.notifier_read);
449 }
450 let data = buf.remove(0);
451 return Poll::Ready(Ok(Some(data)));
452 }
453 buf.shrink_to_fit();
455 }
456
457 let mut num_buffered = 0;
458
459 loop {
460 if num_buffered == self.config.max_buffer_len {
465 cx.waker().wake_by_ref();
466 return Poll::Pending;
467 }
468
469 if !self.can_read(&id) {
471 return Poll::Ready(Ok(None));
477 }
478
479 match ready!(self.poll_read_frame(cx, Some(id)))? {
481 Frame::Data { data, stream_id } if stream_id.into_local() == id => {
482 return Poll::Ready(Ok(Some(data)))
483 }
484 Frame::Data { stream_id, data } => {
485 self.buffer(stream_id.into_local(), data)?;
489 num_buffered += 1;
490 }
491 frame @ Frame::Open { .. } => {
492 if let Some(id) = self.on_open(frame.remote_id())? {
493 self.open_buffer.push_front(id);
494 tracing::trace!(
495 connection=%self.id,
496 inbound_stream=%id,
497 inbound_buffer_len=%self.open_buffer.len(),
498 "Buffered new inbound stream"
499 );
500 self.notifier_read.wake_next_stream();
501 }
502 }
503 Frame::Close { stream_id } => {
504 let stream_id = stream_id.into_local();
505 self.on_close(stream_id);
506 if id == stream_id {
507 return Poll::Ready(Ok(None));
508 }
509 }
510 Frame::Reset { stream_id } => {
511 let stream_id = stream_id.into_local();
512 self.on_reset(stream_id);
513 if id == stream_id {
514 return Poll::Ready(Ok(None));
515 }
516 }
517 }
518 }
519 }
520
521 pub(crate) fn poll_flush_stream(
527 &mut self,
528 cx: &Context<'_>,
529 id: LocalStreamId,
530 ) -> Poll<io::Result<()>> {
531 self.guard_open()?;
532
533 ready!(self.poll_flush(cx))?;
534 tracing::trace!(
535 connection=%self.id,
536 substream=%id,
537 "Flushed substream"
538 );
539
540 Poll::Ready(Ok(()))
541 }
542
543 pub(crate) fn poll_close_stream(
547 &mut self,
548 cx: &Context<'_>,
549 id: LocalStreamId,
550 ) -> Poll<io::Result<()>> {
551 self.guard_open()?;
552
553 match self.substreams.remove(&id) {
554 None => Poll::Ready(Ok(())),
555 Some(SubstreamState::SendClosed { buf }) => {
556 self.substreams
557 .insert(id, SubstreamState::SendClosed { buf });
558 Poll::Ready(Ok(()))
559 }
560 Some(SubstreamState::Closed { buf }) => {
561 self.substreams.insert(id, SubstreamState::Closed { buf });
562 Poll::Ready(Ok(()))
563 }
564 Some(SubstreamState::Reset { buf }) => {
565 self.substreams.insert(id, SubstreamState::Reset { buf });
566 Poll::Ready(Ok(()))
567 }
568 Some(SubstreamState::Open { buf }) => {
569 if self
570 .poll_send_frame(cx, || Frame::Close { stream_id: id })?
571 .is_pending()
572 {
573 self.substreams.insert(id, SubstreamState::Open { buf });
574 Poll::Pending
575 } else {
576 tracing::debug!(
577 connection=%self.id,
578 substream=%id,
579 "Closed substream (half-close)"
580 );
581 self.substreams
582 .insert(id, SubstreamState::SendClosed { buf });
583 Poll::Ready(Ok(()))
584 }
585 }
586 Some(SubstreamState::RecvClosed { buf }) => {
587 if self
588 .poll_send_frame(cx, || Frame::Close { stream_id: id })?
589 .is_pending()
590 {
591 self.substreams
592 .insert(id, SubstreamState::RecvClosed { buf });
593 Poll::Pending
594 } else {
595 tracing::debug!(
596 connection=%self.id,
597 substream=%id,
598 "Closed substream"
599 );
600 self.substreams.insert(id, SubstreamState::Closed { buf });
601 Poll::Ready(Ok(()))
602 }
603 }
604 }
605 }
606
607 fn poll_send_frame<F>(&mut self, cx: &Context<'_>, frame: F) -> Poll<io::Result<()>>
612 where
613 F: FnOnce() -> Frame<LocalStreamId>,
614 {
615 let waker = NotifierWrite::register(&self.notifier_write, cx.waker());
616 match ready!(self.io.poll_ready_unpin(&mut Context::from_waker(&waker))) {
617 Ok(()) => {
618 let frame = frame();
619 tracing::trace!(connection=%self.id, ?frame, "Sending frame");
620 match self.io.start_send_unpin(frame) {
621 Ok(()) => Poll::Ready(Ok(())),
622 Err(e) => Poll::Ready(self.on_error(e)),
623 }
624 }
625 Err(e) => Poll::Ready(self.on_error(e)),
626 }
627 }
628
629 fn poll_read_frame(
636 &mut self,
637 cx: &Context<'_>,
638 stream_id: Option<LocalStreamId>,
639 ) -> Poll<io::Result<Frame<RemoteStreamId>>> {
640 if let Poll::Ready(Err(e)) = self.send_pending_frames(cx) {
642 return Poll::Ready(Err(e));
643 }
644
645 if let Some(id) = &stream_id {
647 if self.pending_flush_open.contains(id) {
648 tracing::trace!(
649 connection=%self.id,
650 substream=%id,
651 "Executing pending flush for substream"
652 );
653 ready!(self.poll_flush(cx))?;
654 self.pending_flush_open = Default::default();
655 }
656 }
657
658 if let Some(blocked_id) = &self.blocking_stream {
660 if !self.notifier_read.wake_read_stream(*blocked_id) {
666 tracing::trace!(
669 connection=%self.id,
670 "No task to read from blocked stream. Waking current task."
671 );
672 cx.waker().wake_by_ref();
673 } else if let Some(id) = stream_id {
674 debug_assert!(
677 blocked_id != &id,
678 "Unexpected attempt at reading a new \
679 frame from a substream with a full buffer."
680 );
681 let _ = NotifierRead::register_read_stream(&self.notifier_read, cx.waker(), id);
682 } else {
683 let _ = NotifierRead::register_next_stream(&self.notifier_read, cx.waker());
686 }
687
688 return Poll::Pending;
689 }
690
691 let waker = match stream_id {
693 Some(id) => NotifierRead::register_read_stream(&self.notifier_read, cx.waker(), id),
694 None => NotifierRead::register_next_stream(&self.notifier_read, cx.waker()),
695 };
696 match ready!(self.io.poll_next_unpin(&mut Context::from_waker(&waker))) {
697 Some(Ok(frame)) => {
698 tracing::trace!(connection=%self.id, ?frame, "Received frame");
699 Poll::Ready(Ok(frame))
700 }
701 Some(Err(e)) => Poll::Ready(self.on_error(e)),
702 None => Poll::Ready(self.on_error(io::ErrorKind::UnexpectedEof.into())),
703 }
704 }
705
706 fn on_open(&mut self, id: RemoteStreamId) -> io::Result<Option<LocalStreamId>> {
708 let id = id.into_local();
709
710 if self.substreams.contains_key(&id) {
711 tracing::debug!(
712 connection=%self.id,
713 substream=%id,
714 "Received unexpected `Open` frame for open substream",
715 );
716 return self.on_error(io::Error::other(
717 "Protocol error: Received `Open` frame for open substream.",
718 ));
719 }
720
721 if self.substreams.len() >= self.config.max_substreams {
722 tracing::debug!(
723 connection=%self.id,
724 max_substreams=%self.config.max_substreams,
725 "Maximum number of substreams exceeded"
726 );
727 self.check_max_pending_frames()?;
728 tracing::debug!(
729 connection=%self.id,
730 substream=%id,
731 "Pending reset for new substream"
732 );
733 self.pending_frames
734 .push_front(Frame::Reset { stream_id: id });
735 return Ok(None);
736 }
737
738 self.substreams.insert(
739 id,
740 SubstreamState::Open {
741 buf: Default::default(),
742 },
743 );
744
745 tracing::debug!(
746 connection=%self.id,
747 substream=%id,
748 total_substreams=%self.substreams.len(),
749 "New inbound substream"
750 );
751
752 Ok(Some(id))
753 }
754
755 fn on_reset(&mut self, id: LocalStreamId) {
757 if let Some(state) = self.substreams.remove(&id) {
758 match state {
759 SubstreamState::Closed { .. } => {
760 tracing::trace!(
761 connection=%self.id,
762 substream=%id,
763 "Ignoring reset for mutually closed substream"
764 );
765 }
766 SubstreamState::Reset { .. } => {
767 tracing::trace!(
768 connection=%self.id,
769 substream=%id,
770 "Ignoring redundant reset for already reset substream"
771 );
772 }
773 SubstreamState::RecvClosed { buf }
774 | SubstreamState::SendClosed { buf }
775 | SubstreamState::Open { buf } => {
776 tracing::debug!(
777 connection=%self.id,
778 substream=%id,
779 "Substream reset by remote"
780 );
781 self.substreams.insert(id, SubstreamState::Reset { buf });
782 NotifierRead::wake_read_stream(&self.notifier_read, id);
785 }
786 }
787 } else {
788 tracing::trace!(
789 connection=%self.id,
790 substream=%id,
791 "Ignoring `Reset` for unknown substream, possibly dropped earlier"
792 );
793 }
794 }
795
796 fn on_close(&mut self, id: LocalStreamId) {
798 if let Some(state) = self.substreams.remove(&id) {
799 match state {
800 SubstreamState::RecvClosed { .. } | SubstreamState::Closed { .. } => {
801 tracing::debug!(
802 connection=%self.id,
803 substream=%id,
804 "Ignoring `Close` frame for closed substream"
805 );
806 self.substreams.insert(id, state);
807 }
808 SubstreamState::Reset { buf } => {
809 tracing::debug!(
810 connection=%self.id,
811 substream=%id,
812 "Ignoring `Close` frame for already reset substream"
813 );
814 self.substreams.insert(id, SubstreamState::Reset { buf });
815 }
816 SubstreamState::SendClosed { buf } => {
817 tracing::debug!(
818 connection=%self.id,
819 substream=%id,
820 "Substream closed by remote (SendClosed -> Closed)"
821 );
822 self.substreams.insert(id, SubstreamState::Closed { buf });
823 self.notifier_read.wake_read_stream(id);
825 }
826 SubstreamState::Open { buf } => {
827 tracing::debug!(
828 connection=%self.id,
829 substream=%id,
830 "Substream closed by remote (Open -> RecvClosed)"
831 );
832 self.substreams
833 .insert(id, SubstreamState::RecvClosed { buf });
834 self.notifier_read.wake_read_stream(id);
836 }
837 }
838 } else {
839 tracing::trace!(
840 connection=%self.id,
841 substream=%id,
842 "Ignoring `Close` for unknown substream, possibly dropped earlier."
843 );
844 }
845 }
846
847 fn next_outbound_stream_id(&mut self) -> LocalStreamId {
849 let id = self.next_outbound_stream_id;
850 self.next_outbound_stream_id = self.next_outbound_stream_id.next();
851 id
852 }
853
854 fn can_read(&self, id: &LocalStreamId) -> bool {
856 matches!(
857 self.substreams.get(id),
858 Some(SubstreamState::Open { .. }) | Some(SubstreamState::SendClosed { .. })
859 )
860 }
861
862 fn send_pending_frames(&mut self, cx: &Context<'_>) -> Poll<io::Result<()>> {
864 while let Some(frame) = self.pending_frames.pop_back() {
865 if self.poll_send_frame(cx, || frame.clone())?.is_pending() {
866 self.pending_frames.push_back(frame);
867 return Poll::Pending;
868 }
869 }
870
871 Poll::Ready(Ok(()))
872 }
873
874 fn on_error<T>(&mut self, e: io::Error) -> io::Result<T> {
876 tracing::debug!(
877 connection=%self.id,
878 "Multiplexed connection failed: {:?}",
879 e
880 );
881 self.status = Status::Err(io::Error::new(e.kind(), e.to_string()));
882 self.pending_frames = Default::default();
883 self.substreams = Default::default();
884 self.open_buffer = Default::default();
885 Err(e)
886 }
887
888 fn guard_open(&self) -> io::Result<()> {
891 match &self.status {
892 Status::Closed => Err(io::Error::other("Connection is closed")),
893 Status::Err(e) => Err(io::Error::new(e.kind(), e.to_string())),
894 Status::Open => Ok(()),
895 }
896 }
897
898 fn check_max_pending_frames(&mut self) -> io::Result<()> {
901 if self.pending_frames.len() >= self.config.max_substreams + EXTRA_PENDING_FRAMES {
902 return self.on_error(io::Error::other("Too many pending frames."));
903 }
904 Ok(())
905 }
906
907 fn buffer(&mut self, id: LocalStreamId, data: Bytes) -> io::Result<()> {
917 let Some(state) = self.substreams.get_mut(&id) else {
918 tracing::trace!(
919 connection=%self.id,
920 substream=%id,
921 data=?data,
922 "Dropping data for unknown substream"
923 );
924 return Ok(());
925 };
926
927 let Some(buf) = state.recv_buf_open() else {
928 tracing::trace!(
929 connection=%self.id,
930 substream=%id,
931 data=?data,
932 "Dropping data for closed or reset substream",
933 );
934 return Ok(());
935 };
936
937 debug_assert!(buf.len() <= self.config.max_buffer_len);
938 tracing::trace!(
939 connection=%self.id,
940 substream=%id,
941 data=?data,
942 data_buffer=%buf.len() + 1,
943 "Buffering data for substream"
944 );
945 buf.push(data);
946 self.notifier_read.wake_read_stream(id);
947 if buf.len() > self.config.max_buffer_len {
948 tracing::debug!(
949 connection=%self.id,
950 substream=%id,
951 "Frame buffer of substream is full"
952 );
953 match self.config.max_buffer_behaviour {
954 MaxBufferBehaviour::ResetStream => {
955 let buf = buf.clone();
956 self.check_max_pending_frames()?;
957 self.substreams.insert(id, SubstreamState::Reset { buf });
958 tracing::debug!(
959 connection=%self.id,
960 substream=%id,
961 "Pending reset for stream"
962 );
963 self.pending_frames
964 .push_front(Frame::Reset { stream_id: id });
965 }
966 MaxBufferBehaviour::Block => {
967 self.blocking_stream = Some(id);
968 }
969 }
970 }
971
972 Ok(())
973 }
974}
975
976type RecvBuf = SmallVec<[Bytes; 10]>;
977
978#[derive(Clone, Debug)]
980enum SubstreamState {
981 Open { buf: RecvBuf },
983 SendClosed { buf: RecvBuf },
986 RecvClosed { buf: RecvBuf },
989 Closed { buf: RecvBuf },
993 Reset { buf: RecvBuf },
996}
997
998impl SubstreamState {
999 fn recv_buf(&mut self) -> &mut RecvBuf {
1001 match self {
1002 SubstreamState::Open { buf } => buf,
1003 SubstreamState::SendClosed { buf } => buf,
1004 SubstreamState::RecvClosed { buf } => buf,
1005 SubstreamState::Closed { buf } => buf,
1006 SubstreamState::Reset { buf } => buf,
1007 }
1008 }
1009
1010 fn recv_buf_open(&mut self) -> Option<&mut RecvBuf> {
1013 match self {
1014 SubstreamState::Open { buf } => Some(buf),
1015 SubstreamState::SendClosed { buf } => Some(buf),
1016 SubstreamState::RecvClosed { .. } => None,
1017 SubstreamState::Closed { .. } => None,
1018 SubstreamState::Reset { .. } => None,
1019 }
1020 }
1021}
1022
1023struct NotifierRead {
1024 next_stream: AtomicWaker,
1027 read_stream: Mutex<IntMap<LocalStreamId, Waker>>,
1030}
1031
1032impl NotifierRead {
1033 #[must_use]
1039 fn register_read_stream<'a>(
1040 self: &'a Arc<Self>,
1041 waker: &Waker,
1042 id: LocalStreamId,
1043 ) -> WakerRef<'a> {
1044 let mut pending = self.read_stream.lock();
1045 pending.insert(id, waker.clone());
1046 waker_ref(self)
1047 }
1048
1049 #[must_use]
1054 fn register_next_stream<'a>(self: &'a Arc<Self>, waker: &Waker) -> WakerRef<'a> {
1055 self.next_stream.register(waker);
1056 waker_ref(self)
1057 }
1058
1059 fn wake_read_stream(&self, id: LocalStreamId) -> bool {
1062 let mut pending = self.read_stream.lock();
1063
1064 if let Some(waker) = pending.remove(&id) {
1065 waker.wake();
1066 return true;
1067 }
1068
1069 false
1070 }
1071
1072 fn wake_next_stream(&self) {
1074 self.next_stream.wake();
1075 }
1076}
1077
1078impl ArcWake for NotifierRead {
1079 fn wake_by_ref(this: &Arc<Self>) {
1080 let wakers = mem::take(&mut *this.read_stream.lock());
1081 for (_, waker) in wakers {
1082 waker.wake();
1083 }
1084 this.wake_next_stream();
1085 }
1086}
1087
1088struct NotifierWrite {
1089 pending: Mutex<Vec<Waker>>,
1092}
1093
1094impl NotifierWrite {
1095 #[must_use]
1100 fn register<'a>(self: &'a Arc<Self>, waker: &Waker) -> WakerRef<'a> {
1101 let mut pending = self.pending.lock();
1102 if pending.iter().all(|w| !w.will_wake(waker)) {
1103 pending.push(waker.clone());
1104 }
1105 waker_ref(self)
1106 }
1107}
1108
1109impl ArcWake for NotifierWrite {
1110 fn wake_by_ref(this: &Arc<Self>) {
1111 let wakers = mem::take(&mut *this.pending.lock());
1112 for waker in wakers {
1113 waker.wake();
1114 }
1115 }
1116}
1117
1118struct NotifierOpen {
1119 pending: Vec<Waker>,
1122}
1123
1124impl NotifierOpen {
1125 fn register(&mut self, waker: &Waker) {
1127 if self.pending.iter().all(|w| !w.will_wake(waker)) {
1128 self.pending.push(waker.clone());
1129 }
1130 }
1131
1132 fn wake_all(&mut self) {
1133 let wakers = mem::take(&mut self.pending);
1134 for waker in wakers {
1135 waker.wake();
1136 }
1137 }
1138}
1139
1140const EXTRA_PENDING_FRAMES: usize = 1000;
1148
1149#[cfg(test)]
1150mod tests {
1151 use std::{collections::HashSet, num::NonZeroU8, ops::DerefMut, pin::Pin};
1152
1153 use asynchronous_codec::{Decoder, Encoder};
1154 use bytes::BytesMut;
1155 use quickcheck::*;
1156 use tokio::runtime::Runtime;
1157
1158 use super::*;
1159
1160 impl Arbitrary for MaxBufferBehaviour {
1161 fn arbitrary(g: &mut Gen) -> MaxBufferBehaviour {
1162 *g.choose(&[MaxBufferBehaviour::Block, MaxBufferBehaviour::ResetStream])
1163 .unwrap()
1164 }
1165 }
1166
1167 impl Arbitrary for Config {
1168 fn arbitrary(g: &mut Gen) -> Config {
1169 Config {
1170 max_substreams: g.gen_range(1..100),
1171 max_buffer_len: g.gen_range(1..1000),
1172 max_buffer_behaviour: MaxBufferBehaviour::arbitrary(g),
1173 split_send_size: g.gen_range(1..10000),
1174 protocol_name: crate::config::DEFAULT_MPLEX_PROTOCOL_NAME,
1175 }
1176 }
1177 }
1178
1179 struct Connection {
1181 r_buf: BytesMut,
1183 w_buf: BytesMut,
1185 eof: bool,
1187 }
1188
1189 impl AsyncRead for Connection {
1190 fn poll_read(
1191 mut self: Pin<&mut Self>,
1192 _: &mut Context<'_>,
1193 buf: &mut [u8],
1194 ) -> Poll<io::Result<usize>> {
1195 if self.eof {
1196 return Poll::Ready(Err(io::ErrorKind::UnexpectedEof.into()));
1197 }
1198 let n = std::cmp::min(buf.len(), self.r_buf.len());
1199 let data = self.r_buf.split_to(n);
1200 buf[..n].copy_from_slice(&data[..]);
1201 if n == 0 {
1202 Poll::Pending
1203 } else {
1204 Poll::Ready(Ok(n))
1205 }
1206 }
1207 }
1208
1209 impl AsyncWrite for Connection {
1210 fn poll_write(
1211 mut self: Pin<&mut Self>,
1212 _: &mut Context<'_>,
1213 buf: &[u8],
1214 ) -> Poll<io::Result<usize>> {
1215 self.w_buf.extend_from_slice(buf);
1216 Poll::Ready(Ok(buf.len()))
1217 }
1218
1219 fn poll_flush(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<io::Result<()>> {
1220 Poll::Ready(Ok(()))
1221 }
1222
1223 fn poll_close(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<io::Result<()>> {
1224 Poll::Ready(Ok(()))
1225 }
1226 }
1227
1228 #[test]
1229 fn max_buffer_behaviour() {
1230 use tracing_subscriber::EnvFilter;
1231 let _ = tracing_subscriber::fmt()
1232 .with_env_filter(EnvFilter::from_default_env())
1233 .try_init();
1234
1235 fn prop(cfg: Config, overflow: NonZeroU8) {
1236 let mut r_buf = BytesMut::new();
1237 let mut codec = Codec::new();
1238
1239 for i in 0..cfg.max_substreams {
1241 let stream_id = LocalStreamId::dialer(i as u64);
1242 codec.encode(Frame::Open { stream_id }, &mut r_buf).unwrap();
1243 }
1244
1245 let stream_id = LocalStreamId::dialer(0);
1247 let data = Bytes::from("Hello world");
1248 for _ in 0..cfg.max_buffer_len + overflow.get() as usize {
1249 codec
1250 .encode(
1251 Frame::Data {
1252 stream_id,
1253 data: data.clone(),
1254 },
1255 &mut r_buf,
1256 )
1257 .unwrap();
1258 }
1259
1260 let conn = Connection {
1262 r_buf,
1263 w_buf: BytesMut::new(),
1264 eof: false,
1265 };
1266 let mut m = Multiplexed::new(conn, cfg.clone());
1267
1268 let rt = Runtime::new().unwrap();
1269 rt.block_on(future::poll_fn(move |cx| {
1270 for i in 0..cfg.max_substreams {
1272 match m.poll_next_stream(cx) {
1273 Poll::Pending => panic!("Expected new inbound stream."),
1274 Poll::Ready(Err(e)) => panic!("{e:?}"),
1275 Poll::Ready(Ok(id)) => {
1276 assert_eq!(id, LocalStreamId::listener(i as u64));
1277 }
1278 };
1279 }
1280
1281 let id = LocalStreamId::listener(0);
1284 match m.poll_next_stream(cx) {
1285 Poll::Ready(r) => panic!("Unexpected result for next stream: {r:?}"),
1286 Poll::Pending => {
1287 assert_eq!(
1293 m.substreams.get_mut(&id).unwrap().recv_buf().len(),
1294 cfg.max_buffer_len
1295 );
1296 match m.poll_next_stream(cx) {
1297 Poll::Ready(r) => panic!("Unexpected result for next stream: {r:?}"),
1298 Poll::Pending => {
1299 assert_eq!(
1302 m.substreams.get_mut(&id).unwrap().recv_buf().len(),
1303 cfg.max_buffer_len + 1
1304 );
1305 }
1306 }
1307 }
1308 }
1309
1310 match cfg.max_buffer_behaviour {
1313 MaxBufferBehaviour::ResetStream => {
1314 let _ = m.poll_flush_stream(cx, id);
1315 let w_buf = &mut m.io.get_mut().deref_mut().w_buf;
1316 let frame = codec.decode(w_buf).unwrap();
1317 let stream_id = stream_id.into_remote();
1318 assert_eq!(frame, Some(Frame::Reset { stream_id }));
1319 }
1320 MaxBufferBehaviour::Block => {
1321 assert!(m.poll_next_stream(cx).is_pending());
1322 for i in 1..cfg.max_substreams {
1323 let id = LocalStreamId::listener(i as u64);
1324 assert!(m.poll_read_stream(cx, id).is_pending());
1325 }
1326 }
1327 }
1328
1329 for _ in 0..cfg.max_buffer_len + 1 {
1331 match m.poll_read_stream(cx, id) {
1332 Poll::Ready(Ok(Some(bytes))) => {
1333 assert_eq!(bytes, data);
1334 }
1335 x => panic!("Unexpected: {x:?}"),
1336 }
1337 }
1338
1339 match cfg.max_buffer_behaviour {
1343 MaxBufferBehaviour::ResetStream => {
1344 match m.poll_read_stream(cx, id) {
1346 Poll::Ready(Ok(None)) => {}
1347 poll => panic!("Unexpected: {poll:?}"),
1348 }
1349 }
1350 MaxBufferBehaviour::Block => {
1351 match m.poll_read_stream(cx, id) {
1353 Poll::Ready(Ok(Some(bytes))) => assert_eq!(bytes, data),
1354 Poll::Pending => assert_eq!(overflow.get(), 1),
1355 poll => panic!("Unexpected: {poll:?}"),
1356 }
1357 }
1358 }
1359
1360 Poll::Ready(())
1361 }));
1362 }
1363
1364 quickcheck(prop as fn(_, _))
1365 }
1366
1367 #[test]
1368 fn close_on_error() {
1369 use tracing_subscriber::EnvFilter;
1370 let _ = tracing_subscriber::fmt()
1371 .with_env_filter(EnvFilter::from_default_env())
1372 .try_init();
1373
1374 fn prop(cfg: Config, num_streams: NonZeroU8) {
1375 let num_streams = cmp::min(cfg.max_substreams, num_streams.get() as usize);
1376
1377 let conn = Connection {
1379 r_buf: BytesMut::new(),
1380 w_buf: BytesMut::new(),
1381 eof: false,
1382 };
1383 let mut m = Multiplexed::new(conn, cfg);
1384
1385 let mut opened = HashSet::new();
1387 let rt = Runtime::new().unwrap();
1388 rt.block_on(future::poll_fn(move |cx| {
1389 for _ in 0..num_streams {
1391 let id = ready!(m.poll_open_stream(cx)).unwrap();
1392 assert!(opened.insert(id));
1393 assert!(m.poll_read_stream(cx, id).is_pending());
1394 }
1395
1396 m.io.get_mut().deref_mut().eof = true;
1398
1399 assert!(opened.iter().all(|id| match m.poll_read_stream(cx, *id) {
1402 Poll::Ready(Err(e)) => e.kind() == io::ErrorKind::UnexpectedEof,
1403 _ => false,
1404 }));
1405
1406 assert!(m.substreams.is_empty());
1407
1408 Poll::Ready(())
1409 }))
1410 }
1411
1412 quickcheck(prop as fn(_, _))
1413 }
1414}