1use std::{
2 fmt,
3 future::Future,
4 mem,
5 pin::Pin,
6 task::{Context, Poll},
7 time::Duration,
8};
9
10use futures::{future, AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, Stream, StreamExt};
11use libp2p_core::{
12 muxing::StreamMuxerExt,
13 upgrade::{InboundConnectionUpgrade, OutboundConnectionUpgrade},
14 StreamMuxer, UpgradeInfo,
15};
16
17use crate::future::{BoxFuture, Either, FutureExt};
18
19pub async fn connected_muxers_on_memory_ring_buffer<MC, M, E>() -> (M, M)
20where
21 MC: InboundConnectionUpgrade<futures_ringbuf::Endpoint, Error = E, Output = M>
22 + OutboundConnectionUpgrade<futures_ringbuf::Endpoint, Error = E, Output = M>
23 + Send
24 + 'static
25 + Default,
26 <MC as UpgradeInfo>::Info: Send,
27 <<MC as UpgradeInfo>::InfoIter as IntoIterator>::IntoIter: Send,
28 <MC as InboundConnectionUpgrade<futures_ringbuf::Endpoint>>::Future: Send,
29 <MC as OutboundConnectionUpgrade<futures_ringbuf::Endpoint>>::Future: Send,
30 E: std::error::Error + Send + Sync + 'static,
31{
32 let (alice, bob) = futures_ringbuf::Endpoint::pair(100, 100);
33
34 let alice_upgrade = MC::default().upgrade_inbound(
35 alice,
36 MC::default().protocol_info().into_iter().next().unwrap(),
37 );
38
39 let bob_upgrade = MC::default().upgrade_outbound(
40 bob,
41 MC::default().protocol_info().into_iter().next().unwrap(),
42 );
43
44 futures::future::try_join(alice_upgrade, bob_upgrade)
45 .await
46 .unwrap()
47}
48
49pub async fn close_implies_flush<A, B, S, E>(alice: A, bob: B)
52where
53 A: StreamMuxer<Substream = S, Error = E> + Unpin,
54 B: StreamMuxer<Substream = S, Error = E> + Unpin,
55 S: AsyncRead + AsyncWrite + Send + Unpin + 'static,
56 E: fmt::Debug,
57{
58 run_commutative(
59 alice,
60 bob,
61 |mut stream| async move {
62 stream.write_all(b"PING").await.unwrap();
63 stream.close().await.unwrap();
64 },
65 |mut stream| async move {
66 let mut buf = Vec::new();
67 stream.read_to_end(&mut buf).await.unwrap();
68
69 assert_eq!(buf, b"PING");
70 },
71 )
72 .await;
73}
74
75pub async fn read_after_close<A, B, S, E>(alice: A, bob: B)
77where
78 A: StreamMuxer<Substream = S, Error = E> + Unpin,
79 B: StreamMuxer<Substream = S, Error = E> + Unpin,
80 S: AsyncRead + AsyncWrite + Send + Unpin + 'static,
81 E: fmt::Debug,
82{
83 run_commutative(
84 alice,
85 bob,
86 |mut stream| async move {
87 stream.write_all(b"PING").await.unwrap();
88 stream.close().await.unwrap();
89
90 let mut buf = Vec::new();
91 stream.read_to_end(&mut buf).await.unwrap();
92
93 assert_eq!(buf, b"PONG");
94 },
95 |mut stream| async move {
96 let mut buf = [0u8; 4];
97 stream.read_exact(&mut buf).await.unwrap();
98
99 assert_eq!(&buf, b"PING");
100
101 stream.write_all(b"PONG").await.unwrap();
102 stream.close().await.unwrap();
103 },
104 )
105 .await;
106}
107
108async fn run_commutative<A, B, S, E, F1, F2>(
111 mut alice: A,
112 mut bob: B,
113 alice_proto: impl Fn(S) -> F1 + Clone + 'static,
114 bob_proto: impl Fn(S) -> F2 + Clone + 'static,
115) where
116 A: StreamMuxer<Substream = S, Error = E> + Unpin,
117 B: StreamMuxer<Substream = S, Error = E> + Unpin,
118 S: AsyncRead + AsyncWrite + Send + Unpin + 'static,
119 E: fmt::Debug,
120 F1: Future<Output = ()> + Send + 'static,
121 F2: Future<Output = ()> + Send + 'static,
122{
123 run(&mut alice, &mut bob, alice_proto.clone(), bob_proto.clone()).await;
124 run(&mut bob, &mut alice, alice_proto, bob_proto).await;
125}
126
127async fn run<A, B, S, E, F1, F2>(
133 dialer: &mut A,
134 listener: &mut B,
135 alice_proto: impl Fn(S) -> F1 + 'static,
136 bob_proto: impl Fn(S) -> F2 + 'static,
137) where
138 A: StreamMuxer<Substream = S, Error = E> + Unpin,
139 B: StreamMuxer<Substream = S, Error = E> + Unpin,
140 S: AsyncRead + AsyncWrite + Send + Unpin + 'static,
141 E: fmt::Debug,
142 F1: Future<Output = ()> + Send + 'static,
143 F2: Future<Output = ()> + Send + 'static,
144{
145 let mut dialer = Harness::OutboundSetup {
146 muxer: dialer,
147 proto_fn: Box::new(move |s| alice_proto(s).boxed()),
148 };
149 let mut listener = Harness::InboundSetup {
150 muxer: listener,
151 proto_fn: Box::new(move |s| bob_proto(s).boxed()),
152 };
153
154 let mut dialer_complete = false;
155 let mut listener_complete = false;
156
157 loop {
158 match futures::future::select(dialer.next(), listener.next()).await {
159 Either::Left((Some(Event::SetupComplete), _)) => {
160 tracing::info!("Dialer opened outbound stream");
161 }
162 Either::Left((Some(Event::ProtocolComplete), _)) => {
163 tracing::info!("Dialer completed protocol");
164 dialer_complete = true
165 }
166 Either::Left((Some(Event::Timeout), _)) => {
167 panic!("Dialer protocol timed out");
168 }
169 Either::Right((Some(Event::SetupComplete), _)) => {
170 tracing::info!("Listener received inbound stream");
171 }
172 Either::Right((Some(Event::ProtocolComplete), _)) => {
173 tracing::info!("Listener completed protocol");
174 listener_complete = true
175 }
176 Either::Right((Some(Event::Timeout), _)) => {
177 panic!("Listener protocol timed out");
178 }
179 _ => unreachable!(),
180 }
181
182 if dialer_complete && listener_complete {
183 break;
184 }
185 }
186}
187
188enum Harness<'m, M>
189where
190 M: StreamMuxer,
191{
192 InboundSetup {
193 muxer: &'m mut M,
194 proto_fn: Box<dyn FnOnce(M::Substream) -> BoxFuture<'static, ()>>,
195 },
196 OutboundSetup {
197 muxer: &'m mut M,
198 proto_fn: Box<dyn FnOnce(M::Substream) -> BoxFuture<'static, ()>>,
199 },
200 Running {
201 muxer: &'m mut M,
202 timeout: futures_timer::Delay,
203 proto: BoxFuture<'static, ()>,
204 },
205 Complete {
206 muxer: &'m mut M,
207 },
208 Poisoned,
209}
210
211enum Event {
212 SetupComplete,
213 Timeout,
214 ProtocolComplete,
215}
216
217impl<M> Stream for Harness<'_, M>
218where
219 M: StreamMuxer + Unpin,
220{
221 type Item = Event;
222
223 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
224 let this = self.get_mut();
225
226 loop {
227 match mem::replace(this, Self::Poisoned) {
228 Harness::InboundSetup { muxer, proto_fn } => {
229 if let Poll::Ready(stream) = muxer.poll_inbound_unpin(cx) {
230 *this = Harness::Running {
231 muxer,
232 timeout: futures_timer::Delay::new(Duration::from_secs(10)),
233 proto: proto_fn(stream.unwrap()),
234 };
235 return Poll::Ready(Some(Event::SetupComplete));
236 }
237
238 if let Poll::Ready(event) = muxer.poll_unpin(cx) {
239 event.unwrap();
240
241 *this = Harness::InboundSetup { muxer, proto_fn };
242 continue;
243 }
244
245 *this = Harness::InboundSetup { muxer, proto_fn };
246 return Poll::Pending;
247 }
248 Harness::OutboundSetup { muxer, proto_fn } => {
249 if let Poll::Ready(stream) = muxer.poll_outbound_unpin(cx) {
250 *this = Harness::Running {
251 muxer,
252 timeout: futures_timer::Delay::new(Duration::from_secs(10)),
253 proto: proto_fn(stream.unwrap()),
254 };
255 return Poll::Ready(Some(Event::SetupComplete));
256 }
257
258 if let Poll::Ready(event) = muxer.poll_unpin(cx) {
259 event.unwrap();
260
261 *this = Harness::OutboundSetup { muxer, proto_fn };
262 continue;
263 }
264
265 *this = Harness::OutboundSetup { muxer, proto_fn };
266 return Poll::Pending;
267 }
268 Harness::Running {
269 muxer,
270 mut proto,
271 mut timeout,
272 } => {
273 if let Poll::Ready(event) = muxer.poll_unpin(cx) {
274 event.unwrap();
275
276 *this = Harness::Running {
277 muxer,
278 proto,
279 timeout,
280 };
281 continue;
282 }
283
284 if let Poll::Ready(()) = proto.poll_unpin(cx) {
285 *this = Harness::Complete { muxer };
286 return Poll::Ready(Some(Event::ProtocolComplete));
287 }
288
289 if let Poll::Ready(()) = timeout.poll_unpin(cx) {
290 return Poll::Ready(Some(Event::Timeout));
291 }
292
293 *this = Harness::Running {
294 muxer,
295 proto,
296 timeout,
297 };
298 return Poll::Pending;
299 }
300 Harness::Complete { muxer } => {
301 if let Poll::Ready(event) = muxer.poll_unpin(cx) {
302 event.unwrap();
303
304 *this = Harness::Complete { muxer };
305 continue;
306 }
307
308 *this = Harness::Complete { muxer };
309 return Poll::Pending;
310 }
311 Harness::Poisoned => {
312 unreachable!()
313 }
314 }
315 }
316 }
317}