multistream_select/
protocol.rs

1// Copyright 2017 Parity Technologies (UK) Ltd.
2//
3// Permission is hereby granted, free of charge, to any person obtaining a
4// copy of this software and associated documentation files (the "Software"),
5// to deal in the Software without restriction, including without limitation
6// the rights to use, copy, modify, merge, publish, distribute, sublicense,
7// and/or sell copies of the Software, and to permit persons to whom the
8// Software is furnished to do so, subject to the following conditions:
9//
10// The above copyright notice and this permission notice shall be included in
11// all copies or substantial portions of the Software.
12//
13// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
14// OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
16// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
18// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
19// DEALINGS IN THE SOFTWARE.
20
21//! Multistream-select protocol messages an I/O operations for
22//! constructing protocol negotiation flows.
23//!
24//! A protocol negotiation flow is constructed by using the
25//! `Stream` and `Sink` implementations of `MessageIO` and
26//! `MessageReader`.
27
28use std::{
29    error::Error,
30    fmt, io,
31    pin::Pin,
32    task::{Context, Poll},
33};
34
35use bytes::{BufMut, Bytes, BytesMut};
36use futures::{io::IoSlice, prelude::*, ready};
37use unsigned_varint as uvi;
38
39use crate::{
40    length_delimited::{LengthDelimited, LengthDelimitedReader},
41    Version,
42};
43
44/// The maximum number of supported protocols that can be processed.
45const MAX_PROTOCOLS: usize = 1000;
46
47/// The encoded form of a multistream-select 1.0.0 header message.
48const MSG_MULTISTREAM_1_0: &[u8] = b"/multistream/1.0.0\n";
49/// The encoded form of a multistream-select 'na' message.
50const MSG_PROTOCOL_NA: &[u8] = b"na\n";
51/// The encoded form of a multistream-select 'ls' message.
52const MSG_LS: &[u8] = b"ls\n";
53
54/// The multistream-select header lines preceding negotiation.
55///
56/// Every [`Version`] has a corresponding header line.
57#[derive(Copy, Clone, Debug, PartialEq, Eq)]
58pub(crate) enum HeaderLine {
59    /// The `/multistream/1.0.0` header line.
60    V1,
61}
62
63impl From<Version> for HeaderLine {
64    fn from(v: Version) -> HeaderLine {
65        match v {
66            Version::V1 | Version::V1Lazy => HeaderLine::V1,
67        }
68    }
69}
70
71/// A protocol (name) exchanged during protocol negotiation.
72#[derive(Clone, Debug, PartialEq, Eq)]
73pub(crate) struct Protocol(String);
74impl AsRef<str> for Protocol {
75    fn as_ref(&self) -> &str {
76        self.0.as_ref()
77    }
78}
79
80impl TryFrom<Bytes> for Protocol {
81    type Error = ProtocolError;
82
83    fn try_from(value: Bytes) -> Result<Self, Self::Error> {
84        if !value.as_ref().starts_with(b"/") {
85            return Err(ProtocolError::InvalidProtocol);
86        }
87        let protocol_as_string =
88            String::from_utf8(value.to_vec()).map_err(|_| ProtocolError::InvalidProtocol)?;
89
90        Ok(Protocol(protocol_as_string))
91    }
92}
93
94impl TryFrom<&[u8]> for Protocol {
95    type Error = ProtocolError;
96
97    fn try_from(value: &[u8]) -> Result<Self, Self::Error> {
98        Self::try_from(Bytes::copy_from_slice(value))
99    }
100}
101
102impl TryFrom<&str> for Protocol {
103    type Error = ProtocolError;
104
105    fn try_from(value: &str) -> Result<Self, Self::Error> {
106        if !value.starts_with('/') {
107            return Err(ProtocolError::InvalidProtocol);
108        }
109
110        Ok(Protocol(value.to_owned()))
111    }
112}
113
114impl fmt::Display for Protocol {
115    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
116        write!(f, "{}", self.0)
117    }
118}
119
120/// A multistream-select protocol message.
121///
122/// Multistream-select protocol messages are exchanged with the goal
123/// of agreeing on an application-layer protocol to use on an I/O stream.
124#[derive(Debug, Clone, PartialEq, Eq)]
125pub(crate) enum Message {
126    /// A header message identifies the multistream-select protocol
127    /// that the sender wishes to speak.
128    Header(HeaderLine),
129    /// A protocol message identifies a protocol request or acknowledgement.
130    Protocol(Protocol),
131    /// A message through which a peer requests the complete list of
132    /// supported protocols from the remote.
133    ListProtocols,
134    /// A message listing all supported protocols of a peer.
135    Protocols(Vec<Protocol>),
136    /// A message signaling that a requested protocol is not available.
137    NotAvailable,
138}
139
140impl Message {
141    /// Encodes a `Message` into its byte representation.
142    fn encode(&self, dest: &mut BytesMut) {
143        match self {
144            Message::Header(HeaderLine::V1) => {
145                dest.reserve(MSG_MULTISTREAM_1_0.len());
146                dest.put(MSG_MULTISTREAM_1_0);
147            }
148            Message::Protocol(p) => {
149                let len = p.as_ref().len() + 1; // + 1 for \n
150                dest.reserve(len);
151                dest.put(p.0.as_ref());
152                dest.put_u8(b'\n');
153            }
154            Message::ListProtocols => {
155                dest.reserve(MSG_LS.len());
156                dest.put(MSG_LS);
157            }
158            Message::Protocols(ps) => {
159                let mut buf = uvi::encode::usize_buffer();
160                let mut encoded = Vec::with_capacity(ps.len());
161                for p in ps {
162                    encoded.extend(uvi::encode::usize(p.as_ref().len() + 1, &mut buf)); // +1 for '\n'
163                    encoded.extend_from_slice(p.0.as_ref());
164                    encoded.push(b'\n')
165                }
166                encoded.push(b'\n');
167                dest.reserve(encoded.len());
168                dest.put(encoded.as_ref());
169            }
170            Message::NotAvailable => {
171                dest.reserve(MSG_PROTOCOL_NA.len());
172                dest.put(MSG_PROTOCOL_NA);
173            }
174        }
175    }
176
177    /// Decodes a `Message` from its byte representation.
178    fn decode(mut msg: Bytes) -> Result<Message, ProtocolError> {
179        if msg == MSG_MULTISTREAM_1_0 {
180            return Ok(Message::Header(HeaderLine::V1));
181        }
182
183        if msg == MSG_PROTOCOL_NA {
184            return Ok(Message::NotAvailable);
185        }
186
187        if msg == MSG_LS {
188            return Ok(Message::ListProtocols);
189        }
190
191        // If it starts with a `/`, ends with a line feed without any
192        // other line feeds in-between, it must be a protocol name.
193        if msg.first() == Some(&b'/')
194            && msg.last() == Some(&b'\n')
195            && !msg[..msg.len() - 1].contains(&b'\n')
196        {
197            let p = Protocol::try_from(msg.split_to(msg.len() - 1))?;
198            return Ok(Message::Protocol(p));
199        }
200
201        // At this point, it must be an `ls` response, i.e. one or more
202        // length-prefixed, newline-delimited protocol names.
203        let mut protocols = Vec::new();
204        let mut remaining: &[u8] = &msg;
205        loop {
206            // A well-formed message must be terminated with a newline.
207            if remaining == [b'\n'] {
208                break;
209            } else if protocols.len() == MAX_PROTOCOLS {
210                return Err(ProtocolError::TooManyProtocols);
211            }
212
213            // Decode the length of the next protocol name and check that
214            // it ends with a line feed.
215            let (len, tail) = uvi::decode::usize(remaining)?;
216            if len == 0 || len > tail.len() || tail[len - 1] != b'\n' {
217                return Err(ProtocolError::InvalidMessage);
218            }
219
220            // Parse the protocol name.
221            let p = Protocol::try_from(Bytes::copy_from_slice(&tail[..len - 1]))?;
222            protocols.push(p);
223
224            // Skip ahead to the next protocol.
225            remaining = &tail[len..];
226        }
227
228        Ok(Message::Protocols(protocols))
229    }
230}
231
232/// A `MessageIO` implements a [`Stream`] and [`Sink`] of [`Message`]s.
233#[pin_project::pin_project]
234pub(crate) struct MessageIO<R> {
235    #[pin]
236    inner: LengthDelimited<R>,
237}
238
239impl<R> MessageIO<R> {
240    /// Constructs a new `MessageIO` resource wrapping the given I/O stream.
241    pub(crate) fn new(inner: R) -> MessageIO<R>
242    where
243        R: AsyncRead + AsyncWrite,
244    {
245        Self {
246            inner: LengthDelimited::new(inner),
247        }
248    }
249
250    /// Converts the [`MessageIO`] into a [`MessageReader`], dropping the
251    /// [`Message`]-oriented `Sink` in favour of direct `AsyncWrite` access
252    /// to the underlying I/O stream.
253    ///
254    /// This is typically done if further negotiation messages are expected to be
255    /// received but no more messages are written, allowing the writing of
256    /// follow-up protocol data to commence.
257    pub(crate) fn into_reader(self) -> MessageReader<R> {
258        MessageReader {
259            inner: self.inner.into_reader(),
260        }
261    }
262
263    /// Drops the [`MessageIO`] resource, yielding the underlying I/O stream.
264    ///
265    /// # Panics
266    ///
267    /// Panics if the read buffer or write buffer is not empty, meaning that an incoming
268    /// protocol negotiation frame has been partially read or an outgoing frame
269    /// has not yet been flushed. The read buffer is guaranteed to be empty whenever
270    /// `MessageIO::poll` returned a message. The write buffer is guaranteed to be empty
271    /// when the sink has been flushed.
272    pub(crate) fn into_inner(self) -> R {
273        self.inner.into_inner()
274    }
275}
276
277impl<R> Sink<Message> for MessageIO<R>
278where
279    R: AsyncWrite,
280{
281    type Error = ProtocolError;
282
283    fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
284        self.project().inner.poll_ready(cx).map_err(From::from)
285    }
286
287    fn start_send(self: Pin<&mut Self>, item: Message) -> Result<(), Self::Error> {
288        let mut buf = BytesMut::new();
289        item.encode(&mut buf);
290        self.project()
291            .inner
292            .start_send(buf.freeze())
293            .map_err(From::from)
294    }
295
296    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
297        self.project().inner.poll_flush(cx).map_err(From::from)
298    }
299
300    fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
301        self.project().inner.poll_close(cx).map_err(From::from)
302    }
303}
304
305impl<R> Stream for MessageIO<R>
306where
307    R: AsyncRead,
308{
309    type Item = Result<Message, ProtocolError>;
310
311    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
312        match poll_stream(self.project().inner, cx) {
313            Poll::Pending => Poll::Pending,
314            Poll::Ready(None) => Poll::Ready(None),
315            Poll::Ready(Some(Ok(m))) => Poll::Ready(Some(Ok(m))),
316            Poll::Ready(Some(Err(err))) => Poll::Ready(Some(Err(err))),
317        }
318    }
319}
320
321/// A `MessageReader` implements a `Stream` of `Message`s on an underlying
322/// I/O resource combined with direct `AsyncWrite` access.
323#[pin_project::pin_project]
324#[derive(Debug)]
325pub(crate) struct MessageReader<R> {
326    #[pin]
327    inner: LengthDelimitedReader<R>,
328}
329
330impl<R> MessageReader<R> {
331    /// Drops the `MessageReader` resource, yielding the underlying I/O stream
332    /// together with the remaining write buffer containing the protocol
333    /// negotiation frame data that has not yet been written to the I/O stream.
334    ///
335    /// # Panics
336    ///
337    /// Panics if the read buffer or write buffer is not empty, meaning that either
338    /// an incoming protocol negotiation frame has been partially read, or an
339    /// outgoing frame has not yet been flushed. The read buffer is guaranteed to
340    /// be empty whenever `MessageReader::poll` returned a message. The write
341    /// buffer is guaranteed to be empty whenever the sink has been flushed.
342    pub(crate) fn into_inner(self) -> R {
343        self.inner.into_inner()
344    }
345}
346
347impl<R> Stream for MessageReader<R>
348where
349    R: AsyncRead,
350{
351    type Item = Result<Message, ProtocolError>;
352
353    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
354        poll_stream(self.project().inner, cx)
355    }
356}
357
358impl<TInner> AsyncWrite for MessageReader<TInner>
359where
360    TInner: AsyncWrite,
361{
362    fn poll_write(
363        self: Pin<&mut Self>,
364        cx: &mut Context<'_>,
365        buf: &[u8],
366    ) -> Poll<Result<usize, io::Error>> {
367        self.project().inner.poll_write(cx, buf)
368    }
369
370    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
371        self.project().inner.poll_flush(cx)
372    }
373
374    fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
375        self.project().inner.poll_close(cx)
376    }
377
378    fn poll_write_vectored(
379        self: Pin<&mut Self>,
380        cx: &mut Context<'_>,
381        bufs: &[IoSlice<'_>],
382    ) -> Poll<Result<usize, io::Error>> {
383        self.project().inner.poll_write_vectored(cx, bufs)
384    }
385}
386
387fn poll_stream<S>(
388    stream: Pin<&mut S>,
389    cx: &mut Context<'_>,
390) -> Poll<Option<Result<Message, ProtocolError>>>
391where
392    S: Stream<Item = Result<Bytes, io::Error>>,
393{
394    let msg = if let Some(msg) = ready!(stream.poll_next(cx)?) {
395        match Message::decode(msg) {
396            Ok(m) => m,
397            Err(err) => return Poll::Ready(Some(Err(err))),
398        }
399    } else {
400        return Poll::Ready(None);
401    };
402
403    tracing::trace!(message=?msg, "Received message");
404
405    Poll::Ready(Some(Ok(msg)))
406}
407
408/// A protocol error.
409#[derive(Debug)]
410pub enum ProtocolError {
411    /// I/O error.
412    IoError(io::Error),
413
414    /// Received an invalid message from the remote.
415    InvalidMessage,
416
417    /// A protocol (name) is invalid.
418    InvalidProtocol,
419
420    /// Too many protocols have been returned by the remote.
421    TooManyProtocols,
422}
423
424impl From<io::Error> for ProtocolError {
425    fn from(err: io::Error) -> ProtocolError {
426        ProtocolError::IoError(err)
427    }
428}
429
430impl From<ProtocolError> for io::Error {
431    fn from(err: ProtocolError) -> Self {
432        if let ProtocolError::IoError(e) = err {
433            return e;
434        }
435        io::ErrorKind::InvalidData.into()
436    }
437}
438
439impl From<uvi::decode::Error> for ProtocolError {
440    fn from(err: uvi::decode::Error) -> ProtocolError {
441        Self::from(io::Error::new(io::ErrorKind::InvalidData, err.to_string()))
442    }
443}
444
445impl Error for ProtocolError {
446    fn source(&self) -> Option<&(dyn Error + 'static)> {
447        match *self {
448            ProtocolError::IoError(ref err) => Some(err),
449            _ => None,
450        }
451    }
452}
453
454impl fmt::Display for ProtocolError {
455    fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> {
456        match self {
457            ProtocolError::IoError(e) => write!(fmt, "I/O error: {e}"),
458            ProtocolError::InvalidMessage => write!(fmt, "Received an invalid message."),
459            ProtocolError::InvalidProtocol => write!(fmt, "A protocol (name) is invalid."),
460            ProtocolError::TooManyProtocols => write!(fmt, "Too many protocols received."),
461        }
462    }
463}
464
465#[cfg(test)]
466mod tests {
467    use std::iter;
468
469    use quickcheck::*;
470
471    use super::*;
472
473    impl Arbitrary for Protocol {
474        fn arbitrary(g: &mut Gen) -> Protocol {
475            let n = g.gen_range(1..g.size());
476            let p: String = iter::repeat(())
477                .map(|()| char::arbitrary(g))
478                .filter(|&c| c.is_ascii_alphanumeric())
479                .take(n)
480                .collect();
481            Protocol(format!("/{p}"))
482        }
483    }
484
485    impl Arbitrary for Message {
486        fn arbitrary(g: &mut Gen) -> Message {
487            match g.gen_range(0..5u8) {
488                0 => Message::Header(HeaderLine::V1),
489                1 => Message::NotAvailable,
490                2 => Message::ListProtocols,
491                3 => Message::Protocol(Protocol::arbitrary(g)),
492                4 => Message::Protocols(Vec::arbitrary(g)),
493                _ => panic!(),
494            }
495        }
496    }
497
498    #[test]
499    fn encode_decode_message() {
500        fn prop(msg: Message) {
501            let mut buf = BytesMut::new();
502            msg.encode(&mut buf);
503            match Message::decode(buf.freeze()) {
504                Ok(m) => assert_eq!(m, msg),
505                Err(e) => panic!("Decoding failed: {e:?}"),
506            }
507        }
508        quickcheck(prop as fn(_))
509    }
510}