1use std::{
24 convert::TryFrom as _,
25 iter, mem,
26 pin::Pin,
27 task::{Context, Poll},
28};
29
30use futures::prelude::*;
31
32use crate::{
33 protocol::{HeaderLine, Message, MessageIO, Protocol, ProtocolError},
34 Negotiated, NegotiationError, Version,
35};
36
37pub fn dialer_select_proto<R, I>(
51 inner: R,
52 protocols: I,
53 version: Version,
54) -> DialerSelectFuture<R, I::IntoIter>
55where
56 R: AsyncRead + AsyncWrite,
57 I: IntoIterator,
58 I::Item: AsRef<str>,
59{
60 let protocols = protocols.into_iter().peekable();
61 DialerSelectFuture {
62 version,
63 protocols,
64 state: State::SendHeader {
65 io: MessageIO::new(inner),
66 },
67 }
68}
69
70#[pin_project::pin_project]
73pub struct DialerSelectFuture<R, I: Iterator> {
74 protocols: iter::Peekable<I>,
76 state: State<R, I::Item>,
77 version: Version,
78}
79
80enum State<R, N> {
81 SendHeader { io: MessageIO<R> },
82 SendProtocol { io: MessageIO<R>, protocol: N },
83 FlushProtocol { io: MessageIO<R>, protocol: N },
84 AwaitProtocol { io: MessageIO<R>, protocol: N },
85 Done,
86}
87
88impl<R, I> Future for DialerSelectFuture<R, I>
89where
90 R: AsyncRead + AsyncWrite + Unpin,
94 I: Iterator,
95 I::Item: AsRef<str>,
96{
97 type Output = Result<(I::Item, Negotiated<R>), NegotiationError>;
98
99 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
100 let this = self.project();
101
102 loop {
103 match mem::replace(this.state, State::Done) {
104 State::SendHeader { mut io } => {
105 match Pin::new(&mut io).poll_ready(cx)? {
106 Poll::Ready(()) => {}
107 Poll::Pending => {
108 *this.state = State::SendHeader { io };
109 return Poll::Pending;
110 }
111 }
112
113 let h = HeaderLine::from(*this.version);
114 if let Err(err) = Pin::new(&mut io).start_send(Message::Header(h)) {
115 return Poll::Ready(Err(From::from(err)));
116 }
117
118 let protocol = this.protocols.next().ok_or(NegotiationError::Failed)?;
119
120 *this.state = State::SendProtocol { io, protocol };
123 }
124
125 State::SendProtocol { mut io, protocol } => {
126 match Pin::new(&mut io).poll_ready(cx)? {
127 Poll::Ready(()) => {}
128 Poll::Pending => {
129 *this.state = State::SendProtocol { io, protocol };
130 return Poll::Pending;
131 }
132 }
133
134 let p = Protocol::try_from(protocol.as_ref())?;
135 if let Err(err) = Pin::new(&mut io).start_send(Message::Protocol(p.clone())) {
136 return Poll::Ready(Err(From::from(err)));
137 }
138 tracing::debug!(protocol=%p, "Dialer: Proposed protocol");
139
140 if this.protocols.peek().is_some() {
141 *this.state = State::FlushProtocol { io, protocol }
142 } else {
143 match this.version {
144 Version::V1 => *this.state = State::FlushProtocol { io, protocol },
145 Version::V1Lazy => {
150 tracing::debug!(protocol=%p, "Dialer: Expecting proposed protocol");
151 let hl = HeaderLine::from(Version::V1Lazy);
152 let io = Negotiated::expecting(io.into_reader(), p, Some(hl));
153 return Poll::Ready(Ok((protocol, io)));
154 }
155 }
156 }
157 }
158
159 State::FlushProtocol { mut io, protocol } => {
160 match Pin::new(&mut io).poll_flush(cx)? {
161 Poll::Ready(()) => *this.state = State::AwaitProtocol { io, protocol },
162 Poll::Pending => {
163 *this.state = State::FlushProtocol { io, protocol };
164 return Poll::Pending;
165 }
166 }
167 }
168
169 State::AwaitProtocol { mut io, protocol } => {
170 let msg = match Pin::new(&mut io).poll_next(cx)? {
171 Poll::Ready(Some(msg)) => msg,
172 Poll::Pending => {
173 *this.state = State::AwaitProtocol { io, protocol };
174 return Poll::Pending;
175 }
176 Poll::Ready(None) => return Poll::Ready(Err(NegotiationError::Failed)),
180 };
181
182 match msg {
183 Message::Header(v) if v == HeaderLine::from(*this.version) => {
184 *this.state = State::AwaitProtocol { io, protocol };
185 }
186 Message::Protocol(ref p) if p.as_ref() == protocol.as_ref() => {
187 tracing::debug!(protocol=%p, "Dialer: Received confirmation for protocol");
188 let io = Negotiated::completed(io.into_inner());
189 return Poll::Ready(Ok((protocol, io)));
190 }
191 Message::NotAvailable => {
192 tracing::debug!(
193 protocol=%protocol.as_ref(),
194 "Dialer: Received rejection of protocol"
195 );
196 let protocol = this.protocols.next().ok_or(NegotiationError::Failed)?;
197 *this.state = State::SendProtocol { io, protocol }
198 }
199 _ => return Poll::Ready(Err(ProtocolError::InvalidMessage.into())),
200 }
201 }
202
203 State::Done => panic!("State::poll called after completion"),
204 }
205 }
206 }
207}
208
209#[cfg(test)]
210mod tests {
211
212 use std::time::Duration;
213
214 use quickcheck::{Arbitrary, Gen, GenRange};
215 use tokio::{
216 net::{TcpListener, TcpStream},
217 runtime::Runtime,
218 time::timeout,
219 };
220 use tokio_util::compat::TokioAsyncReadCompatExt;
221 use tracing::metadata::LevelFilter;
222 use tracing_subscriber::EnvFilter;
223
224 use super::*;
225 use crate::listener_select_proto;
226
227 #[test]
228 fn select_proto_basic() {
229 async fn run(version: Version) {
230 let (client_connection, server_connection) = futures_ringbuf::Endpoint::pair(100, 100);
231
232 let server = tokio::task::spawn(async move {
233 let protos = vec!["/proto1", "/proto2"];
234 let (proto, mut io) = listener_select_proto(server_connection, protos)
235 .await
236 .unwrap();
237 assert_eq!(proto, "/proto2");
238
239 let mut out = vec![0; 32];
240 let n = io.read(&mut out).await.unwrap();
241 out.truncate(n);
242 assert_eq!(out, b"ping");
243
244 io.write_all(b"pong").await.unwrap();
245 io.flush().await.unwrap();
246 });
247
248 let client = tokio::task::spawn(async move {
249 let protos = vec!["/proto3", "/proto2"];
250 let (proto, mut io) = dialer_select_proto(client_connection, protos, version)
251 .await
252 .unwrap();
253 assert_eq!(proto, "/proto2");
254
255 io.write_all(b"ping").await.unwrap();
256 io.flush().await.unwrap();
257
258 let mut out = vec![0; 32];
259 let n = io.read(&mut out).await.unwrap();
260 out.truncate(n);
261 assert_eq!(out, b"pong");
262 });
263
264 server.await.unwrap();
265 client.await.unwrap();
266 }
267
268 let rt = Runtime::new().unwrap();
269 rt.block_on(run(Version::V1));
270 rt.block_on(run(Version::V1Lazy));
271 }
272
273 #[test]
275 fn negotiation_failed() {
276 fn prop(
277 version: Version,
278 DialerProtos(dial_protos): DialerProtos,
279 ListenerProtos(listen_protos): ListenerProtos,
280 DialPayload(dial_payload): DialPayload,
281 ) {
282 let _ = tracing_subscriber::fmt()
283 .with_env_filter(
284 EnvFilter::builder()
285 .with_default_directive(LevelFilter::DEBUG.into())
286 .from_env_lossy(),
287 )
288 .try_init();
289
290 let rt = Runtime::new().unwrap();
291 rt.block_on(async move {
292 let listener = TcpListener::bind("0.0.0.0:0").await.unwrap();
293 let addr = listener.local_addr().unwrap();
294
295 let server = tokio::task::spawn(async move {
296 let server_connection = listener.accept().await.unwrap().0.compat();
297
298 let io = match timeout(
299 Duration::from_secs(2),
300 listener_select_proto(server_connection, listen_protos),
301 )
302 .await
303 .unwrap()
304 {
305 Ok((_, io)) => io,
306 Err(NegotiationError::Failed) => return,
307 Err(NegotiationError::ProtocolError(e)) => {
308 panic!("Unexpected protocol error {e}")
309 }
310 };
311 match io.complete().await {
312 Err(NegotiationError::Failed) => {}
313 _ => panic!(),
314 }
315 });
316
317 let client = tokio::task::spawn(async move {
318 let client_connection = TcpStream::connect(addr).await.unwrap().compat();
319
320 let mut io = match timeout(
321 Duration::from_secs(2),
322 dialer_select_proto(client_connection, dial_protos, version),
323 )
324 .await
325 .unwrap()
326 {
327 Err(NegotiationError::Failed) => return,
328 Ok((_, io)) => io,
329 Err(_) => panic!(),
330 };
331 tracing::info!("Writing early data");
336
337 io.write_all(&dial_payload).await.unwrap();
338 match io.complete().await {
339 Err(NegotiationError::Failed) => {}
340 _ => panic!(),
341 }
342 });
343
344 server.await.unwrap();
345 client.await.unwrap();
346
347 tracing::info!("---------------------------------------")
348 });
349 }
350
351 quickcheck::QuickCheck::new()
352 .tests(1000)
353 .quickcheck(prop as fn(_, _, _, _));
354 }
355
356 #[tokio::test]
357 async fn v1_lazy_do_not_wait_for_negotiation_on_poll_close() {
358 let (client_connection, _server_connection) =
359 futures_ringbuf::Endpoint::pair(1024 * 1024, 1);
360
361 let client = tokio::task::spawn(async move {
362 let protos = vec!["/proto1"];
364 let (proto, mut io) = dialer_select_proto(client_connection, protos, Version::V1Lazy)
365 .await
366 .unwrap();
367 assert_eq!(proto, "/proto1");
368
369 io.close().await.unwrap();
372 });
373
374 match tokio::time::timeout(Duration::from_secs(10), client).await {
375 Ok(join_result) => join_result.expect("Client task should complete successfully"),
376 Err(_elapsed) => {
377 panic!("Expected the client task to complete before timeout");
378 }
379 }
380 }
381
382 #[derive(Clone, Debug)]
383 struct DialerProtos(Vec<&'static str>);
384
385 impl Arbitrary for DialerProtos {
386 fn arbitrary(g: &mut Gen) -> Self {
387 if bool::arbitrary(g) {
388 DialerProtos(vec!["/proto1"])
389 } else {
390 DialerProtos(vec!["/proto1", "/proto2"])
391 }
392 }
393 }
394
395 #[derive(Clone, Debug)]
396 struct ListenerProtos(Vec<&'static str>);
397
398 impl Arbitrary for ListenerProtos {
399 fn arbitrary(g: &mut Gen) -> Self {
400 if bool::arbitrary(g) {
401 ListenerProtos(vec!["/proto3"])
402 } else {
403 ListenerProtos(vec!["/proto3", "/proto4"])
404 }
405 }
406 }
407
408 #[derive(Clone, Debug)]
409 struct DialPayload(Vec<u8>);
410
411 impl Arbitrary for DialPayload {
412 fn arbitrary(g: &mut Gen) -> Self {
413 DialPayload(
414 (0..g.gen_range(0..2u8))
415 .map(|_| g.gen_range(1..255)) .collect(),
417 )
418 }
419 }
420}