1use std::{
22 fmt,
23 hash::{Hash, Hasher},
24 io, mem,
25};
26
27use asynchronous_codec::{Decoder, Encoder};
28use bytes::{BufMut, Bytes, BytesMut};
29use libp2p_core::Endpoint;
30use unsigned_varint::{codec, encode};
31
32pub(crate) const MAX_FRAME_SIZE: usize = 1024 * 1024;
36
37#[derive(Copy, Clone, Eq, Debug)]
54pub(crate) struct LocalStreamId {
55 num: u64,
56 role: Endpoint,
57}
58
59impl fmt::Display for LocalStreamId {
60 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
61 match self.role {
62 Endpoint::Dialer => write!(f, "({}/initiator)", self.num),
63 Endpoint::Listener => write!(f, "({}/receiver)", self.num),
64 }
65 }
66}
67
68impl PartialEq for LocalStreamId {
76 fn eq(&self, other: &Self) -> bool {
77 self.num.eq(&other.num) && self.role.eq(&other.role)
78 }
79}
80
81impl Hash for LocalStreamId {
82 fn hash<H: Hasher>(&self, state: &mut H) {
83 state.write_u64(self.num);
84 }
85}
86
87impl nohash_hasher::IsEnabled for LocalStreamId {}
88
89#[derive(Copy, Clone, PartialEq, Eq, Hash, Debug)]
95pub(crate) struct RemoteStreamId {
96 num: u64,
97 role: Endpoint,
98}
99
100impl LocalStreamId {
101 pub(crate) fn dialer(num: u64) -> Self {
102 Self {
103 num,
104 role: Endpoint::Dialer,
105 }
106 }
107
108 #[cfg(test)]
109 pub(crate) fn listener(num: u64) -> Self {
110 Self {
111 num,
112 role: Endpoint::Listener,
113 }
114 }
115
116 pub(crate) fn next(self) -> Self {
117 Self {
118 num: self
119 .num
120 .checked_add(1)
121 .expect("Mplex substream ID overflowed"),
122 ..self
123 }
124 }
125
126 #[cfg(test)]
127 pub(crate) fn into_remote(self) -> RemoteStreamId {
128 RemoteStreamId {
129 num: self.num,
130 role: !self.role,
131 }
132 }
133}
134
135impl RemoteStreamId {
136 fn dialer(num: u64) -> Self {
137 Self {
138 num,
139 role: Endpoint::Dialer,
140 }
141 }
142
143 fn listener(num: u64) -> Self {
144 Self {
145 num,
146 role: Endpoint::Listener,
147 }
148 }
149
150 pub(crate) fn into_local(self) -> LocalStreamId {
153 LocalStreamId {
154 num: self.num,
155 role: !self.role,
156 }
157 }
158}
159
160#[derive(Debug, Clone, PartialEq, Eq)]
162pub(crate) enum Frame<T> {
163 Open { stream_id: T },
164 Data { stream_id: T, data: Bytes },
165 Close { stream_id: T },
166 Reset { stream_id: T },
167}
168
169impl Frame<RemoteStreamId> {
170 pub(crate) fn remote_id(&self) -> RemoteStreamId {
171 match *self {
172 Frame::Open { stream_id } => stream_id,
173 Frame::Data { stream_id, .. } => stream_id,
174 Frame::Close { stream_id, .. } => stream_id,
175 Frame::Reset { stream_id, .. } => stream_id,
176 }
177 }
178}
179
180pub(crate) struct Codec {
181 varint_decoder: codec::Uvi<u64>,
182 decoder_state: CodecDecodeState,
183}
184
185#[derive(Debug, Clone)]
186enum CodecDecodeState {
187 Begin,
188 HasHeader(u64),
189 HasHeaderAndLen(u64, usize),
190 Poisoned,
191}
192
193impl Codec {
194 pub(crate) fn new() -> Codec {
195 Codec {
196 varint_decoder: codec::Uvi::default(),
197 decoder_state: CodecDecodeState::Begin,
198 }
199 }
200}
201
202impl Decoder for Codec {
203 type Item = Frame<RemoteStreamId>;
204 type Error = io::Error;
205
206 fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
207 loop {
208 match mem::replace(&mut self.decoder_state, CodecDecodeState::Poisoned) {
209 CodecDecodeState::Begin => match self.varint_decoder.decode(src)? {
210 Some(header) => {
211 self.decoder_state = CodecDecodeState::HasHeader(header);
212 }
213 None => {
214 self.decoder_state = CodecDecodeState::Begin;
215 return Ok(None);
216 }
217 },
218 CodecDecodeState::HasHeader(header) => match self.varint_decoder.decode(src)? {
219 Some(len) => {
220 if len as usize > MAX_FRAME_SIZE {
221 let msg = format!("Mplex frame length {len} exceeds maximum");
222 return Err(io::Error::new(io::ErrorKind::InvalidData, msg));
223 }
224
225 self.decoder_state =
226 CodecDecodeState::HasHeaderAndLen(header, len as usize);
227 }
228 None => {
229 self.decoder_state = CodecDecodeState::HasHeader(header);
230 return Ok(None);
231 }
232 },
233 CodecDecodeState::HasHeaderAndLen(header, len) => {
234 if src.len() < len {
235 self.decoder_state = CodecDecodeState::HasHeaderAndLen(header, len);
236 let to_reserve = len - src.len();
237 src.reserve(to_reserve);
238 return Ok(None);
239 }
240
241 let buf = src.split_to(len);
242 let num = header >> 3;
243 let out = match header & 7 {
244 0 => Frame::Open {
245 stream_id: RemoteStreamId::dialer(num),
246 },
247 1 => Frame::Data {
248 stream_id: RemoteStreamId::listener(num),
249 data: buf.freeze(),
250 },
251 2 => Frame::Data {
252 stream_id: RemoteStreamId::dialer(num),
253 data: buf.freeze(),
254 },
255 3 => Frame::Close {
256 stream_id: RemoteStreamId::listener(num),
257 },
258 4 => Frame::Close {
259 stream_id: RemoteStreamId::dialer(num),
260 },
261 5 => Frame::Reset {
262 stream_id: RemoteStreamId::listener(num),
263 },
264 6 => Frame::Reset {
265 stream_id: RemoteStreamId::dialer(num),
266 },
267 _ => {
268 let msg = format!("Invalid mplex header value 0x{header:x}");
269 return Err(io::Error::new(io::ErrorKind::InvalidData, msg));
270 }
271 };
272
273 self.decoder_state = CodecDecodeState::Begin;
274 return Ok(Some(out));
275 }
276
277 CodecDecodeState::Poisoned => {
278 return Err(io::Error::new(
279 io::ErrorKind::InvalidData,
280 "Mplex codec poisoned",
281 ));
282 }
283 }
284 }
285 }
286}
287
288impl Encoder for Codec {
289 type Item<'a> = Frame<LocalStreamId>;
290 type Error = io::Error;
291
292 fn encode(&mut self, item: Self::Item<'_>, dst: &mut BytesMut) -> Result<(), Self::Error> {
293 let (header, data) = match item {
294 Frame::Open { stream_id } => (stream_id.num << 3, Bytes::new()),
295 Frame::Data {
296 stream_id:
297 LocalStreamId {
298 num,
299 role: Endpoint::Listener,
300 },
301 data,
302 } => ((num << 3) | 1, data),
303 Frame::Data {
304 stream_id:
305 LocalStreamId {
306 num,
307 role: Endpoint::Dialer,
308 },
309 data,
310 } => ((num << 3) | 2, data),
311 Frame::Close {
312 stream_id:
313 LocalStreamId {
314 num,
315 role: Endpoint::Listener,
316 },
317 } => ((num << 3) | 3, Bytes::new()),
318 Frame::Close {
319 stream_id:
320 LocalStreamId {
321 num,
322 role: Endpoint::Dialer,
323 },
324 } => ((num << 3) | 4, Bytes::new()),
325 Frame::Reset {
326 stream_id:
327 LocalStreamId {
328 num,
329 role: Endpoint::Listener,
330 },
331 } => ((num << 3) | 5, Bytes::new()),
332 Frame::Reset {
333 stream_id:
334 LocalStreamId {
335 num,
336 role: Endpoint::Dialer,
337 },
338 } => ((num << 3) | 6, Bytes::new()),
339 };
340
341 let mut header_buf = encode::u64_buffer();
342 let header_bytes = encode::u64(header, &mut header_buf);
343
344 let data_len = data.as_ref().len();
345 let mut data_buf = encode::usize_buffer();
346 let data_len_bytes = encode::usize(data_len, &mut data_buf);
347
348 if data_len > MAX_FRAME_SIZE {
349 return Err(io::Error::new(
350 io::ErrorKind::InvalidData,
351 "data size exceed maximum",
352 ));
353 }
354
355 dst.reserve(header_bytes.len() + data_len_bytes.len() + data_len);
356 dst.put(header_bytes);
357 dst.put(data_len_bytes);
358 dst.put(data);
359 Ok(())
360 }
361}
362
363#[cfg(test)]
364mod tests {
365 use super::*;
366
367 #[test]
368 fn encode_large_messages_fails() {
369 let mut enc = Codec::new();
370 let role = Endpoint::Dialer;
371 let data = Bytes::from(&[123u8; MAX_FRAME_SIZE + 1][..]);
372 let bad_msg = Frame::Data {
373 stream_id: LocalStreamId { num: 123, role },
374 data,
375 };
376 let mut out = BytesMut::new();
377 match enc.encode(bad_msg, &mut out) {
378 Err(e) => assert_eq!(e.to_string(), "data size exceed maximum"),
379 _ => panic!("Can't send a message bigger than MAX_FRAME_SIZE"),
380 }
381
382 let data = Bytes::from(&[123u8; MAX_FRAME_SIZE][..]);
383 let ok_msg = Frame::Data {
384 stream_id: LocalStreamId { num: 123, role },
385 data,
386 };
387 assert!(enc.encode(ok_msg, &mut out).is_ok());
388 }
389
390 #[test]
391 fn test_60bit_stream_id() {
392 let mut codec = Codec::new();
394 let id: u64 = u32::MAX as u64 + 1;
396 let stream_id = LocalStreamId {
397 num: id,
398 role: Endpoint::Dialer,
399 };
400
401 let original_frame = Frame::Open { stream_id };
403
404 let mut enc_frame = BytesMut::new();
406 codec
407 .encode(original_frame, &mut enc_frame)
408 .expect("Encoding to succeed.");
409
410 let dec_string_id = codec
412 .decode(&mut enc_frame)
413 .expect("Decoding to succeed.")
414 .map(|f| f.remote_id())
415 .unwrap();
416
417 assert_eq!(dec_string_id.num, stream_id.num);
418 }
419}