1use std::{
24 convert::TryFrom as _,
25 iter, mem,
26 pin::Pin,
27 task::{Context, Poll},
28};
29
30use futures::prelude::*;
31
32use crate::{
33 protocol::{HeaderLine, Message, MessageIO, Protocol, ProtocolError},
34 Negotiated, NegotiationError, Version,
35};
36
37pub fn dialer_select_proto<R, I>(
51 inner: R,
52 protocols: I,
53 version: Version,
54) -> DialerSelectFuture<R, I::IntoIter>
55where
56 R: AsyncRead + AsyncWrite,
57 I: IntoIterator,
58 I::Item: AsRef<str>,
59{
60 let protocols = protocols.into_iter().peekable();
61 DialerSelectFuture {
62 version,
63 protocols,
64 state: State::SendHeader {
65 io: MessageIO::new(inner),
66 },
67 }
68}
69
70#[pin_project::pin_project]
73pub struct DialerSelectFuture<R, I: Iterator> {
74 protocols: iter::Peekable<I>,
76 state: State<R, I::Item>,
77 version: Version,
78}
79
80enum State<R, N> {
81 SendHeader { io: MessageIO<R> },
82 SendProtocol { io: MessageIO<R>, protocol: N },
83 FlushProtocol { io: MessageIO<R>, protocol: N },
84 AwaitProtocol { io: MessageIO<R>, protocol: N },
85 Done,
86}
87
88impl<R, I> Future for DialerSelectFuture<R, I>
89where
90 R: AsyncRead + AsyncWrite + Unpin,
94 I: Iterator,
95 I::Item: AsRef<str>,
96{
97 type Output = Result<(I::Item, Negotiated<R>), NegotiationError>;
98
99 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
100 let this = self.project();
101
102 loop {
103 match mem::replace(this.state, State::Done) {
104 State::SendHeader { mut io } => {
105 match Pin::new(&mut io).poll_ready(cx)? {
106 Poll::Ready(()) => {}
107 Poll::Pending => {
108 *this.state = State::SendHeader { io };
109 return Poll::Pending;
110 }
111 }
112
113 let h = HeaderLine::from(*this.version);
114 if let Err(err) = Pin::new(&mut io).start_send(Message::Header(h)) {
115 return Poll::Ready(Err(From::from(err)));
116 }
117
118 let protocol = this.protocols.next().ok_or(NegotiationError::Failed)?;
119
120 *this.state = State::SendProtocol { io, protocol };
123 }
124
125 State::SendProtocol { mut io, protocol } => {
126 match Pin::new(&mut io).poll_ready(cx)? {
127 Poll::Ready(()) => {}
128 Poll::Pending => {
129 *this.state = State::SendProtocol { io, protocol };
130 return Poll::Pending;
131 }
132 }
133
134 let p = Protocol::try_from(protocol.as_ref())?;
135 if let Err(err) = Pin::new(&mut io).start_send(Message::Protocol(p.clone())) {
136 return Poll::Ready(Err(From::from(err)));
137 }
138 tracing::debug!(protocol=%p, "Dialer: Proposed protocol");
139
140 if this.protocols.peek().is_some() {
141 *this.state = State::FlushProtocol { io, protocol }
142 } else {
143 match this.version {
144 Version::V1 => *this.state = State::FlushProtocol { io, protocol },
145 Version::V1Lazy => {
150 tracing::debug!(protocol=%p, "Dialer: Expecting proposed protocol");
151 let hl = HeaderLine::from(Version::V1Lazy);
152 let io = Negotiated::expecting(io.into_reader(), p, Some(hl));
153 return Poll::Ready(Ok((protocol, io)));
154 }
155 }
156 }
157 }
158
159 State::FlushProtocol { mut io, protocol } => {
160 match Pin::new(&mut io).poll_flush(cx)? {
161 Poll::Ready(()) => *this.state = State::AwaitProtocol { io, protocol },
162 Poll::Pending => {
163 *this.state = State::FlushProtocol { io, protocol };
164 return Poll::Pending;
165 }
166 }
167 }
168
169 State::AwaitProtocol { mut io, protocol } => {
170 let msg = match Pin::new(&mut io).poll_next(cx)? {
171 Poll::Ready(Some(msg)) => msg,
172 Poll::Pending => {
173 *this.state = State::AwaitProtocol { io, protocol };
174 return Poll::Pending;
175 }
176 Poll::Ready(None) => return Poll::Ready(Err(NegotiationError::Failed)),
180 };
181
182 match msg {
183 Message::Header(v) if v == HeaderLine::from(*this.version) => {
184 *this.state = State::AwaitProtocol { io, protocol };
185 }
186 Message::Protocol(ref p) if p.as_ref() == protocol.as_ref() => {
187 tracing::debug!(protocol=%p, "Dialer: Received confirmation for protocol");
188 let io = Negotiated::completed(io.into_inner());
189 return Poll::Ready(Ok((protocol, io)));
190 }
191 Message::NotAvailable => {
192 tracing::debug!(
193 protocol=%protocol.as_ref(),
194 "Dialer: Received rejection of protocol"
195 );
196 let protocol = this.protocols.next().ok_or(NegotiationError::Failed)?;
197 *this.state = State::SendProtocol { io, protocol }
198 }
199 _ => return Poll::Ready(Err(ProtocolError::InvalidMessage.into())),
200 }
201 }
202
203 State::Done => panic!("State::poll called after completion"),
204 }
205 }
206 }
207}
208
209#[cfg(test)]
210mod tests {
211 use std::time::Duration;
212
213 use async_std::{
214 future::timeout,
215 net::{TcpListener, TcpStream},
216 };
217 use quickcheck::{Arbitrary, Gen, GenRange};
218 use tracing::metadata::LevelFilter;
219 use tracing_subscriber::EnvFilter;
220
221 use super::*;
222 use crate::listener_select_proto;
223
224 #[test]
225 fn select_proto_basic() {
226 async fn run(version: Version) {
227 let (client_connection, server_connection) = futures_ringbuf::Endpoint::pair(100, 100);
228
229 let server = async_std::task::spawn(async move {
230 let protos = vec!["/proto1", "/proto2"];
231 let (proto, mut io) = listener_select_proto(server_connection, protos)
232 .await
233 .unwrap();
234 assert_eq!(proto, "/proto2");
235
236 let mut out = vec![0; 32];
237 let n = io.read(&mut out).await.unwrap();
238 out.truncate(n);
239 assert_eq!(out, b"ping");
240
241 io.write_all(b"pong").await.unwrap();
242 io.flush().await.unwrap();
243 });
244
245 let client = async_std::task::spawn(async move {
246 let protos = vec!["/proto3", "/proto2"];
247 let (proto, mut io) = dialer_select_proto(client_connection, protos, version)
248 .await
249 .unwrap();
250 assert_eq!(proto, "/proto2");
251
252 io.write_all(b"ping").await.unwrap();
253 io.flush().await.unwrap();
254
255 let mut out = vec![0; 32];
256 let n = io.read(&mut out).await.unwrap();
257 out.truncate(n);
258 assert_eq!(out, b"pong");
259 });
260
261 server.await;
262 client.await;
263 }
264
265 async_std::task::block_on(run(Version::V1));
266 async_std::task::block_on(run(Version::V1Lazy));
267 }
268
269 #[test]
271 fn negotiation_failed() {
272 fn prop(
273 version: Version,
274 DialerProtos(dial_protos): DialerProtos,
275 ListenerProtos(listen_protos): ListenerProtos,
276 DialPayload(dial_payload): DialPayload,
277 ) {
278 let _ = tracing_subscriber::fmt()
279 .with_env_filter(
280 EnvFilter::builder()
281 .with_default_directive(LevelFilter::DEBUG.into())
282 .from_env_lossy(),
283 )
284 .try_init();
285
286 async_std::task::block_on(async move {
287 let listener = TcpListener::bind("0.0.0.0:0").await.unwrap();
288 let addr = listener.local_addr().unwrap();
289
290 let server = async_std::task::spawn(async move {
291 let server_connection = listener.accept().await.unwrap().0;
292
293 let io = match timeout(
294 Duration::from_secs(2),
295 listener_select_proto(server_connection, listen_protos),
296 )
297 .await
298 .unwrap()
299 {
300 Ok((_, io)) => io,
301 Err(NegotiationError::Failed) => return,
302 Err(NegotiationError::ProtocolError(e)) => {
303 panic!("Unexpected protocol error {e}")
304 }
305 };
306 match io.complete().await {
307 Err(NegotiationError::Failed) => {}
308 _ => panic!(),
309 }
310 });
311
312 let client = async_std::task::spawn(async move {
313 let client_connection = TcpStream::connect(addr).await.unwrap();
314
315 let mut io = match timeout(
316 Duration::from_secs(2),
317 dialer_select_proto(client_connection, dial_protos, version),
318 )
319 .await
320 .unwrap()
321 {
322 Err(NegotiationError::Failed) => return,
323 Ok((_, io)) => io,
324 Err(_) => panic!(),
325 };
326 tracing::info!("Writing early data");
331
332 io.write_all(&dial_payload).await.unwrap();
333 match io.complete().await {
334 Err(NegotiationError::Failed) => {}
335 _ => panic!(),
336 }
337 });
338
339 server.await;
340 client.await;
341
342 tracing::info!("---------------------------------------")
343 });
344 }
345
346 quickcheck::QuickCheck::new()
347 .tests(1000)
348 .quickcheck(prop as fn(_, _, _, _));
349 }
350
351 #[async_std::test]
352 async fn v1_lazy_do_not_wait_for_negotiation_on_poll_close() {
353 let (client_connection, _server_connection) =
354 futures_ringbuf::Endpoint::pair(1024 * 1024, 1);
355
356 let client = async_std::task::spawn(async move {
357 let protos = vec!["/proto1"];
359 let (proto, mut io) = dialer_select_proto(client_connection, protos, Version::V1Lazy)
360 .await
361 .unwrap();
362 assert_eq!(proto, "/proto1");
363
364 io.close().await.unwrap();
367 });
368
369 async_std::future::timeout(Duration::from_secs(10), client)
370 .await
371 .unwrap();
372 }
373
374 #[derive(Clone, Debug)]
375 struct DialerProtos(Vec<&'static str>);
376
377 impl Arbitrary for DialerProtos {
378 fn arbitrary(g: &mut Gen) -> Self {
379 if bool::arbitrary(g) {
380 DialerProtos(vec!["/proto1"])
381 } else {
382 DialerProtos(vec!["/proto1", "/proto2"])
383 }
384 }
385 }
386
387 #[derive(Clone, Debug)]
388 struct ListenerProtos(Vec<&'static str>);
389
390 impl Arbitrary for ListenerProtos {
391 fn arbitrary(g: &mut Gen) -> Self {
392 if bool::arbitrary(g) {
393 ListenerProtos(vec!["/proto3"])
394 } else {
395 ListenerProtos(vec!["/proto3", "/proto4"])
396 }
397 }
398 }
399
400 #[derive(Clone, Debug)]
401 struct DialPayload(Vec<u8>);
402
403 impl Arbitrary for DialPayload {
404 fn arbitrary(g: &mut Gen) -> Self {
405 DialPayload(
406 (0..g.gen_range(0..2u8))
407 .map(|_| g.gen_range(1..255)) .collect(),
409 )
410 }
411 }
412}