libp2p_rendezvous/
codec.rs

1// Copyright 2021 COMIT Network.
2//
3// Permission is hereby granted, free of charge, to any person obtaining a
4// copy of this software and associated documentation files (the "Software"),
5// to deal in the Software without restriction, including without limitation
6// the rights to use, copy, modify, merge, publish, distribute, sublicense,
7// and/or sell copies of the Software, and to permit persons to whom the
8// Software is furnished to do so, subject to the following conditions:
9//
10// The above copyright notice and this permission notice shall be included in
11// all copies or substantial portions of the Software.
12//
13// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
14// OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
16// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
18// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
19// DEALINGS IN THE SOFTWARE.
20
21use std::{fmt, io};
22
23use async_trait::async_trait;
24use asynchronous_codec::{BytesMut, Decoder, Encoder, FramedRead, FramedWrite};
25use futures::{AsyncRead, AsyncWrite, SinkExt, StreamExt};
26use libp2p_core::{peer_record, signed_envelope, PeerRecord, SignedEnvelope};
27use libp2p_swarm::StreamProtocol;
28use quick_protobuf_codec::Codec as ProtobufCodec;
29use rand::RngCore;
30
31use crate::DEFAULT_TTL;
32
33pub type Ttl = u64;
34pub(crate) type Limit = u64;
35
36const MAX_MESSAGE_LEN_BYTES: usize = 1024 * 1024;
37
38#[allow(clippy::large_enum_variant)]
39#[derive(Debug, Clone, PartialEq)]
40pub enum Message {
41    Register(NewRegistration),
42    RegisterResponse(Result<Ttl, ErrorCode>),
43    Unregister(Namespace),
44    Discover {
45        namespace: Option<Namespace>,
46        cookie: Option<Cookie>,
47        limit: Option<Limit>,
48    },
49    DiscoverResponse(Result<(Vec<Registration>, Cookie), ErrorCode>),
50}
51
52#[derive(Debug, PartialEq, Eq, Hash, Clone)]
53pub struct Namespace(String);
54
55impl Namespace {
56    /// Creates a new [`Namespace`] from a static string.
57    ///
58    /// This will panic if the namespace is too long. We accepting panicking in this case because we
59    /// are enforcing a `static lifetime which means this value can only be a constant in the
60    /// program and hence we hope the developer checked that it is of an acceptable length.
61    pub fn from_static(value: &'static str) -> Self {
62        if value.len() > crate::MAX_NAMESPACE {
63            panic!("Namespace '{value}' is too long!")
64        }
65
66        Namespace(value.to_owned())
67    }
68
69    pub fn new(value: String) -> Result<Self, NamespaceTooLong> {
70        if value.len() > crate::MAX_NAMESPACE {
71            return Err(NamespaceTooLong);
72        }
73
74        Ok(Namespace(value))
75    }
76}
77
78impl From<Namespace> for String {
79    fn from(namespace: Namespace) -> Self {
80        namespace.0
81    }
82}
83
84impl fmt::Display for Namespace {
85    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
86        write!(f, "{}", self.0)
87    }
88}
89
90impl PartialEq<str> for Namespace {
91    fn eq(&self, other: &str) -> bool {
92        self.0.eq(other)
93    }
94}
95
96impl PartialEq<Namespace> for str {
97    fn eq(&self, other: &Namespace) -> bool {
98        other.0.eq(self)
99    }
100}
101
102#[derive(Debug, thiserror::Error)]
103#[error("Namespace is too long")]
104pub struct NamespaceTooLong;
105
106#[derive(Debug, Eq, PartialEq, Hash, Clone)]
107pub struct Cookie {
108    id: u64,
109    namespace: Option<Namespace>,
110}
111
112impl Cookie {
113    /// Construct a new [`Cookie`] for a given namespace.
114    ///
115    /// This cookie will only be valid for subsequent DISCOVER requests targeting the same
116    /// namespace.
117    pub fn for_namespace(namespace: Namespace) -> Self {
118        Self {
119            id: rand::thread_rng().next_u64(),
120            namespace: Some(namespace),
121        }
122    }
123
124    /// Construct a new [`Cookie`] for a DISCOVER request that inquires about all namespaces.
125    pub fn for_all_namespaces() -> Self {
126        Self {
127            id: rand::random(),
128            namespace: None,
129        }
130    }
131
132    pub fn into_wire_encoding(self) -> Vec<u8> {
133        let id_bytes = self.id.to_be_bytes();
134        let namespace = self.namespace.map(|ns| ns.0).unwrap_or_default();
135
136        let mut buffer = Vec::with_capacity(id_bytes.len() + namespace.len());
137        buffer.extend_from_slice(&id_bytes);
138        buffer.extend_from_slice(namespace.as_bytes());
139
140        buffer
141    }
142
143    pub fn from_wire_encoding(mut bytes: Vec<u8>) -> Result<Self, InvalidCookie> {
144        // check length early to avoid panic during slicing
145        if bytes.len() < 8 {
146            return Err(InvalidCookie);
147        }
148
149        let namespace = bytes.split_off(8);
150        let namespace = if namespace.is_empty() {
151            None
152        } else {
153            Some(
154                Namespace::new(String::from_utf8(namespace).map_err(|_| InvalidCookie)?)
155                    .map_err(|_| InvalidCookie)?,
156            )
157        };
158
159        let bytes = <[u8; 8]>::try_from(bytes).map_err(|_| InvalidCookie)?;
160        let id = u64::from_be_bytes(bytes);
161
162        Ok(Self { id, namespace })
163    }
164
165    pub fn namespace(&self) -> Option<&Namespace> {
166        self.namespace.as_ref()
167    }
168}
169
170#[derive(Debug, thiserror::Error)]
171#[error("The cookie was malformed")]
172pub struct InvalidCookie;
173
174#[derive(Debug, Clone, PartialEq)]
175pub struct NewRegistration {
176    pub namespace: Namespace,
177    pub record: PeerRecord,
178    pub ttl: Option<u64>,
179}
180
181impl NewRegistration {
182    pub fn new(namespace: Namespace, record: PeerRecord, ttl: Option<Ttl>) -> Self {
183        Self {
184            namespace,
185            record,
186            ttl,
187        }
188    }
189
190    pub fn effective_ttl(&self) -> Ttl {
191        self.ttl.unwrap_or(DEFAULT_TTL)
192    }
193}
194
195#[derive(Debug, Clone, PartialEq, Eq)]
196pub struct Registration {
197    pub namespace: Namespace,
198    pub record: PeerRecord,
199    pub ttl: Ttl,
200}
201
202#[derive(Debug, Copy, Clone, PartialEq, Eq)]
203pub enum ErrorCode {
204    InvalidNamespace,
205    InvalidSignedPeerRecord,
206    InvalidTtl,
207    InvalidCookie,
208    NotAuthorized,
209    InternalError,
210    Unavailable,
211}
212
213impl Encoder for Codec {
214    type Item<'a> = Message;
215    type Error = Error;
216
217    fn encode(&mut self, item: Self::Item<'_>, dst: &mut BytesMut) -> Result<(), Self::Error> {
218        let mut pb: ProtobufCodec<proto::Message> = ProtobufCodec::new(MAX_MESSAGE_LEN_BYTES);
219
220        pb.encode(proto::Message::from(item), dst)?;
221
222        Ok(())
223    }
224}
225
226impl Decoder for Codec {
227    type Item = Message;
228    type Error = Error;
229
230    fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
231        let mut pb: ProtobufCodec<proto::Message> = ProtobufCodec::new(MAX_MESSAGE_LEN_BYTES);
232
233        let Some(message) = pb.decode(src)? else {
234            return Ok(None);
235        };
236
237        Ok(Some(message.try_into()?))
238    }
239}
240
241#[derive(Clone, Default)]
242pub struct Codec {}
243
244#[async_trait]
245impl libp2p_request_response::Codec for Codec {
246    type Protocol = StreamProtocol;
247    type Request = Message;
248    type Response = Message;
249
250    async fn read_request<T>(&mut self, _: &Self::Protocol, io: &mut T) -> io::Result<Self::Request>
251    where
252        T: AsyncRead + Unpin + Send,
253    {
254        let message = FramedRead::new(io, self.clone())
255            .next()
256            .await
257            .ok_or(io::ErrorKind::UnexpectedEof)??;
258
259        Ok(message)
260    }
261
262    async fn read_response<T>(
263        &mut self,
264        _: &Self::Protocol,
265        io: &mut T,
266    ) -> io::Result<Self::Response>
267    where
268        T: AsyncRead + Unpin + Send,
269    {
270        let message = FramedRead::new(io, self.clone())
271            .next()
272            .await
273            .ok_or(io::ErrorKind::UnexpectedEof)??;
274
275        Ok(message)
276    }
277
278    async fn write_request<T>(
279        &mut self,
280        _: &Self::Protocol,
281        io: &mut T,
282        req: Self::Request,
283    ) -> io::Result<()>
284    where
285        T: AsyncWrite + Unpin + Send,
286    {
287        FramedWrite::new(io, self.clone()).send(req).await?;
288
289        Ok(())
290    }
291
292    async fn write_response<T>(
293        &mut self,
294        _: &Self::Protocol,
295        io: &mut T,
296        res: Self::Response,
297    ) -> io::Result<()>
298    where
299        T: AsyncWrite + Unpin + Send,
300    {
301        FramedWrite::new(io, self.clone()).send(res).await?;
302
303        Ok(())
304    }
305}
306
307#[derive(Debug, thiserror::Error)]
308pub enum Error {
309    #[error(transparent)]
310    Codec(#[from] quick_protobuf_codec::Error),
311    #[error("Failed to read/write")]
312    Io(#[from] std::io::Error),
313    #[error("Failed to convert wire message to internal data model")]
314    Conversion(#[from] ConversionError),
315}
316
317impl From<Error> for std::io::Error {
318    fn from(value: Error) -> Self {
319        match value {
320            Error::Io(e) => e,
321            Error::Codec(e) => io::Error::from(e),
322            Error::Conversion(e) => io::Error::new(io::ErrorKind::InvalidInput, e),
323        }
324    }
325}
326
327impl From<Message> for proto::Message {
328    fn from(message: Message) -> Self {
329        match message {
330            Message::Register(NewRegistration {
331                namespace,
332                record,
333                ttl,
334            }) => proto::Message {
335                type_pb: Some(proto::MessageType::REGISTER),
336                register: Some(proto::Register {
337                    ns: Some(namespace.into()),
338                    ttl,
339                    signedPeerRecord: Some(record.into_signed_envelope().into_protobuf_encoding()),
340                }),
341                registerResponse: None,
342                unregister: None,
343                discover: None,
344                discoverResponse: None,
345            },
346            Message::RegisterResponse(Ok(ttl)) => proto::Message {
347                type_pb: Some(proto::MessageType::REGISTER_RESPONSE),
348                registerResponse: Some(proto::RegisterResponse {
349                    status: Some(proto::ResponseStatus::OK),
350                    statusText: None,
351                    ttl: Some(ttl),
352                }),
353                register: None,
354                discover: None,
355                unregister: None,
356                discoverResponse: None,
357            },
358            Message::RegisterResponse(Err(error)) => proto::Message {
359                type_pb: Some(proto::MessageType::REGISTER_RESPONSE),
360                registerResponse: Some(proto::RegisterResponse {
361                    status: Some(proto::ResponseStatus::from(error)),
362                    statusText: None,
363                    ttl: None,
364                }),
365                register: None,
366                discover: None,
367                unregister: None,
368                discoverResponse: None,
369            },
370            Message::Unregister(namespace) => proto::Message {
371                type_pb: Some(proto::MessageType::UNREGISTER),
372                unregister: Some(proto::Unregister {
373                    ns: Some(namespace.into()),
374                    id: None,
375                }),
376                register: None,
377                registerResponse: None,
378                discover: None,
379                discoverResponse: None,
380            },
381            Message::Discover {
382                namespace,
383                cookie,
384                limit,
385            } => proto::Message {
386                type_pb: Some(proto::MessageType::DISCOVER),
387                discover: Some(proto::Discover {
388                    ns: namespace.map(|ns| ns.into()),
389                    cookie: cookie.map(|cookie| cookie.into_wire_encoding()),
390                    limit,
391                }),
392                register: None,
393                registerResponse: None,
394                unregister: None,
395                discoverResponse: None,
396            },
397            Message::DiscoverResponse(Ok((registrations, cookie))) => proto::Message {
398                type_pb: Some(proto::MessageType::DISCOVER_RESPONSE),
399                discoverResponse: Some(proto::DiscoverResponse {
400                    registrations: registrations
401                        .into_iter()
402                        .map(|reggo| proto::Register {
403                            ns: Some(reggo.namespace.into()),
404                            ttl: Some(reggo.ttl),
405                            signedPeerRecord: Some(
406                                reggo.record.into_signed_envelope().into_protobuf_encoding(),
407                            ),
408                        })
409                        .collect(),
410                    status: Some(proto::ResponseStatus::OK),
411                    statusText: None,
412                    cookie: Some(cookie.into_wire_encoding()),
413                }),
414                register: None,
415                discover: None,
416                unregister: None,
417                registerResponse: None,
418            },
419            Message::DiscoverResponse(Err(error)) => proto::Message {
420                type_pb: Some(proto::MessageType::DISCOVER_RESPONSE),
421                discoverResponse: Some(proto::DiscoverResponse {
422                    registrations: Vec::new(),
423                    status: Some(proto::ResponseStatus::from(error)),
424                    statusText: None,
425                    cookie: None,
426                }),
427                register: None,
428                discover: None,
429                unregister: None,
430                registerResponse: None,
431            },
432        }
433    }
434}
435
436impl TryFrom<proto::Message> for Message {
437    type Error = ConversionError;
438
439    fn try_from(message: proto::Message) -> Result<Self, Self::Error> {
440        let message = match message {
441            proto::Message {
442                type_pb: Some(proto::MessageType::REGISTER),
443                register:
444                    Some(proto::Register {
445                        ns,
446                        ttl,
447                        signedPeerRecord: Some(signed_peer_record),
448                    }),
449                ..
450            } => Message::Register(NewRegistration {
451                namespace: ns
452                    .map(Namespace::new)
453                    .transpose()?
454                    .ok_or(ConversionError::MissingNamespace)?,
455                ttl,
456                record: PeerRecord::from_signed_envelope(SignedEnvelope::from_protobuf_encoding(
457                    &signed_peer_record,
458                )?)?,
459            }),
460            proto::Message {
461                type_pb: Some(proto::MessageType::REGISTER_RESPONSE),
462                registerResponse:
463                    Some(proto::RegisterResponse {
464                        status: Some(proto::ResponseStatus::OK),
465                        ttl,
466                        ..
467                    }),
468                ..
469            } => Message::RegisterResponse(Ok(ttl.ok_or(ConversionError::MissingTtl)?)),
470            proto::Message {
471                type_pb: Some(proto::MessageType::DISCOVER),
472                discover: Some(proto::Discover { ns, limit, cookie }),
473                ..
474            } => Message::Discover {
475                namespace: ns.map(Namespace::new).transpose()?,
476                cookie: cookie.map(Cookie::from_wire_encoding).transpose()?,
477                limit,
478            },
479            proto::Message {
480                type_pb: Some(proto::MessageType::DISCOVER_RESPONSE),
481                discoverResponse:
482                    Some(proto::DiscoverResponse {
483                        registrations,
484                        status: Some(proto::ResponseStatus::OK),
485                        cookie: Some(cookie),
486                        ..
487                    }),
488                ..
489            } => {
490                let registrations = registrations
491                    .into_iter()
492                    .map(|reggo| {
493                        Ok(Registration {
494                            namespace: reggo
495                                .ns
496                                .map(Namespace::new)
497                                .transpose()?
498                                .ok_or(ConversionError::MissingNamespace)?,
499                            record: PeerRecord::from_signed_envelope(
500                                SignedEnvelope::from_protobuf_encoding(
501                                    &reggo
502                                        .signedPeerRecord
503                                        .ok_or(ConversionError::MissingSignedPeerRecord)?,
504                                )?,
505                            )?,
506                            ttl: reggo.ttl.ok_or(ConversionError::MissingTtl)?,
507                        })
508                    })
509                    .collect::<Result<Vec<_>, ConversionError>>()?;
510                let cookie = Cookie::from_wire_encoding(cookie)?;
511
512                Message::DiscoverResponse(Ok((registrations, cookie)))
513            }
514            proto::Message {
515                type_pb: Some(proto::MessageType::REGISTER_RESPONSE),
516                registerResponse:
517                    Some(proto::RegisterResponse {
518                        status: Some(response_status),
519                        ..
520                    }),
521                ..
522            } => Message::RegisterResponse(Err(response_status.try_into()?)),
523            proto::Message {
524                type_pb: Some(proto::MessageType::UNREGISTER),
525                unregister: Some(proto::Unregister { ns, .. }),
526                ..
527            } => Message::Unregister(
528                ns.map(Namespace::new)
529                    .transpose()?
530                    .ok_or(ConversionError::MissingNamespace)?,
531            ),
532            proto::Message {
533                type_pb: Some(proto::MessageType::DISCOVER_RESPONSE),
534                discoverResponse:
535                    Some(proto::DiscoverResponse {
536                        status: Some(response_status),
537                        ..
538                    }),
539                ..
540            } => Message::DiscoverResponse(Err(response_status.try_into()?)),
541            _ => return Err(ConversionError::InconsistentWireMessage),
542        };
543
544        Ok(message)
545    }
546}
547
548#[derive(Debug, thiserror::Error)]
549pub enum ConversionError {
550    #[error("The wire message is consistent")]
551    InconsistentWireMessage,
552    #[error("Missing namespace field")]
553    MissingNamespace,
554    #[error("Invalid namespace")]
555    InvalidNamespace(#[from] NamespaceTooLong),
556    #[error("Missing signed peer record field")]
557    MissingSignedPeerRecord,
558    #[error("Missing TTL field")]
559    MissingTtl,
560    #[error("Bad status code")]
561    BadStatusCode,
562    #[error("Failed to decode signed envelope")]
563    BadSignedEnvelope(#[from] signed_envelope::DecodingError),
564    #[error("Failed to decode envelope as signed peer record")]
565    BadSignedPeerRecord(#[from] peer_record::FromEnvelopeError),
566    #[error(transparent)]
567    BadCookie(#[from] InvalidCookie),
568    #[error("The requested PoW difficulty is out of range")]
569    PoWDifficultyOutOfRange,
570    #[error("The provided PoW hash is not 32 bytes long")]
571    BadPoWHash,
572}
573
574impl ConversionError {
575    pub fn to_error_code(&self) -> ErrorCode {
576        match self {
577            ConversionError::MissingNamespace => ErrorCode::InvalidNamespace,
578            ConversionError::MissingSignedPeerRecord => ErrorCode::InvalidSignedPeerRecord,
579            ConversionError::BadSignedEnvelope(_) => ErrorCode::InvalidSignedPeerRecord,
580            ConversionError::BadSignedPeerRecord(_) => ErrorCode::InvalidSignedPeerRecord,
581            ConversionError::BadCookie(_) => ErrorCode::InvalidCookie,
582            ConversionError::MissingTtl => ErrorCode::InvalidTtl,
583            ConversionError::InconsistentWireMessage => ErrorCode::InternalError,
584            ConversionError::BadStatusCode => ErrorCode::InternalError,
585            ConversionError::PoWDifficultyOutOfRange => ErrorCode::InternalError,
586            ConversionError::BadPoWHash => ErrorCode::InternalError,
587            ConversionError::InvalidNamespace(_) => ErrorCode::InvalidNamespace,
588        }
589    }
590}
591
592impl TryFrom<proto::ResponseStatus> for ErrorCode {
593    type Error = UnmappableStatusCode;
594
595    fn try_from(value: proto::ResponseStatus) -> Result<Self, Self::Error> {
596        use proto::ResponseStatus::*;
597
598        let code = match value {
599            OK => return Err(UnmappableStatusCode(value)),
600            E_INVALID_NAMESPACE => ErrorCode::InvalidNamespace,
601            E_INVALID_SIGNED_PEER_RECORD => ErrorCode::InvalidSignedPeerRecord,
602            E_INVALID_TTL => ErrorCode::InvalidTtl,
603            E_INVALID_COOKIE => ErrorCode::InvalidCookie,
604            E_NOT_AUTHORIZED => ErrorCode::NotAuthorized,
605            E_INTERNAL_ERROR => ErrorCode::InternalError,
606            E_UNAVAILABLE => ErrorCode::Unavailable,
607        };
608
609        Ok(code)
610    }
611}
612
613impl From<ErrorCode> for proto::ResponseStatus {
614    fn from(error_code: ErrorCode) -> Self {
615        use proto::ResponseStatus::*;
616
617        match error_code {
618            ErrorCode::InvalidNamespace => E_INVALID_NAMESPACE,
619            ErrorCode::InvalidSignedPeerRecord => E_INVALID_SIGNED_PEER_RECORD,
620            ErrorCode::InvalidTtl => E_INVALID_TTL,
621            ErrorCode::InvalidCookie => E_INVALID_COOKIE,
622            ErrorCode::NotAuthorized => E_NOT_AUTHORIZED,
623            ErrorCode::InternalError => E_INTERNAL_ERROR,
624            ErrorCode::Unavailable => E_UNAVAILABLE,
625        }
626    }
627}
628
629impl From<UnmappableStatusCode> for ConversionError {
630    fn from(_: UnmappableStatusCode) -> Self {
631        ConversionError::InconsistentWireMessage
632    }
633}
634
635#[derive(Debug, thiserror::Error)]
636#[error("The response code ({0:?}) cannot be mapped to our ErrorCode enum")]
637pub struct UnmappableStatusCode(proto::ResponseStatus);
638
639mod proto {
640    #![allow(unreachable_pub)]
641    include!("generated/mod.rs");
642    pub(crate) use self::rendezvous::pb::{mod_Message::*, Message};
643}
644
645#[cfg(test)]
646mod tests {
647    use super::*;
648
649    #[test]
650    fn cookie_wire_encoding_roundtrip() {
651        let cookie = Cookie::for_namespace(Namespace::from_static("foo"));
652
653        let bytes = cookie.clone().into_wire_encoding();
654        let parsed = Cookie::from_wire_encoding(bytes).unwrap();
655
656        assert_eq!(parsed, cookie);
657    }
658
659    #[test]
660    fn cookie_wire_encoding_length() {
661        let cookie = Cookie::for_namespace(Namespace::from_static("foo"));
662
663        let bytes = cookie.into_wire_encoding();
664
665        assert_eq!(bytes.len(), 8 + 3)
666    }
667}