multistream_select/
negotiated.rs

1// Copyright 2019 Parity Technologies (UK) Ltd.
2//
3// Permission is hereby granted, free of charge, to any person obtaining a
4// copy of this software and associated documentation files (the "Software"),
5// to deal in the Software without restriction, including without limitation
6// the rights to use, copy, modify, merge, publish, distribute, sublicense,
7// and/or sell copies of the Software, and to permit persons to whom the
8// Software is furnished to do so, subject to the following conditions:
9//
10// The above copyright notice and this permission notice shall be included in
11// all copies or substantial portions of the Software.
12//
13// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
14// OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
16// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
18// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
19// DEALINGS IN THE SOFTWARE.
20
21use std::{
22    error::Error,
23    fmt, io, mem,
24    pin::Pin,
25    task::{Context, Poll},
26};
27
28use futures::{
29    io::{IoSlice, IoSliceMut},
30    prelude::*,
31    ready,
32};
33use pin_project::pin_project;
34
35use crate::protocol::{HeaderLine, Message, MessageReader, Protocol, ProtocolError};
36
37/// An I/O stream that has settled on an (application-layer) protocol to use.
38///
39/// A `Negotiated` represents an I/O stream that has _settled_ on a protocol
40/// to use. In particular, it is not implied that all of the protocol negotiation
41/// frames have yet been sent and / or received, just that the selected protocol
42/// is fully determined. This is to allow the last protocol negotiation frames
43/// sent by a peer to be combined in a single write, possibly piggy-backing
44/// data from the negotiated protocol on top.
45///
46/// Reading from a `Negotiated` I/O stream that still has pending negotiation
47/// protocol data to send implicitly triggers flushing of all yet unsent data.
48#[pin_project]
49#[derive(Debug)]
50pub struct Negotiated<TInner> {
51    #[pin]
52    state: State<TInner>,
53}
54
55/// A `Future` that waits on the completion of protocol negotiation.
56#[derive(Debug)]
57pub struct NegotiatedComplete<TInner> {
58    inner: Option<Negotiated<TInner>>,
59}
60
61impl<TInner> Future for NegotiatedComplete<TInner>
62where
63    // `Unpin` is required not because of
64    // implementation details but because we produce
65    // the `Negotiated` as the output of the
66    // future.
67    TInner: AsyncRead + AsyncWrite + Unpin,
68{
69    type Output = Result<Negotiated<TInner>, NegotiationError>;
70
71    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
72        let mut io = self
73            .inner
74            .take()
75            .expect("NegotiatedFuture called after completion.");
76        match Negotiated::poll(Pin::new(&mut io), cx) {
77            Poll::Pending => {
78                self.inner = Some(io);
79                Poll::Pending
80            }
81            Poll::Ready(Ok(())) => Poll::Ready(Ok(io)),
82            Poll::Ready(Err(err)) => {
83                self.inner = Some(io);
84                Poll::Ready(Err(err))
85            }
86        }
87    }
88}
89
90impl<TInner> Negotiated<TInner> {
91    /// Creates a `Negotiated` in state [`State::Completed`].
92    pub(crate) fn completed(io: TInner) -> Self {
93        Negotiated {
94            state: State::Completed { io },
95        }
96    }
97
98    /// Creates a `Negotiated` in state [`State::Expecting`] that is still
99    /// expecting confirmation of the given `protocol`.
100    pub(crate) fn expecting(
101        io: MessageReader<TInner>,
102        protocol: Protocol,
103        header: Option<HeaderLine>,
104    ) -> Self {
105        Negotiated {
106            state: State::Expecting {
107                io,
108                protocol,
109                header,
110            },
111        }
112    }
113
114    /// Polls the `Negotiated` for completion.
115    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), NegotiationError>>
116    where
117        TInner: AsyncRead + AsyncWrite + Unpin,
118    {
119        // Flush any pending negotiation data.
120        match self.as_mut().poll_flush(cx) {
121            Poll::Ready(Ok(())) => {}
122            Poll::Pending => return Poll::Pending,
123            Poll::Ready(Err(e)) => {
124                // If the remote closed the stream, it is important to still
125                // continue reading the data that was sent, if any.
126                if e.kind() != io::ErrorKind::WriteZero {
127                    return Poll::Ready(Err(e.into()));
128                }
129            }
130        }
131
132        let mut this = self.project();
133
134        if let StateProj::Completed { .. } = this.state.as_mut().project() {
135            return Poll::Ready(Ok(()));
136        }
137
138        // Read outstanding protocol negotiation messages.
139        loop {
140            match mem::replace(&mut *this.state, State::Invalid) {
141                State::Expecting {
142                    mut io,
143                    header,
144                    protocol,
145                } => {
146                    let msg = match Pin::new(&mut io).poll_next(cx)? {
147                        Poll::Ready(Some(msg)) => msg,
148                        Poll::Pending => {
149                            *this.state = State::Expecting {
150                                io,
151                                header,
152                                protocol,
153                            };
154                            return Poll::Pending;
155                        }
156                        Poll::Ready(None) => {
157                            return Poll::Ready(Err(ProtocolError::IoError(
158                                io::ErrorKind::UnexpectedEof.into(),
159                            )
160                            .into()));
161                        }
162                    };
163
164                    if let Message::Header(h) = &msg {
165                        if Some(h) == header.as_ref() {
166                            *this.state = State::Expecting {
167                                io,
168                                protocol,
169                                header: None,
170                            };
171                            continue;
172                        }
173                    }
174
175                    if let Message::Protocol(p) = &msg {
176                        if p.as_ref() == protocol.as_ref() {
177                            tracing::debug!(protocol=%p, "Negotiated: Received confirmation for protocol");
178                            *this.state = State::Completed {
179                                io: io.into_inner(),
180                            };
181                            return Poll::Ready(Ok(()));
182                        }
183                    }
184
185                    return Poll::Ready(Err(NegotiationError::Failed));
186                }
187
188                _ => panic!("Negotiated: Invalid state"),
189            }
190        }
191    }
192
193    /// Returns a [`NegotiatedComplete`] future that waits for protocol
194    /// negotiation to complete.
195    pub fn complete(self) -> NegotiatedComplete<TInner> {
196        NegotiatedComplete { inner: Some(self) }
197    }
198}
199
200/// The states of a `Negotiated` I/O stream.
201#[pin_project(project = StateProj)]
202#[derive(Debug)]
203enum State<R> {
204    /// In this state, a `Negotiated` is still expecting to
205    /// receive confirmation of the protocol it has optimistically
206    /// settled on.
207    Expecting {
208        /// The underlying I/O stream.
209        #[pin]
210        io: MessageReader<R>,
211        /// The expected negotiation header/preamble (i.e. multistream-select version),
212        /// if one is still expected to be received.
213        header: Option<HeaderLine>,
214        /// The expected application protocol (i.e. name and version).
215        protocol: Protocol,
216    },
217
218    /// In this state, a protocol has been agreed upon and I/O
219    /// on the underlying stream can commence.
220    Completed {
221        #[pin]
222        io: R,
223    },
224
225    /// Temporary state while moving the `io` resource from
226    /// `Expecting` to `Completed`.
227    Invalid,
228}
229
230impl<TInner> AsyncRead for Negotiated<TInner>
231where
232    TInner: AsyncRead + AsyncWrite + Unpin,
233{
234    fn poll_read(
235        mut self: Pin<&mut Self>,
236        cx: &mut Context<'_>,
237        buf: &mut [u8],
238    ) -> Poll<Result<usize, io::Error>> {
239        loop {
240            if let StateProj::Completed { io } = self.as_mut().project().state.project() {
241                // If protocol negotiation is complete, commence with reading.
242                return io.poll_read(cx, buf);
243            }
244
245            // Poll the `Negotiated`, driving protocol negotiation to completion,
246            // including flushing of any remaining data.
247            match self.as_mut().poll(cx) {
248                Poll::Ready(Ok(())) => {}
249                Poll::Pending => return Poll::Pending,
250                Poll::Ready(Err(err)) => return Poll::Ready(Err(From::from(err))),
251            }
252        }
253    }
254
255    // TODO: implement once method is stabilized in the futures crate
256    // unsafe fn initializer(&self) -> Initializer {
257    // match &self.state {
258    // State::Completed { io, .. } => io.initializer(),
259    // State::Expecting { io, .. } => io.inner_ref().initializer(),
260    // State::Invalid => panic!("Negotiated: Invalid state"),
261    // }
262    // }
263
264    fn poll_read_vectored(
265        mut self: Pin<&mut Self>,
266        cx: &mut Context<'_>,
267        bufs: &mut [IoSliceMut<'_>],
268    ) -> Poll<Result<usize, io::Error>> {
269        loop {
270            if let StateProj::Completed { io } = self.as_mut().project().state.project() {
271                // If protocol negotiation is complete, commence with reading.
272                return io.poll_read_vectored(cx, bufs);
273            }
274
275            // Poll the `Negotiated`, driving protocol negotiation to completion,
276            // including flushing of any remaining data.
277            match self.as_mut().poll(cx) {
278                Poll::Ready(Ok(())) => {}
279                Poll::Pending => return Poll::Pending,
280                Poll::Ready(Err(err)) => return Poll::Ready(Err(From::from(err))),
281            }
282        }
283    }
284}
285
286impl<TInner> AsyncWrite for Negotiated<TInner>
287where
288    TInner: AsyncWrite + AsyncRead + Unpin,
289{
290    fn poll_write(
291        self: Pin<&mut Self>,
292        cx: &mut Context<'_>,
293        buf: &[u8],
294    ) -> Poll<Result<usize, io::Error>> {
295        match self.project().state.project() {
296            StateProj::Completed { io } => io.poll_write(cx, buf),
297            StateProj::Expecting { io, .. } => io.poll_write(cx, buf),
298            StateProj::Invalid => panic!("Negotiated: Invalid state"),
299        }
300    }
301
302    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
303        match self.project().state.project() {
304            StateProj::Completed { io } => io.poll_flush(cx),
305            StateProj::Expecting { io, .. } => io.poll_flush(cx),
306            StateProj::Invalid => panic!("Negotiated: Invalid state"),
307        }
308    }
309
310    fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
311        // Ensure all data has been flushed, including optimistic multistream-select messages.
312        ready!(self
313            .as_mut()
314            .poll_flush(cx)
315            .map_err(Into::<io::Error>::into)?);
316
317        // Continue with the shutdown of the underlying I/O stream.
318        match self.project().state.project() {
319            StateProj::Completed { io, .. } => io.poll_close(cx),
320            StateProj::Expecting { io, .. } => {
321                let close_poll = io.poll_close(cx);
322                if let Poll::Ready(Ok(())) = close_poll {
323                    tracing::debug!("Stream closed. Confirmation from remote for optimstic protocol negotiation still pending")
324                }
325                close_poll
326            }
327            StateProj::Invalid => panic!("Negotiated: Invalid state"),
328        }
329    }
330
331    fn poll_write_vectored(
332        self: Pin<&mut Self>,
333        cx: &mut Context<'_>,
334        bufs: &[IoSlice<'_>],
335    ) -> Poll<Result<usize, io::Error>> {
336        match self.project().state.project() {
337            StateProj::Completed { io } => io.poll_write_vectored(cx, bufs),
338            StateProj::Expecting { io, .. } => io.poll_write_vectored(cx, bufs),
339            StateProj::Invalid => panic!("Negotiated: Invalid state"),
340        }
341    }
342}
343
344/// Error that can happen when negotiating a protocol with the remote.
345#[derive(Debug)]
346pub enum NegotiationError {
347    /// A protocol error occurred during the negotiation.
348    ProtocolError(ProtocolError),
349
350    /// Protocol negotiation failed because no protocol could be agreed upon.
351    Failed,
352}
353
354impl From<ProtocolError> for NegotiationError {
355    fn from(err: ProtocolError) -> NegotiationError {
356        NegotiationError::ProtocolError(err)
357    }
358}
359
360impl From<io::Error> for NegotiationError {
361    fn from(err: io::Error) -> NegotiationError {
362        ProtocolError::from(err).into()
363    }
364}
365
366impl From<NegotiationError> for io::Error {
367    fn from(err: NegotiationError) -> io::Error {
368        if let NegotiationError::ProtocolError(e) = err {
369            return e.into();
370        }
371        io::Error::new(io::ErrorKind::Other, err)
372    }
373}
374
375impl Error for NegotiationError {
376    fn source(&self) -> Option<&(dyn Error + 'static)> {
377        match self {
378            NegotiationError::ProtocolError(err) => Some(err),
379            _ => None,
380        }
381    }
382}
383
384impl fmt::Display for NegotiationError {
385    fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> {
386        match self {
387            NegotiationError::ProtocolError(p) => {
388                fmt.write_fmt(format_args!("Protocol error: {p}"))
389            }
390            NegotiationError::Failed => fmt.write_str("Protocol negotiation failed."),
391        }
392    }
393}