1use std::{
29 error::Error,
30 fmt, io,
31 pin::Pin,
32 task::{Context, Poll},
33};
34
35use bytes::{BufMut, Bytes, BytesMut};
36use futures::{io::IoSlice, prelude::*, ready};
37use unsigned_varint as uvi;
38
39use crate::{
40 length_delimited::{LengthDelimited, LengthDelimitedReader},
41 Version,
42};
43
44const MAX_PROTOCOLS: usize = 1000;
46
47const MSG_MULTISTREAM_1_0: &[u8] = b"/multistream/1.0.0\n";
49const MSG_PROTOCOL_NA: &[u8] = b"na\n";
51const MSG_LS: &[u8] = b"ls\n";
53
54#[derive(Copy, Clone, Debug, PartialEq, Eq)]
58pub(crate) enum HeaderLine {
59 V1,
61}
62
63impl From<Version> for HeaderLine {
64 fn from(v: Version) -> HeaderLine {
65 match v {
66 Version::V1 | Version::V1Lazy => HeaderLine::V1,
67 }
68 }
69}
70
71#[derive(Clone, Debug, PartialEq, Eq)]
73pub(crate) struct Protocol(String);
74impl AsRef<str> for Protocol {
75 fn as_ref(&self) -> &str {
76 self.0.as_ref()
77 }
78}
79
80impl TryFrom<Bytes> for Protocol {
81 type Error = ProtocolError;
82
83 fn try_from(value: Bytes) -> Result<Self, Self::Error> {
84 if !value.as_ref().starts_with(b"/") {
85 return Err(ProtocolError::InvalidProtocol);
86 }
87 let protocol_as_string =
88 String::from_utf8(value.to_vec()).map_err(|_| ProtocolError::InvalidProtocol)?;
89
90 Ok(Protocol(protocol_as_string))
91 }
92}
93
94impl TryFrom<&[u8]> for Protocol {
95 type Error = ProtocolError;
96
97 fn try_from(value: &[u8]) -> Result<Self, Self::Error> {
98 Self::try_from(Bytes::copy_from_slice(value))
99 }
100}
101
102impl TryFrom<&str> for Protocol {
103 type Error = ProtocolError;
104
105 fn try_from(value: &str) -> Result<Self, Self::Error> {
106 if !value.starts_with('/') {
107 return Err(ProtocolError::InvalidProtocol);
108 }
109
110 Ok(Protocol(value.to_owned()))
111 }
112}
113
114impl fmt::Display for Protocol {
115 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
116 write!(f, "{}", self.0)
117 }
118}
119
120#[derive(Debug, Clone, PartialEq, Eq)]
125pub(crate) enum Message {
126 Header(HeaderLine),
129 Protocol(Protocol),
131 ListProtocols,
134 Protocols(Vec<Protocol>),
136 NotAvailable,
138}
139
140impl Message {
141 fn encode(&self, dest: &mut BytesMut) {
143 match self {
144 Message::Header(HeaderLine::V1) => {
145 dest.reserve(MSG_MULTISTREAM_1_0.len());
146 dest.put(MSG_MULTISTREAM_1_0);
147 }
148 Message::Protocol(p) => {
149 let len = p.as_ref().len() + 1; dest.reserve(len);
151 dest.put(p.0.as_ref());
152 dest.put_u8(b'\n');
153 }
154 Message::ListProtocols => {
155 dest.reserve(MSG_LS.len());
156 dest.put(MSG_LS);
157 }
158 Message::Protocols(ps) => {
159 let mut buf = uvi::encode::usize_buffer();
160 let mut encoded = Vec::with_capacity(ps.len());
161 for p in ps {
162 encoded.extend(uvi::encode::usize(p.as_ref().len() + 1, &mut buf)); encoded.extend_from_slice(p.0.as_ref());
164 encoded.push(b'\n')
165 }
166 encoded.push(b'\n');
167 dest.reserve(encoded.len());
168 dest.put(encoded.as_ref());
169 }
170 Message::NotAvailable => {
171 dest.reserve(MSG_PROTOCOL_NA.len());
172 dest.put(MSG_PROTOCOL_NA);
173 }
174 }
175 }
176
177 fn decode(mut msg: Bytes) -> Result<Message, ProtocolError> {
179 if msg == MSG_MULTISTREAM_1_0 {
180 return Ok(Message::Header(HeaderLine::V1));
181 }
182
183 if msg == MSG_PROTOCOL_NA {
184 return Ok(Message::NotAvailable);
185 }
186
187 if msg == MSG_LS {
188 return Ok(Message::ListProtocols);
189 }
190
191 if msg.first() == Some(&b'/')
194 && msg.last() == Some(&b'\n')
195 && !msg[..msg.len() - 1].contains(&b'\n')
196 {
197 let p = Protocol::try_from(msg.split_to(msg.len() - 1))?;
198 return Ok(Message::Protocol(p));
199 }
200
201 let mut protocols = Vec::new();
204 let mut remaining: &[u8] = &msg;
205 loop {
206 if remaining == [b'\n'] {
208 break;
209 } else if protocols.len() == MAX_PROTOCOLS {
210 return Err(ProtocolError::TooManyProtocols);
211 }
212
213 let (len, tail) = uvi::decode::usize(remaining)?;
216 if len == 0 || len > tail.len() || tail[len - 1] != b'\n' {
217 return Err(ProtocolError::InvalidMessage);
218 }
219
220 let p = Protocol::try_from(Bytes::copy_from_slice(&tail[..len - 1]))?;
222 protocols.push(p);
223
224 remaining = &tail[len..];
226 }
227
228 Ok(Message::Protocols(protocols))
229 }
230}
231
232#[pin_project::pin_project]
234pub(crate) struct MessageIO<R> {
235 #[pin]
236 inner: LengthDelimited<R>,
237}
238
239impl<R> MessageIO<R> {
240 pub(crate) fn new(inner: R) -> MessageIO<R>
242 where
243 R: AsyncRead + AsyncWrite,
244 {
245 Self {
246 inner: LengthDelimited::new(inner),
247 }
248 }
249
250 pub(crate) fn into_reader(self) -> MessageReader<R> {
258 MessageReader {
259 inner: self.inner.into_reader(),
260 }
261 }
262
263 pub(crate) fn into_inner(self) -> R {
273 self.inner.into_inner()
274 }
275}
276
277impl<R> Sink<Message> for MessageIO<R>
278where
279 R: AsyncWrite,
280{
281 type Error = ProtocolError;
282
283 fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
284 self.project().inner.poll_ready(cx).map_err(From::from)
285 }
286
287 fn start_send(self: Pin<&mut Self>, item: Message) -> Result<(), Self::Error> {
288 let mut buf = BytesMut::new();
289 item.encode(&mut buf);
290 self.project()
291 .inner
292 .start_send(buf.freeze())
293 .map_err(From::from)
294 }
295
296 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
297 self.project().inner.poll_flush(cx).map_err(From::from)
298 }
299
300 fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
301 self.project().inner.poll_close(cx).map_err(From::from)
302 }
303}
304
305impl<R> Stream for MessageIO<R>
306where
307 R: AsyncRead,
308{
309 type Item = Result<Message, ProtocolError>;
310
311 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
312 match poll_stream(self.project().inner, cx) {
313 Poll::Pending => Poll::Pending,
314 Poll::Ready(None) => Poll::Ready(None),
315 Poll::Ready(Some(Ok(m))) => Poll::Ready(Some(Ok(m))),
316 Poll::Ready(Some(Err(err))) => Poll::Ready(Some(Err(err))),
317 }
318 }
319}
320
321#[pin_project::pin_project]
324#[derive(Debug)]
325pub(crate) struct MessageReader<R> {
326 #[pin]
327 inner: LengthDelimitedReader<R>,
328}
329
330impl<R> MessageReader<R> {
331 pub(crate) fn into_inner(self) -> R {
343 self.inner.into_inner()
344 }
345}
346
347impl<R> Stream for MessageReader<R>
348where
349 R: AsyncRead,
350{
351 type Item = Result<Message, ProtocolError>;
352
353 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
354 poll_stream(self.project().inner, cx)
355 }
356}
357
358impl<TInner> AsyncWrite for MessageReader<TInner>
359where
360 TInner: AsyncWrite,
361{
362 fn poll_write(
363 self: Pin<&mut Self>,
364 cx: &mut Context<'_>,
365 buf: &[u8],
366 ) -> Poll<Result<usize, io::Error>> {
367 self.project().inner.poll_write(cx, buf)
368 }
369
370 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
371 self.project().inner.poll_flush(cx)
372 }
373
374 fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
375 self.project().inner.poll_close(cx)
376 }
377
378 fn poll_write_vectored(
379 self: Pin<&mut Self>,
380 cx: &mut Context<'_>,
381 bufs: &[IoSlice<'_>],
382 ) -> Poll<Result<usize, io::Error>> {
383 self.project().inner.poll_write_vectored(cx, bufs)
384 }
385}
386
387fn poll_stream<S>(
388 stream: Pin<&mut S>,
389 cx: &mut Context<'_>,
390) -> Poll<Option<Result<Message, ProtocolError>>>
391where
392 S: Stream<Item = Result<Bytes, io::Error>>,
393{
394 let msg = if let Some(msg) = ready!(stream.poll_next(cx)?) {
395 match Message::decode(msg) {
396 Ok(m) => m,
397 Err(err) => return Poll::Ready(Some(Err(err))),
398 }
399 } else {
400 return Poll::Ready(None);
401 };
402
403 tracing::trace!(message=?msg, "Received message");
404
405 Poll::Ready(Some(Ok(msg)))
406}
407
408#[derive(Debug)]
410pub enum ProtocolError {
411 IoError(io::Error),
413
414 InvalidMessage,
416
417 InvalidProtocol,
419
420 TooManyProtocols,
422}
423
424impl From<io::Error> for ProtocolError {
425 fn from(err: io::Error) -> ProtocolError {
426 ProtocolError::IoError(err)
427 }
428}
429
430impl From<ProtocolError> for io::Error {
431 fn from(err: ProtocolError) -> Self {
432 if let ProtocolError::IoError(e) = err {
433 return e;
434 }
435 io::ErrorKind::InvalidData.into()
436 }
437}
438
439impl From<uvi::decode::Error> for ProtocolError {
440 fn from(err: uvi::decode::Error) -> ProtocolError {
441 Self::from(io::Error::new(io::ErrorKind::InvalidData, err.to_string()))
442 }
443}
444
445impl Error for ProtocolError {
446 fn source(&self) -> Option<&(dyn Error + 'static)> {
447 match *self {
448 ProtocolError::IoError(ref err) => Some(err),
449 _ => None,
450 }
451 }
452}
453
454impl fmt::Display for ProtocolError {
455 fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> {
456 match self {
457 ProtocolError::IoError(e) => write!(fmt, "I/O error: {e}"),
458 ProtocolError::InvalidMessage => write!(fmt, "Received an invalid message."),
459 ProtocolError::InvalidProtocol => write!(fmt, "A protocol (name) is invalid."),
460 ProtocolError::TooManyProtocols => write!(fmt, "Too many protocols received."),
461 }
462 }
463}
464
465#[cfg(test)]
466mod tests {
467 use std::iter;
468
469 use quickcheck::*;
470
471 use super::*;
472
473 impl Arbitrary for Protocol {
474 fn arbitrary(g: &mut Gen) -> Protocol {
475 let n = g.gen_range(1..g.size());
476 let p: String = iter::repeat(())
477 .map(|()| char::arbitrary(g))
478 .filter(|&c| c.is_ascii_alphanumeric())
479 .take(n)
480 .collect();
481 Protocol(format!("/{p}"))
482 }
483 }
484
485 impl Arbitrary for Message {
486 fn arbitrary(g: &mut Gen) -> Message {
487 match g.gen_range(0..5u8) {
488 0 => Message::Header(HeaderLine::V1),
489 1 => Message::NotAvailable,
490 2 => Message::ListProtocols,
491 3 => Message::Protocol(Protocol::arbitrary(g)),
492 4 => Message::Protocols(Vec::arbitrary(g)),
493 _ => panic!(),
494 }
495 }
496 }
497
498 #[test]
499 fn encode_decode_message() {
500 fn prop(msg: Message) {
501 let mut buf = BytesMut::new();
502 msg.encode(&mut buf);
503 match Message::decode(buf.freeze()) {
504 Ok(m) => assert_eq!(m, msg),
505 Err(e) => panic!("Decoding failed: {e:?}"),
506 }
507 }
508 quickcheck(prop as fn(_))
509 }
510}