libp2p_mplex/
codec.rs

1// Copyright 2018 Parity Technologies (UK) Ltd.
2//
3// Permission is hereby granted, free of charge, to any person obtaining a
4// copy of this software and associated documentation files (the "Software"),
5// to deal in the Software without restriction, including without limitation
6// the rights to use, copy, modify, merge, publish, distribute, sublicense,
7// and/or sell copies of the Software, and to permit persons to whom the
8// Software is furnished to do so, subject to the following conditions:
9//
10// The above copyright notice and this permission notice shall be included in
11// all copies or substantial portions of the Software.
12//
13// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
14// OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
16// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
18// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
19// DEALINGS IN THE SOFTWARE.
20
21use std::{
22    fmt,
23    hash::{Hash, Hasher},
24    io, mem,
25};
26
27use asynchronous_codec::{Decoder, Encoder};
28use bytes::{BufMut, Bytes, BytesMut};
29use libp2p_core::Endpoint;
30use unsigned_varint::{codec, encode};
31
32// Maximum size for a packet: 1MB as per the spec.
33// Since data is entirely buffered before being dispatched, we need a limit or remotes could just
34// send a 4 TB-long packet full of zeroes that we kill our process with an OOM error.
35pub(crate) const MAX_FRAME_SIZE: usize = 1024 * 1024;
36
37/// A unique identifier used by the local node for a substream.
38///
39/// `LocalStreamId`s are sent with frames to the remote, where
40/// they are received as `RemoteStreamId`s.
41///
42/// > **Note**: Streams are identified by a number and a role encoded as a flag
43/// > on each frame that is either odd (for receivers) or even (for initiators).
44/// > `Open` frames do not have a flag, but are sent unidirectionally. As a
45/// > consequence, we need to remember if a stream was initiated by us or remotely
46/// > and we store the information from our point of view as a `LocalStreamId`,
47/// > i.e. receiving an `Open` frame results in a local ID with role `Endpoint::Listener`,
48/// > whilst sending an `Open` frame results in a local ID with role `Endpoint::Dialer`.
49/// > Receiving a frame with a flag identifying the remote as a "receiver" means that
50/// > we initiated the stream, so the local ID has the role `Endpoint::Dialer`.
51/// > Conversely, when receiving a frame with a flag identifying the remote as a "sender",
52/// > the corresponding local ID has the role `Endpoint::Listener`.
53#[derive(Copy, Clone, Eq, Debug)]
54pub(crate) struct LocalStreamId {
55    num: u64,
56    role: Endpoint,
57}
58
59impl fmt::Display for LocalStreamId {
60    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
61        match self.role {
62            Endpoint::Dialer => write!(f, "({}/initiator)", self.num),
63            Endpoint::Listener => write!(f, "({}/receiver)", self.num),
64        }
65    }
66}
67
68/// Manual implementation of [`PartialEq`].
69///
70/// This is equivalent to the derived one but we purposely don't derive it because it triggers the
71/// `clippy::derive_hash_xor_eq` lint.
72///
73/// This [`PartialEq`] implementation satisfies the rule of v1 == v2 -> hash(v1) == hash(v2).
74/// The inverse is not true but does not have to be.
75impl PartialEq for LocalStreamId {
76    fn eq(&self, other: &Self) -> bool {
77        self.num.eq(&other.num) && self.role.eq(&other.role)
78    }
79}
80
81impl Hash for LocalStreamId {
82    fn hash<H: Hasher>(&self, state: &mut H) {
83        state.write_u64(self.num);
84    }
85}
86
87impl nohash_hasher::IsEnabled for LocalStreamId {}
88
89/// A unique identifier used by the remote node for a substream.
90///
91/// `RemoteStreamId`s are received with frames from the remote
92/// and mapped by the receiver to `LocalStreamId`s via
93/// [`RemoteStreamId::into_local()`].
94#[derive(Copy, Clone, PartialEq, Eq, Hash, Debug)]
95pub(crate) struct RemoteStreamId {
96    num: u64,
97    role: Endpoint,
98}
99
100impl LocalStreamId {
101    pub(crate) fn dialer(num: u64) -> Self {
102        Self {
103            num,
104            role: Endpoint::Dialer,
105        }
106    }
107
108    #[cfg(test)]
109    pub(crate) fn listener(num: u64) -> Self {
110        Self {
111            num,
112            role: Endpoint::Listener,
113        }
114    }
115
116    pub(crate) fn next(self) -> Self {
117        Self {
118            num: self
119                .num
120                .checked_add(1)
121                .expect("Mplex substream ID overflowed"),
122            ..self
123        }
124    }
125
126    #[cfg(test)]
127    pub(crate) fn into_remote(self) -> RemoteStreamId {
128        RemoteStreamId {
129            num: self.num,
130            role: !self.role,
131        }
132    }
133}
134
135impl RemoteStreamId {
136    fn dialer(num: u64) -> Self {
137        Self {
138            num,
139            role: Endpoint::Dialer,
140        }
141    }
142
143    fn listener(num: u64) -> Self {
144        Self {
145            num,
146            role: Endpoint::Listener,
147        }
148    }
149
150    /// Converts this `RemoteStreamId` into the corresponding `LocalStreamId`
151    /// that identifies the same substream.
152    pub(crate) fn into_local(self) -> LocalStreamId {
153        LocalStreamId {
154            num: self.num,
155            role: !self.role,
156        }
157    }
158}
159
160/// An Mplex protocol frame.
161#[derive(Debug, Clone, PartialEq, Eq)]
162pub(crate) enum Frame<T> {
163    Open { stream_id: T },
164    Data { stream_id: T, data: Bytes },
165    Close { stream_id: T },
166    Reset { stream_id: T },
167}
168
169impl Frame<RemoteStreamId> {
170    pub(crate) fn remote_id(&self) -> RemoteStreamId {
171        match *self {
172            Frame::Open { stream_id } => stream_id,
173            Frame::Data { stream_id, .. } => stream_id,
174            Frame::Close { stream_id, .. } => stream_id,
175            Frame::Reset { stream_id, .. } => stream_id,
176        }
177    }
178}
179
180pub(crate) struct Codec {
181    varint_decoder: codec::Uvi<u64>,
182    decoder_state: CodecDecodeState,
183}
184
185#[derive(Debug, Clone)]
186enum CodecDecodeState {
187    Begin,
188    HasHeader(u64),
189    HasHeaderAndLen(u64, usize),
190    Poisoned,
191}
192
193impl Codec {
194    pub(crate) fn new() -> Codec {
195        Codec {
196            varint_decoder: codec::Uvi::default(),
197            decoder_state: CodecDecodeState::Begin,
198        }
199    }
200}
201
202impl Decoder for Codec {
203    type Item = Frame<RemoteStreamId>;
204    type Error = io::Error;
205
206    fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
207        loop {
208            match mem::replace(&mut self.decoder_state, CodecDecodeState::Poisoned) {
209                CodecDecodeState::Begin => match self.varint_decoder.decode(src)? {
210                    Some(header) => {
211                        self.decoder_state = CodecDecodeState::HasHeader(header);
212                    }
213                    None => {
214                        self.decoder_state = CodecDecodeState::Begin;
215                        return Ok(None);
216                    }
217                },
218                CodecDecodeState::HasHeader(header) => match self.varint_decoder.decode(src)? {
219                    Some(len) => {
220                        if len as usize > MAX_FRAME_SIZE {
221                            let msg = format!("Mplex frame length {len} exceeds maximum");
222                            return Err(io::Error::new(io::ErrorKind::InvalidData, msg));
223                        }
224
225                        self.decoder_state =
226                            CodecDecodeState::HasHeaderAndLen(header, len as usize);
227                    }
228                    None => {
229                        self.decoder_state = CodecDecodeState::HasHeader(header);
230                        return Ok(None);
231                    }
232                },
233                CodecDecodeState::HasHeaderAndLen(header, len) => {
234                    if src.len() < len {
235                        self.decoder_state = CodecDecodeState::HasHeaderAndLen(header, len);
236                        let to_reserve = len - src.len();
237                        src.reserve(to_reserve);
238                        return Ok(None);
239                    }
240
241                    let buf = src.split_to(len);
242                    let num = header >> 3;
243                    let out = match header & 7 {
244                        0 => Frame::Open {
245                            stream_id: RemoteStreamId::dialer(num),
246                        },
247                        1 => Frame::Data {
248                            stream_id: RemoteStreamId::listener(num),
249                            data: buf.freeze(),
250                        },
251                        2 => Frame::Data {
252                            stream_id: RemoteStreamId::dialer(num),
253                            data: buf.freeze(),
254                        },
255                        3 => Frame::Close {
256                            stream_id: RemoteStreamId::listener(num),
257                        },
258                        4 => Frame::Close {
259                            stream_id: RemoteStreamId::dialer(num),
260                        },
261                        5 => Frame::Reset {
262                            stream_id: RemoteStreamId::listener(num),
263                        },
264                        6 => Frame::Reset {
265                            stream_id: RemoteStreamId::dialer(num),
266                        },
267                        _ => {
268                            let msg = format!("Invalid mplex header value 0x{header:x}");
269                            return Err(io::Error::new(io::ErrorKind::InvalidData, msg));
270                        }
271                    };
272
273                    self.decoder_state = CodecDecodeState::Begin;
274                    return Ok(Some(out));
275                }
276
277                CodecDecodeState::Poisoned => {
278                    return Err(io::Error::new(
279                        io::ErrorKind::InvalidData,
280                        "Mplex codec poisoned",
281                    ));
282                }
283            }
284        }
285    }
286}
287
288impl Encoder for Codec {
289    type Item<'a> = Frame<LocalStreamId>;
290    type Error = io::Error;
291
292    fn encode(&mut self, item: Self::Item<'_>, dst: &mut BytesMut) -> Result<(), Self::Error> {
293        let (header, data) = match item {
294            Frame::Open { stream_id } => (stream_id.num << 3, Bytes::new()),
295            Frame::Data {
296                stream_id:
297                    LocalStreamId {
298                        num,
299                        role: Endpoint::Listener,
300                    },
301                data,
302            } => ((num << 3) | 1, data),
303            Frame::Data {
304                stream_id:
305                    LocalStreamId {
306                        num,
307                        role: Endpoint::Dialer,
308                    },
309                data,
310            } => ((num << 3) | 2, data),
311            Frame::Close {
312                stream_id:
313                    LocalStreamId {
314                        num,
315                        role: Endpoint::Listener,
316                    },
317            } => ((num << 3) | 3, Bytes::new()),
318            Frame::Close {
319                stream_id:
320                    LocalStreamId {
321                        num,
322                        role: Endpoint::Dialer,
323                    },
324            } => ((num << 3) | 4, Bytes::new()),
325            Frame::Reset {
326                stream_id:
327                    LocalStreamId {
328                        num,
329                        role: Endpoint::Listener,
330                    },
331            } => ((num << 3) | 5, Bytes::new()),
332            Frame::Reset {
333                stream_id:
334                    LocalStreamId {
335                        num,
336                        role: Endpoint::Dialer,
337                    },
338            } => ((num << 3) | 6, Bytes::new()),
339        };
340
341        let mut header_buf = encode::u64_buffer();
342        let header_bytes = encode::u64(header, &mut header_buf);
343
344        let data_len = data.as_ref().len();
345        let mut data_buf = encode::usize_buffer();
346        let data_len_bytes = encode::usize(data_len, &mut data_buf);
347
348        if data_len > MAX_FRAME_SIZE {
349            return Err(io::Error::new(
350                io::ErrorKind::InvalidData,
351                "data size exceed maximum",
352            ));
353        }
354
355        dst.reserve(header_bytes.len() + data_len_bytes.len() + data_len);
356        dst.put(header_bytes);
357        dst.put(data_len_bytes);
358        dst.put(data);
359        Ok(())
360    }
361}
362
363#[cfg(test)]
364mod tests {
365    use super::*;
366
367    #[test]
368    fn encode_large_messages_fails() {
369        let mut enc = Codec::new();
370        let role = Endpoint::Dialer;
371        let data = Bytes::from(&[123u8; MAX_FRAME_SIZE + 1][..]);
372        let bad_msg = Frame::Data {
373            stream_id: LocalStreamId { num: 123, role },
374            data,
375        };
376        let mut out = BytesMut::new();
377        match enc.encode(bad_msg, &mut out) {
378            Err(e) => assert_eq!(e.to_string(), "data size exceed maximum"),
379            _ => panic!("Can't send a message bigger than MAX_FRAME_SIZE"),
380        }
381
382        let data = Bytes::from(&[123u8; MAX_FRAME_SIZE][..]);
383        let ok_msg = Frame::Data {
384            stream_id: LocalStreamId { num: 123, role },
385            data,
386        };
387        assert!(enc.encode(ok_msg, &mut out).is_ok());
388    }
389
390    #[test]
391    fn test_60bit_stream_id() {
392        // Create new codec object for encoding and decoding our frame.
393        let mut codec = Codec::new();
394        // Create a u64 stream ID.
395        let id: u64 = u32::MAX as u64 + 1;
396        let stream_id = LocalStreamId {
397            num: id,
398            role: Endpoint::Dialer,
399        };
400
401        // Open a new frame with that stream ID.
402        let original_frame = Frame::Open { stream_id };
403
404        // Encode that frame.
405        let mut enc_frame = BytesMut::new();
406        codec
407            .encode(original_frame, &mut enc_frame)
408            .expect("Encoding to succeed.");
409
410        // Decode encoded frame and extract stream ID.
411        let dec_string_id = codec
412            .decode(&mut enc_frame)
413            .expect("Decoding to succeed.")
414            .map(|f| f.remote_id())
415            .unwrap();
416
417        assert_eq!(dec_string_id.num, stream_id.num);
418    }
419}