multistream_select/
dialer_select.rs

1// Copyright 2017 Parity Technologies (UK) Ltd.
2//
3// Permission is hereby granted, free of charge, to any person obtaining a
4// copy of this software and associated documentation files (the "Software"),
5// to deal in the Software without restriction, including without limitation
6// the rights to use, copy, modify, merge, publish, distribute, sublicense,
7// and/or sell copies of the Software, and to permit persons to whom the
8// Software is furnished to do so, subject to the following conditions:
9//
10// The above copyright notice and this permission notice shall be included in
11// all copies or substantial portions of the Software.
12//
13// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
14// OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
16// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
18// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
19// DEALINGS IN THE SOFTWARE.
20
21//! Protocol negotiation strategies for the peer acting as the dialer.
22
23use std::{
24    convert::TryFrom as _,
25    iter, mem,
26    pin::Pin,
27    task::{Context, Poll},
28};
29
30use futures::prelude::*;
31
32use crate::{
33    protocol::{HeaderLine, Message, MessageIO, Protocol, ProtocolError},
34    Negotiated, NegotiationError, Version,
35};
36
37/// Returns a `Future` that negotiates a protocol on the given I/O stream
38/// for a peer acting as the _dialer_ (or _initiator_).
39///
40/// This function is given an I/O stream and a list of protocols and returns a
41/// computation that performs the protocol negotiation with the remote. The
42/// returned `Future` resolves with the name of the negotiated protocol and
43/// a [`Negotiated`] I/O stream.
44///
45/// Within the scope of this library, a dialer always commits to a specific
46/// multistream-select [`Version`], whereas a listener always supports
47/// all versions supported by this library. Frictionless multistream-select
48/// protocol upgrades may thus proceed by deployments with updated listeners,
49/// eventually followed by deployments of dialers choosing the newer protocol.
50pub fn dialer_select_proto<R, I>(
51    inner: R,
52    protocols: I,
53    version: Version,
54) -> DialerSelectFuture<R, I::IntoIter>
55where
56    R: AsyncRead + AsyncWrite,
57    I: IntoIterator,
58    I::Item: AsRef<str>,
59{
60    let protocols = protocols.into_iter().peekable();
61    DialerSelectFuture {
62        version,
63        protocols,
64        state: State::SendHeader {
65            io: MessageIO::new(inner),
66        },
67    }
68}
69
70/// A `Future` returned by [`dialer_select_proto`] which negotiates
71/// a protocol iteratively by considering one protocol after the other.
72#[pin_project::pin_project]
73pub struct DialerSelectFuture<R, I: Iterator> {
74    // TODO: It would be nice if eventually N = I::Item = Protocol.
75    protocols: iter::Peekable<I>,
76    state: State<R, I::Item>,
77    version: Version,
78}
79
80enum State<R, N> {
81    SendHeader { io: MessageIO<R> },
82    SendProtocol { io: MessageIO<R>, protocol: N },
83    FlushProtocol { io: MessageIO<R>, protocol: N },
84    AwaitProtocol { io: MessageIO<R>, protocol: N },
85    Done,
86}
87
88impl<R, I> Future for DialerSelectFuture<R, I>
89where
90    // The Unpin bound here is required because we produce
91    // a `Negotiated<R>` as the output. It also makes
92    // the implementation considerably easier to write.
93    R: AsyncRead + AsyncWrite + Unpin,
94    I: Iterator,
95    I::Item: AsRef<str>,
96{
97    type Output = Result<(I::Item, Negotiated<R>), NegotiationError>;
98
99    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
100        let this = self.project();
101
102        loop {
103            match mem::replace(this.state, State::Done) {
104                State::SendHeader { mut io } => {
105                    match Pin::new(&mut io).poll_ready(cx)? {
106                        Poll::Ready(()) => {}
107                        Poll::Pending => {
108                            *this.state = State::SendHeader { io };
109                            return Poll::Pending;
110                        }
111                    }
112
113                    let h = HeaderLine::from(*this.version);
114                    if let Err(err) = Pin::new(&mut io).start_send(Message::Header(h)) {
115                        return Poll::Ready(Err(From::from(err)));
116                    }
117
118                    let protocol = this.protocols.next().ok_or(NegotiationError::Failed)?;
119
120                    // The dialer always sends the header and the first protocol
121                    // proposal in one go for efficiency.
122                    *this.state = State::SendProtocol { io, protocol };
123                }
124
125                State::SendProtocol { mut io, protocol } => {
126                    match Pin::new(&mut io).poll_ready(cx)? {
127                        Poll::Ready(()) => {}
128                        Poll::Pending => {
129                            *this.state = State::SendProtocol { io, protocol };
130                            return Poll::Pending;
131                        }
132                    }
133
134                    let p = Protocol::try_from(protocol.as_ref())?;
135                    if let Err(err) = Pin::new(&mut io).start_send(Message::Protocol(p.clone())) {
136                        return Poll::Ready(Err(From::from(err)));
137                    }
138                    tracing::debug!(protocol=%p, "Dialer: Proposed protocol");
139
140                    if this.protocols.peek().is_some() {
141                        *this.state = State::FlushProtocol { io, protocol }
142                    } else {
143                        match this.version {
144                            Version::V1 => *this.state = State::FlushProtocol { io, protocol },
145                            // This is the only effect that `V1Lazy` has compared to `V1`:
146                            // Optimistically settling on the only protocol that
147                            // the dialer supports for this negotiation. Notably,
148                            // the dialer expects a regular `V1` response.
149                            Version::V1Lazy => {
150                                tracing::debug!(protocol=%p, "Dialer: Expecting proposed protocol");
151                                let hl = HeaderLine::from(Version::V1Lazy);
152                                let io = Negotiated::expecting(io.into_reader(), p, Some(hl));
153                                return Poll::Ready(Ok((protocol, io)));
154                            }
155                        }
156                    }
157                }
158
159                State::FlushProtocol { mut io, protocol } => {
160                    match Pin::new(&mut io).poll_flush(cx)? {
161                        Poll::Ready(()) => *this.state = State::AwaitProtocol { io, protocol },
162                        Poll::Pending => {
163                            *this.state = State::FlushProtocol { io, protocol };
164                            return Poll::Pending;
165                        }
166                    }
167                }
168
169                State::AwaitProtocol { mut io, protocol } => {
170                    let msg = match Pin::new(&mut io).poll_next(cx)? {
171                        Poll::Ready(Some(msg)) => msg,
172                        Poll::Pending => {
173                            *this.state = State::AwaitProtocol { io, protocol };
174                            return Poll::Pending;
175                        }
176                        // Treat EOF error as [`NegotiationError::Failed`], not as
177                        // [`NegotiationError::ProtocolError`], allowing dropping or closing an I/O
178                        // stream as a permissible way to "gracefully" fail a negotiation.
179                        Poll::Ready(None) => return Poll::Ready(Err(NegotiationError::Failed)),
180                    };
181
182                    match msg {
183                        Message::Header(v) if v == HeaderLine::from(*this.version) => {
184                            *this.state = State::AwaitProtocol { io, protocol };
185                        }
186                        Message::Protocol(ref p) if p.as_ref() == protocol.as_ref() => {
187                            tracing::debug!(protocol=%p, "Dialer: Received confirmation for protocol");
188                            let io = Negotiated::completed(io.into_inner());
189                            return Poll::Ready(Ok((protocol, io)));
190                        }
191                        Message::NotAvailable => {
192                            tracing::debug!(
193                                protocol=%protocol.as_ref(),
194                                "Dialer: Received rejection of protocol"
195                            );
196                            let protocol = this.protocols.next().ok_or(NegotiationError::Failed)?;
197                            *this.state = State::SendProtocol { io, protocol }
198                        }
199                        _ => return Poll::Ready(Err(ProtocolError::InvalidMessage.into())),
200                    }
201                }
202
203                State::Done => panic!("State::poll called after completion"),
204            }
205        }
206    }
207}
208
209#[cfg(test)]
210mod tests {
211    use std::time::Duration;
212
213    use async_std::{
214        future::timeout,
215        net::{TcpListener, TcpStream},
216    };
217    use quickcheck::{Arbitrary, Gen, GenRange};
218    use tracing::metadata::LevelFilter;
219    use tracing_subscriber::EnvFilter;
220
221    use super::*;
222    use crate::listener_select_proto;
223
224    #[test]
225    fn select_proto_basic() {
226        async fn run(version: Version) {
227            let (client_connection, server_connection) = futures_ringbuf::Endpoint::pair(100, 100);
228
229            let server = async_std::task::spawn(async move {
230                let protos = vec!["/proto1", "/proto2"];
231                let (proto, mut io) = listener_select_proto(server_connection, protos)
232                    .await
233                    .unwrap();
234                assert_eq!(proto, "/proto2");
235
236                let mut out = vec![0; 32];
237                let n = io.read(&mut out).await.unwrap();
238                out.truncate(n);
239                assert_eq!(out, b"ping");
240
241                io.write_all(b"pong").await.unwrap();
242                io.flush().await.unwrap();
243            });
244
245            let client = async_std::task::spawn(async move {
246                let protos = vec!["/proto3", "/proto2"];
247                let (proto, mut io) = dialer_select_proto(client_connection, protos, version)
248                    .await
249                    .unwrap();
250                assert_eq!(proto, "/proto2");
251
252                io.write_all(b"ping").await.unwrap();
253                io.flush().await.unwrap();
254
255                let mut out = vec![0; 32];
256                let n = io.read(&mut out).await.unwrap();
257                out.truncate(n);
258                assert_eq!(out, b"pong");
259            });
260
261            server.await;
262            client.await;
263        }
264
265        async_std::task::block_on(run(Version::V1));
266        async_std::task::block_on(run(Version::V1Lazy));
267    }
268
269    /// Tests the expected behaviour of failed negotiations.
270    #[test]
271    fn negotiation_failed() {
272        fn prop(
273            version: Version,
274            DialerProtos(dial_protos): DialerProtos,
275            ListenerProtos(listen_protos): ListenerProtos,
276            DialPayload(dial_payload): DialPayload,
277        ) {
278            let _ = tracing_subscriber::fmt()
279                .with_env_filter(
280                    EnvFilter::builder()
281                        .with_default_directive(LevelFilter::DEBUG.into())
282                        .from_env_lossy(),
283                )
284                .try_init();
285
286            async_std::task::block_on(async move {
287                let listener = TcpListener::bind("0.0.0.0:0").await.unwrap();
288                let addr = listener.local_addr().unwrap();
289
290                let server = async_std::task::spawn(async move {
291                    let server_connection = listener.accept().await.unwrap().0;
292
293                    let io = match timeout(
294                        Duration::from_secs(2),
295                        listener_select_proto(server_connection, listen_protos),
296                    )
297                    .await
298                    .unwrap()
299                    {
300                        Ok((_, io)) => io,
301                        Err(NegotiationError::Failed) => return,
302                        Err(NegotiationError::ProtocolError(e)) => {
303                            panic!("Unexpected protocol error {e}")
304                        }
305                    };
306                    match io.complete().await {
307                        Err(NegotiationError::Failed) => {}
308                        _ => panic!(),
309                    }
310                });
311
312                let client = async_std::task::spawn(async move {
313                    let client_connection = TcpStream::connect(addr).await.unwrap();
314
315                    let mut io = match timeout(
316                        Duration::from_secs(2),
317                        dialer_select_proto(client_connection, dial_protos, version),
318                    )
319                    .await
320                    .unwrap()
321                    {
322                        Err(NegotiationError::Failed) => return,
323                        Ok((_, io)) => io,
324                        Err(_) => panic!(),
325                    };
326                    // The dialer may write a payload that is even sent before it
327                    // got confirmation of the last proposed protocol, when `V1Lazy`
328                    // is used.
329
330                    tracing::info!("Writing early data");
331
332                    io.write_all(&dial_payload).await.unwrap();
333                    match io.complete().await {
334                        Err(NegotiationError::Failed) => {}
335                        _ => panic!(),
336                    }
337                });
338
339                server.await;
340                client.await;
341
342                tracing::info!("---------------------------------------")
343            });
344        }
345
346        quickcheck::QuickCheck::new()
347            .tests(1000)
348            .quickcheck(prop as fn(_, _, _, _));
349    }
350
351    #[async_std::test]
352    async fn v1_lazy_do_not_wait_for_negotiation_on_poll_close() {
353        let (client_connection, _server_connection) =
354            futures_ringbuf::Endpoint::pair(1024 * 1024, 1);
355
356        let client = async_std::task::spawn(async move {
357            // Single protocol to allow for lazy (or optimistic) protocol negotiation.
358            let protos = vec!["/proto1"];
359            let (proto, mut io) = dialer_select_proto(client_connection, protos, Version::V1Lazy)
360                .await
361                .unwrap();
362            assert_eq!(proto, "/proto1");
363
364            // client can close the connection even though protocol negotiation is not yet done,
365            // i.e. `_server_connection` had been untouched.
366            io.close().await.unwrap();
367        });
368
369        async_std::future::timeout(Duration::from_secs(10), client)
370            .await
371            .unwrap();
372    }
373
374    #[derive(Clone, Debug)]
375    struct DialerProtos(Vec<&'static str>);
376
377    impl Arbitrary for DialerProtos {
378        fn arbitrary(g: &mut Gen) -> Self {
379            if bool::arbitrary(g) {
380                DialerProtos(vec!["/proto1"])
381            } else {
382                DialerProtos(vec!["/proto1", "/proto2"])
383            }
384        }
385    }
386
387    #[derive(Clone, Debug)]
388    struct ListenerProtos(Vec<&'static str>);
389
390    impl Arbitrary for ListenerProtos {
391        fn arbitrary(g: &mut Gen) -> Self {
392            if bool::arbitrary(g) {
393                ListenerProtos(vec!["/proto3"])
394            } else {
395                ListenerProtos(vec!["/proto3", "/proto4"])
396            }
397        }
398    }
399
400    #[derive(Clone, Debug)]
401    struct DialPayload(Vec<u8>);
402
403    impl Arbitrary for DialPayload {
404        fn arbitrary(g: &mut Gen) -> Self {
405            DialPayload(
406                (0..g.gen_range(0..2u8))
407                    .map(|_| g.gen_range(1..255)) // We can generate 0 as that will produce a different error.
408                    .collect(),
409            )
410        }
411    }
412}