1use std::{fmt, io};
22
23use async_trait::async_trait;
24use asynchronous_codec::{BytesMut, Decoder, Encoder, FramedRead, FramedWrite};
25use futures::{AsyncRead, AsyncWrite, SinkExt, StreamExt};
26use libp2p_core::{peer_record, signed_envelope, PeerRecord, SignedEnvelope};
27use libp2p_swarm::StreamProtocol;
28use quick_protobuf_codec::Codec as ProtobufCodec;
29use rand::RngCore;
30
31use crate::DEFAULT_TTL;
32
33pub type Ttl = u64;
34pub(crate) type Limit = u64;
35
36const MAX_MESSAGE_LEN_BYTES: usize = 1024 * 1024;
37
38#[allow(clippy::large_enum_variant)]
39#[derive(Debug, Clone, PartialEq)]
40pub enum Message {
41 Register(NewRegistration),
42 RegisterResponse(Result<Ttl, ErrorCode>),
43 Unregister(Namespace),
44 Discover {
45 namespace: Option<Namespace>,
46 cookie: Option<Cookie>,
47 limit: Option<Limit>,
48 },
49 DiscoverResponse(Result<(Vec<Registration>, Cookie), ErrorCode>),
50}
51
52#[derive(Debug, PartialEq, Eq, Hash, Clone)]
53pub struct Namespace(String);
54
55impl Namespace {
56 pub fn from_static(value: &'static str) -> Self {
62 if value.len() > crate::MAX_NAMESPACE {
63 panic!("Namespace '{value}' is too long!")
64 }
65
66 Namespace(value.to_owned())
67 }
68
69 pub fn new(value: String) -> Result<Self, NamespaceTooLong> {
70 if value.len() > crate::MAX_NAMESPACE {
71 return Err(NamespaceTooLong);
72 }
73
74 Ok(Namespace(value))
75 }
76}
77
78impl From<Namespace> for String {
79 fn from(namespace: Namespace) -> Self {
80 namespace.0
81 }
82}
83
84impl fmt::Display for Namespace {
85 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
86 write!(f, "{}", self.0)
87 }
88}
89
90impl PartialEq<str> for Namespace {
91 fn eq(&self, other: &str) -> bool {
92 self.0.eq(other)
93 }
94}
95
96impl PartialEq<Namespace> for str {
97 fn eq(&self, other: &Namespace) -> bool {
98 other.0.eq(self)
99 }
100}
101
102#[derive(Debug, thiserror::Error)]
103#[error("Namespace is too long")]
104pub struct NamespaceTooLong;
105
106#[derive(Debug, Eq, PartialEq, Hash, Clone)]
107pub struct Cookie {
108 id: u64,
109 namespace: Option<Namespace>,
110}
111
112impl Cookie {
113 pub fn for_namespace(namespace: Namespace) -> Self {
118 Self {
119 id: rand::thread_rng().next_u64(),
120 namespace: Some(namespace),
121 }
122 }
123
124 pub fn for_all_namespaces() -> Self {
126 Self {
127 id: rand::random(),
128 namespace: None,
129 }
130 }
131
132 pub fn into_wire_encoding(self) -> Vec<u8> {
133 let id_bytes = self.id.to_be_bytes();
134 let namespace = self.namespace.map(|ns| ns.0).unwrap_or_default();
135
136 let mut buffer = Vec::with_capacity(id_bytes.len() + namespace.len());
137 buffer.extend_from_slice(&id_bytes);
138 buffer.extend_from_slice(namespace.as_bytes());
139
140 buffer
141 }
142
143 pub fn from_wire_encoding(mut bytes: Vec<u8>) -> Result<Self, InvalidCookie> {
144 if bytes.len() < 8 {
146 return Err(InvalidCookie);
147 }
148
149 let namespace = bytes.split_off(8);
150 let namespace = if namespace.is_empty() {
151 None
152 } else {
153 Some(
154 Namespace::new(String::from_utf8(namespace).map_err(|_| InvalidCookie)?)
155 .map_err(|_| InvalidCookie)?,
156 )
157 };
158
159 let bytes = <[u8; 8]>::try_from(bytes).map_err(|_| InvalidCookie)?;
160 let id = u64::from_be_bytes(bytes);
161
162 Ok(Self { id, namespace })
163 }
164
165 pub fn namespace(&self) -> Option<&Namespace> {
166 self.namespace.as_ref()
167 }
168}
169
170#[derive(Debug, thiserror::Error)]
171#[error("The cookie was malformed")]
172pub struct InvalidCookie;
173
174#[derive(Debug, Clone, PartialEq)]
175pub struct NewRegistration {
176 pub namespace: Namespace,
177 pub record: PeerRecord,
178 pub ttl: Option<u64>,
179}
180
181impl NewRegistration {
182 pub fn new(namespace: Namespace, record: PeerRecord, ttl: Option<Ttl>) -> Self {
183 Self {
184 namespace,
185 record,
186 ttl,
187 }
188 }
189
190 pub fn effective_ttl(&self) -> Ttl {
191 self.ttl.unwrap_or(DEFAULT_TTL)
192 }
193}
194
195#[derive(Debug, Clone, PartialEq, Eq)]
196pub struct Registration {
197 pub namespace: Namespace,
198 pub record: PeerRecord,
199 pub ttl: Ttl,
200}
201
202#[derive(Debug, Copy, Clone, PartialEq, Eq)]
203pub enum ErrorCode {
204 InvalidNamespace,
205 InvalidSignedPeerRecord,
206 InvalidTtl,
207 InvalidCookie,
208 NotAuthorized,
209 InternalError,
210 Unavailable,
211}
212
213impl Encoder for Codec {
214 type Item<'a> = Message;
215 type Error = Error;
216
217 fn encode(&mut self, item: Self::Item<'_>, dst: &mut BytesMut) -> Result<(), Self::Error> {
218 let mut pb: ProtobufCodec<proto::Message> = ProtobufCodec::new(MAX_MESSAGE_LEN_BYTES);
219
220 pb.encode(proto::Message::from(item), dst)?;
221
222 Ok(())
223 }
224}
225
226impl Decoder for Codec {
227 type Item = Message;
228 type Error = Error;
229
230 fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
231 let mut pb: ProtobufCodec<proto::Message> = ProtobufCodec::new(MAX_MESSAGE_LEN_BYTES);
232
233 let Some(message) = pb.decode(src)? else {
234 return Ok(None);
235 };
236
237 Ok(Some(message.try_into()?))
238 }
239}
240
241#[derive(Clone, Default)]
242pub struct Codec {}
243
244#[async_trait]
245impl libp2p_request_response::Codec for Codec {
246 type Protocol = StreamProtocol;
247 type Request = Message;
248 type Response = Message;
249
250 async fn read_request<T>(&mut self, _: &Self::Protocol, io: &mut T) -> io::Result<Self::Request>
251 where
252 T: AsyncRead + Unpin + Send,
253 {
254 let message = FramedRead::new(io, self.clone())
255 .next()
256 .await
257 .ok_or(io::ErrorKind::UnexpectedEof)??;
258
259 Ok(message)
260 }
261
262 async fn read_response<T>(
263 &mut self,
264 _: &Self::Protocol,
265 io: &mut T,
266 ) -> io::Result<Self::Response>
267 where
268 T: AsyncRead + Unpin + Send,
269 {
270 let message = FramedRead::new(io, self.clone())
271 .next()
272 .await
273 .ok_or(io::ErrorKind::UnexpectedEof)??;
274
275 Ok(message)
276 }
277
278 async fn write_request<T>(
279 &mut self,
280 _: &Self::Protocol,
281 io: &mut T,
282 req: Self::Request,
283 ) -> io::Result<()>
284 where
285 T: AsyncWrite + Unpin + Send,
286 {
287 FramedWrite::new(io, self.clone()).send(req).await?;
288
289 Ok(())
290 }
291
292 async fn write_response<T>(
293 &mut self,
294 _: &Self::Protocol,
295 io: &mut T,
296 res: Self::Response,
297 ) -> io::Result<()>
298 where
299 T: AsyncWrite + Unpin + Send,
300 {
301 FramedWrite::new(io, self.clone()).send(res).await?;
302
303 Ok(())
304 }
305}
306
307#[derive(Debug, thiserror::Error)]
308pub enum Error {
309 #[error(transparent)]
310 Codec(#[from] quick_protobuf_codec::Error),
311 #[error("Failed to read/write")]
312 Io(#[from] std::io::Error),
313 #[error("Failed to convert wire message to internal data model")]
314 Conversion(#[from] ConversionError),
315}
316
317impl From<Error> for std::io::Error {
318 fn from(value: Error) -> Self {
319 match value {
320 Error::Io(e) => e,
321 Error::Codec(e) => io::Error::from(e),
322 Error::Conversion(e) => io::Error::new(io::ErrorKind::InvalidInput, e),
323 }
324 }
325}
326
327impl From<Message> for proto::Message {
328 fn from(message: Message) -> Self {
329 match message {
330 Message::Register(NewRegistration {
331 namespace,
332 record,
333 ttl,
334 }) => proto::Message {
335 type_pb: Some(proto::MessageType::REGISTER),
336 register: Some(proto::Register {
337 ns: Some(namespace.into()),
338 ttl,
339 signedPeerRecord: Some(record.into_signed_envelope().into_protobuf_encoding()),
340 }),
341 registerResponse: None,
342 unregister: None,
343 discover: None,
344 discoverResponse: None,
345 },
346 Message::RegisterResponse(Ok(ttl)) => proto::Message {
347 type_pb: Some(proto::MessageType::REGISTER_RESPONSE),
348 registerResponse: Some(proto::RegisterResponse {
349 status: Some(proto::ResponseStatus::OK),
350 statusText: None,
351 ttl: Some(ttl),
352 }),
353 register: None,
354 discover: None,
355 unregister: None,
356 discoverResponse: None,
357 },
358 Message::RegisterResponse(Err(error)) => proto::Message {
359 type_pb: Some(proto::MessageType::REGISTER_RESPONSE),
360 registerResponse: Some(proto::RegisterResponse {
361 status: Some(proto::ResponseStatus::from(error)),
362 statusText: None,
363 ttl: None,
364 }),
365 register: None,
366 discover: None,
367 unregister: None,
368 discoverResponse: None,
369 },
370 Message::Unregister(namespace) => proto::Message {
371 type_pb: Some(proto::MessageType::UNREGISTER),
372 unregister: Some(proto::Unregister {
373 ns: Some(namespace.into()),
374 id: None,
375 }),
376 register: None,
377 registerResponse: None,
378 discover: None,
379 discoverResponse: None,
380 },
381 Message::Discover {
382 namespace,
383 cookie,
384 limit,
385 } => proto::Message {
386 type_pb: Some(proto::MessageType::DISCOVER),
387 discover: Some(proto::Discover {
388 ns: namespace.map(|ns| ns.into()),
389 cookie: cookie.map(|cookie| cookie.into_wire_encoding()),
390 limit,
391 }),
392 register: None,
393 registerResponse: None,
394 unregister: None,
395 discoverResponse: None,
396 },
397 Message::DiscoverResponse(Ok((registrations, cookie))) => proto::Message {
398 type_pb: Some(proto::MessageType::DISCOVER_RESPONSE),
399 discoverResponse: Some(proto::DiscoverResponse {
400 registrations: registrations
401 .into_iter()
402 .map(|reggo| proto::Register {
403 ns: Some(reggo.namespace.into()),
404 ttl: Some(reggo.ttl),
405 signedPeerRecord: Some(
406 reggo.record.into_signed_envelope().into_protobuf_encoding(),
407 ),
408 })
409 .collect(),
410 status: Some(proto::ResponseStatus::OK),
411 statusText: None,
412 cookie: Some(cookie.into_wire_encoding()),
413 }),
414 register: None,
415 discover: None,
416 unregister: None,
417 registerResponse: None,
418 },
419 Message::DiscoverResponse(Err(error)) => proto::Message {
420 type_pb: Some(proto::MessageType::DISCOVER_RESPONSE),
421 discoverResponse: Some(proto::DiscoverResponse {
422 registrations: Vec::new(),
423 status: Some(proto::ResponseStatus::from(error)),
424 statusText: None,
425 cookie: None,
426 }),
427 register: None,
428 discover: None,
429 unregister: None,
430 registerResponse: None,
431 },
432 }
433 }
434}
435
436impl TryFrom<proto::Message> for Message {
437 type Error = ConversionError;
438
439 fn try_from(message: proto::Message) -> Result<Self, Self::Error> {
440 let message = match message {
441 proto::Message {
442 type_pb: Some(proto::MessageType::REGISTER),
443 register:
444 Some(proto::Register {
445 ns,
446 ttl,
447 signedPeerRecord: Some(signed_peer_record),
448 }),
449 ..
450 } => Message::Register(NewRegistration {
451 namespace: ns
452 .map(Namespace::new)
453 .transpose()?
454 .ok_or(ConversionError::MissingNamespace)?,
455 ttl,
456 record: PeerRecord::from_signed_envelope(SignedEnvelope::from_protobuf_encoding(
457 &signed_peer_record,
458 )?)?,
459 }),
460 proto::Message {
461 type_pb: Some(proto::MessageType::REGISTER_RESPONSE),
462 registerResponse:
463 Some(proto::RegisterResponse {
464 status: Some(proto::ResponseStatus::OK),
465 ttl,
466 ..
467 }),
468 ..
469 } => Message::RegisterResponse(Ok(ttl.ok_or(ConversionError::MissingTtl)?)),
470 proto::Message {
471 type_pb: Some(proto::MessageType::DISCOVER),
472 discover: Some(proto::Discover { ns, limit, cookie }),
473 ..
474 } => Message::Discover {
475 namespace: ns.map(Namespace::new).transpose()?,
476 cookie: cookie.map(Cookie::from_wire_encoding).transpose()?,
477 limit,
478 },
479 proto::Message {
480 type_pb: Some(proto::MessageType::DISCOVER_RESPONSE),
481 discoverResponse:
482 Some(proto::DiscoverResponse {
483 registrations,
484 status: Some(proto::ResponseStatus::OK),
485 cookie: Some(cookie),
486 ..
487 }),
488 ..
489 } => {
490 let registrations = registrations
491 .into_iter()
492 .map(|reggo| {
493 Ok(Registration {
494 namespace: reggo
495 .ns
496 .map(Namespace::new)
497 .transpose()?
498 .ok_or(ConversionError::MissingNamespace)?,
499 record: PeerRecord::from_signed_envelope(
500 SignedEnvelope::from_protobuf_encoding(
501 ®go
502 .signedPeerRecord
503 .ok_or(ConversionError::MissingSignedPeerRecord)?,
504 )?,
505 )?,
506 ttl: reggo.ttl.ok_or(ConversionError::MissingTtl)?,
507 })
508 })
509 .collect::<Result<Vec<_>, ConversionError>>()?;
510 let cookie = Cookie::from_wire_encoding(cookie)?;
511
512 Message::DiscoverResponse(Ok((registrations, cookie)))
513 }
514 proto::Message {
515 type_pb: Some(proto::MessageType::REGISTER_RESPONSE),
516 registerResponse:
517 Some(proto::RegisterResponse {
518 status: Some(response_status),
519 ..
520 }),
521 ..
522 } => Message::RegisterResponse(Err(response_status.try_into()?)),
523 proto::Message {
524 type_pb: Some(proto::MessageType::UNREGISTER),
525 unregister: Some(proto::Unregister { ns, .. }),
526 ..
527 } => Message::Unregister(
528 ns.map(Namespace::new)
529 .transpose()?
530 .ok_or(ConversionError::MissingNamespace)?,
531 ),
532 proto::Message {
533 type_pb: Some(proto::MessageType::DISCOVER_RESPONSE),
534 discoverResponse:
535 Some(proto::DiscoverResponse {
536 status: Some(response_status),
537 ..
538 }),
539 ..
540 } => Message::DiscoverResponse(Err(response_status.try_into()?)),
541 _ => return Err(ConversionError::InconsistentWireMessage),
542 };
543
544 Ok(message)
545 }
546}
547
548#[derive(Debug, thiserror::Error)]
549pub enum ConversionError {
550 #[error("The wire message is consistent")]
551 InconsistentWireMessage,
552 #[error("Missing namespace field")]
553 MissingNamespace,
554 #[error("Invalid namespace")]
555 InvalidNamespace(#[from] NamespaceTooLong),
556 #[error("Missing signed peer record field")]
557 MissingSignedPeerRecord,
558 #[error("Missing TTL field")]
559 MissingTtl,
560 #[error("Bad status code")]
561 BadStatusCode,
562 #[error("Failed to decode signed envelope")]
563 BadSignedEnvelope(#[from] signed_envelope::DecodingError),
564 #[error("Failed to decode envelope as signed peer record")]
565 BadSignedPeerRecord(#[from] peer_record::FromEnvelopeError),
566 #[error(transparent)]
567 BadCookie(#[from] InvalidCookie),
568 #[error("The requested PoW difficulty is out of range")]
569 PoWDifficultyOutOfRange,
570 #[error("The provided PoW hash is not 32 bytes long")]
571 BadPoWHash,
572}
573
574impl ConversionError {
575 pub fn to_error_code(&self) -> ErrorCode {
576 match self {
577 ConversionError::MissingNamespace => ErrorCode::InvalidNamespace,
578 ConversionError::MissingSignedPeerRecord => ErrorCode::InvalidSignedPeerRecord,
579 ConversionError::BadSignedEnvelope(_) => ErrorCode::InvalidSignedPeerRecord,
580 ConversionError::BadSignedPeerRecord(_) => ErrorCode::InvalidSignedPeerRecord,
581 ConversionError::BadCookie(_) => ErrorCode::InvalidCookie,
582 ConversionError::MissingTtl => ErrorCode::InvalidTtl,
583 ConversionError::InconsistentWireMessage => ErrorCode::InternalError,
584 ConversionError::BadStatusCode => ErrorCode::InternalError,
585 ConversionError::PoWDifficultyOutOfRange => ErrorCode::InternalError,
586 ConversionError::BadPoWHash => ErrorCode::InternalError,
587 ConversionError::InvalidNamespace(_) => ErrorCode::InvalidNamespace,
588 }
589 }
590}
591
592impl TryFrom<proto::ResponseStatus> for ErrorCode {
593 type Error = UnmappableStatusCode;
594
595 fn try_from(value: proto::ResponseStatus) -> Result<Self, Self::Error> {
596 use proto::ResponseStatus::*;
597
598 let code = match value {
599 OK => return Err(UnmappableStatusCode(value)),
600 E_INVALID_NAMESPACE => ErrorCode::InvalidNamespace,
601 E_INVALID_SIGNED_PEER_RECORD => ErrorCode::InvalidSignedPeerRecord,
602 E_INVALID_TTL => ErrorCode::InvalidTtl,
603 E_INVALID_COOKIE => ErrorCode::InvalidCookie,
604 E_NOT_AUTHORIZED => ErrorCode::NotAuthorized,
605 E_INTERNAL_ERROR => ErrorCode::InternalError,
606 E_UNAVAILABLE => ErrorCode::Unavailable,
607 };
608
609 Ok(code)
610 }
611}
612
613impl From<ErrorCode> for proto::ResponseStatus {
614 fn from(error_code: ErrorCode) -> Self {
615 use proto::ResponseStatus::*;
616
617 match error_code {
618 ErrorCode::InvalidNamespace => E_INVALID_NAMESPACE,
619 ErrorCode::InvalidSignedPeerRecord => E_INVALID_SIGNED_PEER_RECORD,
620 ErrorCode::InvalidTtl => E_INVALID_TTL,
621 ErrorCode::InvalidCookie => E_INVALID_COOKIE,
622 ErrorCode::NotAuthorized => E_NOT_AUTHORIZED,
623 ErrorCode::InternalError => E_INTERNAL_ERROR,
624 ErrorCode::Unavailable => E_UNAVAILABLE,
625 }
626 }
627}
628
629impl From<UnmappableStatusCode> for ConversionError {
630 fn from(_: UnmappableStatusCode) -> Self {
631 ConversionError::InconsistentWireMessage
632 }
633}
634
635#[derive(Debug, thiserror::Error)]
636#[error("The response code ({0:?}) cannot be mapped to our ErrorCode enum")]
637pub struct UnmappableStatusCode(proto::ResponseStatus);
638
639mod proto {
640 #![allow(unreachable_pub)]
641 include!("generated/mod.rs");
642 pub(crate) use self::rendezvous::pb::{mod_Message::*, Message};
643}
644
645#[cfg(test)]
646mod tests {
647 use super::*;
648
649 #[test]
650 fn cookie_wire_encoding_roundtrip() {
651 let cookie = Cookie::for_namespace(Namespace::from_static("foo"));
652
653 let bytes = cookie.clone().into_wire_encoding();
654 let parsed = Cookie::from_wire_encoding(bytes).unwrap();
655
656 assert_eq!(parsed, cookie);
657 }
658
659 #[test]
660 fn cookie_wire_encoding_length() {
661 let cookie = Cookie::for_namespace(Namespace::from_static("foo"));
662
663 let bytes = cookie.into_wire_encoding();
664
665 assert_eq!(bytes.len(), 8 + 3)
666 }
667}