multistream_select/
dialer_select.rs

1// Copyright 2017 Parity Technologies (UK) Ltd.
2//
3// Permission is hereby granted, free of charge, to any person obtaining a
4// copy of this software and associated documentation files (the "Software"),
5// to deal in the Software without restriction, including without limitation
6// the rights to use, copy, modify, merge, publish, distribute, sublicense,
7// and/or sell copies of the Software, and to permit persons to whom the
8// Software is furnished to do so, subject to the following conditions:
9//
10// The above copyright notice and this permission notice shall be included in
11// all copies or substantial portions of the Software.
12//
13// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
14// OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
16// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
18// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
19// DEALINGS IN THE SOFTWARE.
20
21//! Protocol negotiation strategies for the peer acting as the dialer.
22
23use std::{
24    convert::TryFrom as _,
25    iter, mem,
26    pin::Pin,
27    task::{Context, Poll},
28};
29
30use futures::prelude::*;
31
32use crate::{
33    protocol::{HeaderLine, Message, MessageIO, Protocol, ProtocolError},
34    Negotiated, NegotiationError, Version,
35};
36
37/// Returns a `Future` that negotiates a protocol on the given I/O stream
38/// for a peer acting as the _dialer_ (or _initiator_).
39///
40/// This function is given an I/O stream and a list of protocols and returns a
41/// computation that performs the protocol negotiation with the remote. The
42/// returned `Future` resolves with the name of the negotiated protocol and
43/// a [`Negotiated`] I/O stream.
44///
45/// Within the scope of this library, a dialer always commits to a specific
46/// multistream-select [`Version`], whereas a listener always supports
47/// all versions supported by this library. Frictionless multistream-select
48/// protocol upgrades may thus proceed by deployments with updated listeners,
49/// eventually followed by deployments of dialers choosing the newer protocol.
50pub fn dialer_select_proto<R, I>(
51    inner: R,
52    protocols: I,
53    version: Version,
54) -> DialerSelectFuture<R, I::IntoIter>
55where
56    R: AsyncRead + AsyncWrite,
57    I: IntoIterator,
58    I::Item: AsRef<str>,
59{
60    let protocols = protocols.into_iter().peekable();
61    DialerSelectFuture {
62        version,
63        protocols,
64        state: State::SendHeader {
65            io: MessageIO::new(inner),
66        },
67    }
68}
69
70/// A `Future` returned by [`dialer_select_proto`] which negotiates
71/// a protocol iteratively by considering one protocol after the other.
72#[pin_project::pin_project]
73pub struct DialerSelectFuture<R, I: Iterator> {
74    // TODO: It would be nice if eventually N = I::Item = Protocol.
75    protocols: iter::Peekable<I>,
76    state: State<R, I::Item>,
77    version: Version,
78}
79
80enum State<R, N> {
81    SendHeader { io: MessageIO<R> },
82    SendProtocol { io: MessageIO<R>, protocol: N },
83    FlushProtocol { io: MessageIO<R>, protocol: N },
84    AwaitProtocol { io: MessageIO<R>, protocol: N },
85    Done,
86}
87
88impl<R, I> Future for DialerSelectFuture<R, I>
89where
90    // The Unpin bound here is required because we produce
91    // a `Negotiated<R>` as the output. It also makes
92    // the implementation considerably easier to write.
93    R: AsyncRead + AsyncWrite + Unpin,
94    I: Iterator,
95    I::Item: AsRef<str>,
96{
97    type Output = Result<(I::Item, Negotiated<R>), NegotiationError>;
98
99    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
100        let this = self.project();
101
102        loop {
103            match mem::replace(this.state, State::Done) {
104                State::SendHeader { mut io } => {
105                    match Pin::new(&mut io).poll_ready(cx)? {
106                        Poll::Ready(()) => {}
107                        Poll::Pending => {
108                            *this.state = State::SendHeader { io };
109                            return Poll::Pending;
110                        }
111                    }
112
113                    let h = HeaderLine::from(*this.version);
114                    if let Err(err) = Pin::new(&mut io).start_send(Message::Header(h)) {
115                        return Poll::Ready(Err(From::from(err)));
116                    }
117
118                    let protocol = this.protocols.next().ok_or(NegotiationError::Failed)?;
119
120                    // The dialer always sends the header and the first protocol
121                    // proposal in one go for efficiency.
122                    *this.state = State::SendProtocol { io, protocol };
123                }
124
125                State::SendProtocol { mut io, protocol } => {
126                    match Pin::new(&mut io).poll_ready(cx)? {
127                        Poll::Ready(()) => {}
128                        Poll::Pending => {
129                            *this.state = State::SendProtocol { io, protocol };
130                            return Poll::Pending;
131                        }
132                    }
133
134                    let p = Protocol::try_from(protocol.as_ref())?;
135                    if let Err(err) = Pin::new(&mut io).start_send(Message::Protocol(p.clone())) {
136                        return Poll::Ready(Err(From::from(err)));
137                    }
138                    tracing::debug!(protocol=%p, "Dialer: Proposed protocol");
139
140                    if this.protocols.peek().is_some() {
141                        *this.state = State::FlushProtocol { io, protocol }
142                    } else {
143                        match this.version {
144                            Version::V1 => *this.state = State::FlushProtocol { io, protocol },
145                            // This is the only effect that `V1Lazy` has compared to `V1`:
146                            // Optimistically settling on the only protocol that
147                            // the dialer supports for this negotiation. Notably,
148                            // the dialer expects a regular `V1` response.
149                            Version::V1Lazy => {
150                                tracing::debug!(protocol=%p, "Dialer: Expecting proposed protocol");
151                                let hl = HeaderLine::from(Version::V1Lazy);
152                                let io = Negotiated::expecting(io.into_reader(), p, Some(hl));
153                                return Poll::Ready(Ok((protocol, io)));
154                            }
155                        }
156                    }
157                }
158
159                State::FlushProtocol { mut io, protocol } => {
160                    match Pin::new(&mut io).poll_flush(cx)? {
161                        Poll::Ready(()) => *this.state = State::AwaitProtocol { io, protocol },
162                        Poll::Pending => {
163                            *this.state = State::FlushProtocol { io, protocol };
164                            return Poll::Pending;
165                        }
166                    }
167                }
168
169                State::AwaitProtocol { mut io, protocol } => {
170                    let msg = match Pin::new(&mut io).poll_next(cx)? {
171                        Poll::Ready(Some(msg)) => msg,
172                        Poll::Pending => {
173                            *this.state = State::AwaitProtocol { io, protocol };
174                            return Poll::Pending;
175                        }
176                        // Treat EOF error as [`NegotiationError::Failed`], not as
177                        // [`NegotiationError::ProtocolError`], allowing dropping or closing an I/O
178                        // stream as a permissible way to "gracefully" fail a negotiation.
179                        Poll::Ready(None) => return Poll::Ready(Err(NegotiationError::Failed)),
180                    };
181
182                    match msg {
183                        Message::Header(v) if v == HeaderLine::from(*this.version) => {
184                            *this.state = State::AwaitProtocol { io, protocol };
185                        }
186                        Message::Protocol(ref p) if p.as_ref() == protocol.as_ref() => {
187                            tracing::debug!(protocol=%p, "Dialer: Received confirmation for protocol");
188                            let io = Negotiated::completed(io.into_inner());
189                            return Poll::Ready(Ok((protocol, io)));
190                        }
191                        Message::NotAvailable => {
192                            tracing::debug!(
193                                protocol=%protocol.as_ref(),
194                                "Dialer: Received rejection of protocol"
195                            );
196                            let protocol = this.protocols.next().ok_or(NegotiationError::Failed)?;
197                            *this.state = State::SendProtocol { io, protocol }
198                        }
199                        _ => return Poll::Ready(Err(ProtocolError::InvalidMessage.into())),
200                    }
201                }
202
203                State::Done => panic!("State::poll called after completion"),
204            }
205        }
206    }
207}
208
209#[cfg(test)]
210mod tests {
211
212    use std::time::Duration;
213
214    use quickcheck::{Arbitrary, Gen, GenRange};
215    use tokio::{
216        net::{TcpListener, TcpStream},
217        runtime::Runtime,
218        time::timeout,
219    };
220    use tokio_util::compat::TokioAsyncReadCompatExt;
221    use tracing::metadata::LevelFilter;
222    use tracing_subscriber::EnvFilter;
223
224    use super::*;
225    use crate::listener_select_proto;
226
227    #[test]
228    fn select_proto_basic() {
229        async fn run(version: Version) {
230            let (client_connection, server_connection) = futures_ringbuf::Endpoint::pair(100, 100);
231
232            let server = tokio::task::spawn(async move {
233                let protos = vec!["/proto1", "/proto2"];
234                let (proto, mut io) = listener_select_proto(server_connection, protos)
235                    .await
236                    .unwrap();
237                assert_eq!(proto, "/proto2");
238
239                let mut out = vec![0; 32];
240                let n = io.read(&mut out).await.unwrap();
241                out.truncate(n);
242                assert_eq!(out, b"ping");
243
244                io.write_all(b"pong").await.unwrap();
245                io.flush().await.unwrap();
246            });
247
248            let client = tokio::task::spawn(async move {
249                let protos = vec!["/proto3", "/proto2"];
250                let (proto, mut io) = dialer_select_proto(client_connection, protos, version)
251                    .await
252                    .unwrap();
253                assert_eq!(proto, "/proto2");
254
255                io.write_all(b"ping").await.unwrap();
256                io.flush().await.unwrap();
257
258                let mut out = vec![0; 32];
259                let n = io.read(&mut out).await.unwrap();
260                out.truncate(n);
261                assert_eq!(out, b"pong");
262            });
263
264            server.await.unwrap();
265            client.await.unwrap();
266        }
267
268        let rt = Runtime::new().unwrap();
269        rt.block_on(run(Version::V1));
270        rt.block_on(run(Version::V1Lazy));
271    }
272
273    /// Tests the expected behaviour of failed negotiations.
274    #[test]
275    fn negotiation_failed() {
276        fn prop(
277            version: Version,
278            DialerProtos(dial_protos): DialerProtos,
279            ListenerProtos(listen_protos): ListenerProtos,
280            DialPayload(dial_payload): DialPayload,
281        ) {
282            let _ = tracing_subscriber::fmt()
283                .with_env_filter(
284                    EnvFilter::builder()
285                        .with_default_directive(LevelFilter::DEBUG.into())
286                        .from_env_lossy(),
287                )
288                .try_init();
289
290            let rt = Runtime::new().unwrap();
291            rt.block_on(async move {
292                let listener = TcpListener::bind("0.0.0.0:0").await.unwrap();
293                let addr = listener.local_addr().unwrap();
294
295                let server = tokio::task::spawn(async move {
296                    let server_connection = listener.accept().await.unwrap().0.compat();
297
298                    let io = match timeout(
299                        Duration::from_secs(2),
300                        listener_select_proto(server_connection, listen_protos),
301                    )
302                    .await
303                    .unwrap()
304                    {
305                        Ok((_, io)) => io,
306                        Err(NegotiationError::Failed) => return,
307                        Err(NegotiationError::ProtocolError(e)) => {
308                            panic!("Unexpected protocol error {e}")
309                        }
310                    };
311                    match io.complete().await {
312                        Err(NegotiationError::Failed) => {}
313                        _ => panic!(),
314                    }
315                });
316
317                let client = tokio::task::spawn(async move {
318                    let client_connection = TcpStream::connect(addr).await.unwrap().compat();
319
320                    let mut io = match timeout(
321                        Duration::from_secs(2),
322                        dialer_select_proto(client_connection, dial_protos, version),
323                    )
324                    .await
325                    .unwrap()
326                    {
327                        Err(NegotiationError::Failed) => return,
328                        Ok((_, io)) => io,
329                        Err(_) => panic!(),
330                    };
331                    // The dialer may write a payload that is even sent before it
332                    // got confirmation of the last proposed protocol, when `V1Lazy`
333                    // is used.
334
335                    tracing::info!("Writing early data");
336
337                    io.write_all(&dial_payload).await.unwrap();
338                    match io.complete().await {
339                        Err(NegotiationError::Failed) => {}
340                        _ => panic!(),
341                    }
342                });
343
344                server.await.unwrap();
345                client.await.unwrap();
346
347                tracing::info!("---------------------------------------")
348            });
349        }
350
351        quickcheck::QuickCheck::new()
352            .tests(1000)
353            .quickcheck(prop as fn(_, _, _, _));
354    }
355
356    #[tokio::test]
357    async fn v1_lazy_do_not_wait_for_negotiation_on_poll_close() {
358        let (client_connection, _server_connection) =
359            futures_ringbuf::Endpoint::pair(1024 * 1024, 1);
360
361        let client = tokio::task::spawn(async move {
362            // Single protocol to allow for lazy (or optimistic) protocol negotiation.
363            let protos = vec!["/proto1"];
364            let (proto, mut io) = dialer_select_proto(client_connection, protos, Version::V1Lazy)
365                .await
366                .unwrap();
367            assert_eq!(proto, "/proto1");
368
369            // client can close the connection even though protocol negotiation is not yet done,
370            // i.e. `_server_connection` had been untouched.
371            io.close().await.unwrap();
372        });
373
374        match tokio::time::timeout(Duration::from_secs(10), client).await {
375            Ok(join_result) => join_result.expect("Client task should complete successfully"),
376            Err(_elapsed) => {
377                panic!("Expected the client task to complete before timeout");
378            }
379        }
380    }
381
382    #[derive(Clone, Debug)]
383    struct DialerProtos(Vec<&'static str>);
384
385    impl Arbitrary for DialerProtos {
386        fn arbitrary(g: &mut Gen) -> Self {
387            if bool::arbitrary(g) {
388                DialerProtos(vec!["/proto1"])
389            } else {
390                DialerProtos(vec!["/proto1", "/proto2"])
391            }
392        }
393    }
394
395    #[derive(Clone, Debug)]
396    struct ListenerProtos(Vec<&'static str>);
397
398    impl Arbitrary for ListenerProtos {
399        fn arbitrary(g: &mut Gen) -> Self {
400            if bool::arbitrary(g) {
401                ListenerProtos(vec!["/proto3"])
402            } else {
403                ListenerProtos(vec!["/proto3", "/proto4"])
404            }
405        }
406    }
407
408    #[derive(Clone, Debug)]
409    struct DialPayload(Vec<u8>);
410
411    impl Arbitrary for DialPayload {
412        fn arbitrary(g: &mut Gen) -> Self {
413            DialPayload(
414                (0..g.gen_range(0..2u8))
415                    .map(|_| g.gen_range(1..255)) // We can generate 0 as that will produce a different error.
416                    .collect(),
417            )
418        }
419    }
420}