multistream_select/
listener_select.rs

1// Copyright 2017 Parity Technologies (UK) Ltd.
2//
3// Permission is hereby granted, free of charge, to any person obtaining a
4// copy of this software and associated documentation files (the "Software"),
5// to deal in the Software without restriction, including without limitation
6// the rights to use, copy, modify, merge, publish, distribute, sublicense,
7// and/or sell copies of the Software, and to permit persons to whom the
8// Software is furnished to do so, subject to the following conditions:
9//
10// The above copyright notice and this permission notice shall be included in
11// all copies or substantial portions of the Software.
12//
13// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
14// OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
16// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
18// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
19// DEALINGS IN THE SOFTWARE.
20
21//! Protocol negotiation strategies for the peer acting as the listener
22//! in a multistream-select protocol negotiation.
23
24use std::{
25    convert::TryFrom as _,
26    mem,
27    pin::Pin,
28    task::{Context, Poll},
29};
30
31use futures::prelude::*;
32use smallvec::SmallVec;
33
34use crate::{
35    protocol::{HeaderLine, Message, MessageIO, Protocol, ProtocolError},
36    Negotiated, NegotiationError,
37};
38
39/// Returns a `Future` that negotiates a protocol on the given I/O stream
40/// for a peer acting as the _listener_ (or _responder_).
41///
42/// This function is given an I/O stream and a list of protocols and returns a
43/// computation that performs the protocol negotiation with the remote. The
44/// returned `Future` resolves with the name of the negotiated protocol and
45/// a [`Negotiated`] I/O stream.
46pub fn listener_select_proto<R, I>(inner: R, protocols: I) -> ListenerSelectFuture<R, I::Item>
47where
48    R: AsyncRead + AsyncWrite,
49    I: IntoIterator,
50    I::Item: AsRef<str>,
51{
52    let protocols = protocols
53        .into_iter()
54        .filter_map(|n| match Protocol::try_from(n.as_ref()) {
55            Ok(p) => Some((n, p)),
56            Err(e) => {
57                tracing::warn!(
58                    "Listener: Ignoring invalid protocol: {} due to {}",
59                    n.as_ref(),
60                    e
61                );
62                None
63            }
64        });
65    ListenerSelectFuture {
66        protocols: SmallVec::from_iter(protocols),
67        state: State::RecvHeader {
68            io: MessageIO::new(inner),
69        },
70        last_sent_na: false,
71    }
72}
73
74/// The `Future` returned by [`listener_select_proto`] that performs a
75/// multistream-select protocol negotiation on an underlying I/O stream.
76#[pin_project::pin_project]
77pub struct ListenerSelectFuture<R, N> {
78    // TODO: It would be nice if eventually N = Protocol, which has a
79    // few more implications on the API.
80    protocols: SmallVec<[(N, Protocol); 8]>,
81    state: State<R, N>,
82    /// Whether the last message sent was a protocol rejection (i.e. `na\n`).
83    ///
84    /// If the listener reads garbage or EOF after such a rejection,
85    /// the dialer is likely using `V1Lazy` and negotiation must be
86    /// considered failed, but not with a protocol violation or I/O
87    /// error.
88    last_sent_na: bool,
89}
90
91enum State<R, N> {
92    RecvHeader {
93        io: MessageIO<R>,
94    },
95    SendHeader {
96        io: MessageIO<R>,
97    },
98    RecvMessage {
99        io: MessageIO<R>,
100    },
101    SendMessage {
102        io: MessageIO<R>,
103        message: Message,
104        protocol: Option<N>,
105    },
106    Flush {
107        io: MessageIO<R>,
108        protocol: Option<N>,
109    },
110    Done,
111}
112
113impl<R, N> Future for ListenerSelectFuture<R, N>
114where
115    // The Unpin bound here is required because
116    // we produce a `Negotiated<R>` as the output.
117    // It also makes the implementation considerably
118    // easier to write.
119    R: AsyncRead + AsyncWrite + Unpin,
120    N: AsRef<str> + Clone,
121{
122    type Output = Result<(N, Negotiated<R>), NegotiationError>;
123
124    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
125        let this = self.project();
126
127        loop {
128            match mem::replace(this.state, State::Done) {
129                State::RecvHeader { mut io } => {
130                    match io.poll_next_unpin(cx) {
131                        Poll::Ready(Some(Ok(Message::Header(HeaderLine::V1)))) => {
132                            *this.state = State::SendHeader { io }
133                        }
134                        Poll::Ready(Some(Ok(_))) => {
135                            return Poll::Ready(Err(ProtocolError::InvalidMessage.into()))
136                        }
137                        Poll::Ready(Some(Err(err))) => return Poll::Ready(Err(From::from(err))),
138                        // Treat EOF error as [`NegotiationError::Failed`], not as
139                        // [`NegotiationError::ProtocolError`], allowing dropping or closing an I/O
140                        // stream as a permissible way to "gracefully" fail a negotiation.
141                        Poll::Ready(None) => return Poll::Ready(Err(NegotiationError::Failed)),
142                        Poll::Pending => {
143                            *this.state = State::RecvHeader { io };
144                            return Poll::Pending;
145                        }
146                    }
147                }
148
149                State::SendHeader { mut io } => {
150                    match Pin::new(&mut io).poll_ready(cx) {
151                        Poll::Pending => {
152                            *this.state = State::SendHeader { io };
153                            return Poll::Pending;
154                        }
155                        Poll::Ready(Ok(())) => {}
156                        Poll::Ready(Err(err)) => return Poll::Ready(Err(From::from(err))),
157                    }
158
159                    let msg = Message::Header(HeaderLine::V1);
160                    if let Err(err) = Pin::new(&mut io).start_send(msg) {
161                        return Poll::Ready(Err(From::from(err)));
162                    }
163
164                    *this.state = State::Flush { io, protocol: None };
165                }
166
167                State::RecvMessage { mut io } => {
168                    let msg = match Pin::new(&mut io).poll_next(cx) {
169                        Poll::Ready(Some(Ok(msg))) => msg,
170                        // Treat EOF error as [`NegotiationError::Failed`], not as
171                        // [`NegotiationError::ProtocolError`], allowing dropping or closing an I/O
172                        // stream as a permissible way to "gracefully" fail a negotiation.
173                        //
174                        // This is e.g. important when a listener rejects a protocol with
175                        // [`Message::NotAvailable`] and the dialer does not have alternative
176                        // protocols to propose. Then the dialer will stop the negotiation and drop
177                        // the corresponding stream. As a listener this EOF should be interpreted as
178                        // a failed negotiation.
179                        Poll::Ready(None) => return Poll::Ready(Err(NegotiationError::Failed)),
180                        Poll::Pending => {
181                            *this.state = State::RecvMessage { io };
182                            return Poll::Pending;
183                        }
184                        Poll::Ready(Some(Err(err))) => {
185                            if *this.last_sent_na {
186                                // When we read garbage or EOF after having already rejected a
187                                // protocol, the dialer is most likely using `V1Lazy` and has
188                                // optimistically settled on this protocol, so this is really a
189                                // failed negotiation, not a protocol violation. In this case
190                                // the dialer also raises `NegotiationError::Failed` when finally
191                                // reading the `N/A` response.
192                                if let ProtocolError::InvalidMessage = &err {
193                                    tracing::trace!(
194                                        "Listener: Negotiation failed with invalid \
195                                        message after protocol rejection."
196                                    );
197                                    return Poll::Ready(Err(NegotiationError::Failed));
198                                }
199                                if let ProtocolError::IoError(e) = &err {
200                                    if e.kind() == std::io::ErrorKind::UnexpectedEof {
201                                        tracing::trace!(
202                                            "Listener: Negotiation failed with EOF \
203                                            after protocol rejection."
204                                        );
205                                        return Poll::Ready(Err(NegotiationError::Failed));
206                                    }
207                                }
208                            }
209
210                            return Poll::Ready(Err(From::from(err)));
211                        }
212                    };
213
214                    match msg {
215                        Message::ListProtocols => {
216                            let supported =
217                                this.protocols.iter().map(|(_, p)| p).cloned().collect();
218                            let message = Message::Protocols(supported);
219                            *this.state = State::SendMessage {
220                                io,
221                                message,
222                                protocol: None,
223                            }
224                        }
225                        Message::Protocol(p) => {
226                            let protocol = this.protocols.iter().find_map(|(name, proto)| {
227                                if &p == proto {
228                                    Some(name.clone())
229                                } else {
230                                    None
231                                }
232                            });
233
234                            let message = if protocol.is_some() {
235                                tracing::debug!(protocol=%p, "Listener: confirming protocol");
236                                Message::Protocol(p.clone())
237                            } else {
238                                tracing::debug!(protocol=%p.as_ref(), "Listener: rejecting protocol");
239                                Message::NotAvailable
240                            };
241
242                            *this.state = State::SendMessage {
243                                io,
244                                message,
245                                protocol,
246                            };
247                        }
248                        _ => return Poll::Ready(Err(ProtocolError::InvalidMessage.into())),
249                    }
250                }
251
252                State::SendMessage {
253                    mut io,
254                    message,
255                    protocol,
256                } => {
257                    match Pin::new(&mut io).poll_ready(cx) {
258                        Poll::Pending => {
259                            *this.state = State::SendMessage {
260                                io,
261                                message,
262                                protocol,
263                            };
264                            return Poll::Pending;
265                        }
266                        Poll::Ready(Ok(())) => {}
267                        Poll::Ready(Err(err)) => return Poll::Ready(Err(From::from(err))),
268                    }
269
270                    if let Message::NotAvailable = &message {
271                        *this.last_sent_na = true;
272                    } else {
273                        *this.last_sent_na = false;
274                    }
275
276                    if let Err(err) = Pin::new(&mut io).start_send(message) {
277                        return Poll::Ready(Err(From::from(err)));
278                    }
279
280                    *this.state = State::Flush { io, protocol };
281                }
282
283                State::Flush { mut io, protocol } => {
284                    match Pin::new(&mut io).poll_flush(cx) {
285                        Poll::Pending => {
286                            *this.state = State::Flush { io, protocol };
287                            return Poll::Pending;
288                        }
289                        Poll::Ready(Ok(())) => {
290                            // If a protocol has been selected, finish negotiation.
291                            // Otherwise expect to receive another message.
292                            match protocol {
293                                Some(protocol) => {
294                                    tracing::debug!(
295                                        protocol=%protocol.as_ref(),
296                                        "Listener: sent confirmed protocol"
297                                    );
298                                    let io = Negotiated::completed(io.into_inner());
299                                    return Poll::Ready(Ok((protocol, io)));
300                                }
301                                None => *this.state = State::RecvMessage { io },
302                            }
303                        }
304                        Poll::Ready(Err(err)) => return Poll::Ready(Err(From::from(err))),
305                    }
306                }
307
308                State::Done => panic!("State::poll called after completion"),
309            }
310        }
311    }
312}