libp2p_core/muxing/
boxed.rs1use std::{
2    error::Error,
3    fmt, io,
4    io::{IoSlice, IoSliceMut},
5    pin::Pin,
6    task::{Context, Poll},
7};
8
9use futures::{AsyncRead, AsyncWrite};
10use pin_project::pin_project;
11
12use crate::muxing::{StreamMuxer, StreamMuxerEvent};
13
14pub struct StreamMuxerBox {
16    inner: Pin<Box<dyn StreamMuxer<Substream = SubstreamBox, Error = io::Error> + Send>>,
17}
18
19impl fmt::Debug for StreamMuxerBox {
20    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
21        f.debug_struct("StreamMuxerBox").finish_non_exhaustive()
22    }
23}
24
25pub struct SubstreamBox(Pin<Box<dyn AsyncReadWrite + Send>>);
30
31#[pin_project]
32struct Wrap<T>
33where
34    T: StreamMuxer,
35{
36    #[pin]
37    inner: T,
38}
39
40impl<T> StreamMuxer for Wrap<T>
41where
42    T: StreamMuxer,
43    T::Substream: Send + 'static,
44    T::Error: Send + Sync + 'static,
45{
46    type Substream = SubstreamBox;
47    type Error = io::Error;
48
49    fn poll_inbound(
50        self: Pin<&mut Self>,
51        cx: &mut Context<'_>,
52    ) -> Poll<Result<Self::Substream, Self::Error>> {
53        self.project()
54            .inner
55            .poll_inbound(cx)
56            .map_ok(SubstreamBox::new)
57            .map_err(into_io_error)
58    }
59
60    fn poll_outbound(
61        self: Pin<&mut Self>,
62        cx: &mut Context<'_>,
63    ) -> Poll<Result<Self::Substream, Self::Error>> {
64        self.project()
65            .inner
66            .poll_outbound(cx)
67            .map_ok(SubstreamBox::new)
68            .map_err(into_io_error)
69    }
70
71    #[inline]
72    fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
73        self.project().inner.poll_close(cx).map_err(into_io_error)
74    }
75
76    fn poll(
77        self: Pin<&mut Self>,
78        cx: &mut Context<'_>,
79    ) -> Poll<Result<StreamMuxerEvent, Self::Error>> {
80        self.project().inner.poll(cx).map_err(into_io_error)
81    }
82}
83
84fn into_io_error<E>(err: E) -> io::Error
85where
86    E: Error + Send + Sync + 'static,
87{
88    io::Error::other(err)
89}
90
91impl StreamMuxerBox {
92    pub fn new<T>(muxer: T) -> StreamMuxerBox
94    where
95        T: StreamMuxer + Send + 'static,
96        T::Substream: Send + 'static,
97        T::Error: Send + Sync + 'static,
98    {
99        let wrap = Wrap { inner: muxer };
100
101        StreamMuxerBox {
102            inner: Box::pin(wrap),
103        }
104    }
105
106    fn project(
107        self: Pin<&mut Self>,
108    ) -> Pin<&mut (dyn StreamMuxer<Substream = SubstreamBox, Error = io::Error> + Send)> {
109        self.get_mut().inner.as_mut()
110    }
111}
112
113impl StreamMuxer for StreamMuxerBox {
114    type Substream = SubstreamBox;
115    type Error = io::Error;
116
117    fn poll_inbound(
118        self: Pin<&mut Self>,
119        cx: &mut Context<'_>,
120    ) -> Poll<Result<Self::Substream, Self::Error>> {
121        self.project().poll_inbound(cx)
122    }
123
124    fn poll_outbound(
125        self: Pin<&mut Self>,
126        cx: &mut Context<'_>,
127    ) -> Poll<Result<Self::Substream, Self::Error>> {
128        self.project().poll_outbound(cx)
129    }
130
131    #[inline]
132    fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
133        self.project().poll_close(cx)
134    }
135
136    fn poll(
137        self: Pin<&mut Self>,
138        cx: &mut Context<'_>,
139    ) -> Poll<Result<StreamMuxerEvent, Self::Error>> {
140        self.project().poll(cx)
141    }
142}
143
144impl SubstreamBox {
145    pub fn new<S: AsyncRead + AsyncWrite + Send + 'static>(stream: S) -> Self {
148        Self(Box::pin(stream))
149    }
150}
151
152impl fmt::Debug for SubstreamBox {
153    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
154        write!(f, "SubstreamBox({})", self.0.type_name())
155    }
156}
157
158trait AsyncReadWrite: AsyncRead + AsyncWrite {
160    fn type_name(&self) -> &'static str;
164}
165
166impl<S> AsyncReadWrite for S
167where
168    S: AsyncRead + AsyncWrite,
169{
170    fn type_name(&self) -> &'static str {
171        std::any::type_name::<S>()
172    }
173}
174
175impl AsyncRead for SubstreamBox {
176    fn poll_read(
177        mut self: Pin<&mut Self>,
178        cx: &mut Context<'_>,
179        buf: &mut [u8],
180    ) -> Poll<std::io::Result<usize>> {
181        self.0.as_mut().poll_read(cx, buf)
182    }
183
184    fn poll_read_vectored(
185        mut self: Pin<&mut Self>,
186        cx: &mut Context<'_>,
187        bufs: &mut [IoSliceMut<'_>],
188    ) -> Poll<std::io::Result<usize>> {
189        self.0.as_mut().poll_read_vectored(cx, bufs)
190    }
191}
192
193impl AsyncWrite for SubstreamBox {
194    fn poll_write(
195        mut self: Pin<&mut Self>,
196        cx: &mut Context<'_>,
197        buf: &[u8],
198    ) -> Poll<std::io::Result<usize>> {
199        self.0.as_mut().poll_write(cx, buf)
200    }
201
202    fn poll_write_vectored(
203        mut self: Pin<&mut Self>,
204        cx: &mut Context<'_>,
205        bufs: &[IoSlice<'_>],
206    ) -> Poll<std::io::Result<usize>> {
207        self.0.as_mut().poll_write_vectored(cx, bufs)
208    }
209
210    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
211        self.0.as_mut().poll_flush(cx)
212    }
213
214    fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
215        self.0.as_mut().poll_close(cx)
216    }
217}