1use std::time::Duration;
22
23use asynchronous_codec::{Framed, FramedParts};
24use bytes::Bytes;
25use either::Either;
26use futures::prelude::*;
27use libp2p_core::Multiaddr;
28use libp2p_identity::PeerId;
29use libp2p_swarm::Stream;
30use thiserror::Error;
31use web_time::SystemTime;
32
33use crate::{proto, proto::message_v2::pb::mod_HopMessage::Type, protocol::MAX_MESSAGE_SIZE};
34
35#[derive(Debug, Error)]
36pub enum Error {
37 #[error(transparent)]
38 Codec(#[from] quick_protobuf_codec::Error),
39 #[error("Stream closed")]
40 StreamClosed,
41 #[error("Failed to parse peer id.")]
42 ParsePeerId,
43 #[error("Expected 'peer' field to be set.")]
44 MissingPeer,
45 #[error("Unexpected message type 'status'")]
46 UnexpectedTypeStatus,
47}
48
49pub struct ReservationReq {
50 substream: Framed<Stream, quick_protobuf_codec::Codec<proto::HopMessage>>,
51 reservation_duration: Duration,
52 max_circuit_duration: Duration,
53 max_circuit_bytes: u64,
54}
55
56impl ReservationReq {
57 pub async fn accept(self, addrs: Vec<Multiaddr>) -> Result<(), Error> {
58 if addrs.is_empty() {
59 tracing::debug!(
60 "Accepting relay reservation without providing external addresses of local node. \
61 Thus the remote node might not be able to advertise its relayed address."
62 )
63 }
64
65 let msg = proto::HopMessage {
66 type_pb: proto::HopMessageType::STATUS,
67 peer: None,
68 reservation: Some(proto::Reservation {
69 addrs: addrs.into_iter().map(|a| a.to_vec()).collect(),
70 expire: (SystemTime::now() + self.reservation_duration)
71 .duration_since(SystemTime::UNIX_EPOCH)
72 .unwrap()
73 .as_secs(),
74 voucher: None,
75 }),
76 limit: Some(proto::Limit {
77 duration: Some(
78 self.max_circuit_duration
79 .as_secs()
80 .try_into()
81 .expect("`max_circuit_duration` not to exceed `u32::MAX`."),
82 ),
83 data: Some(self.max_circuit_bytes),
84 }),
85 status: Some(proto::Status::OK),
86 };
87
88 self.send(msg).await
89 }
90
91 pub async fn deny(self, status: proto::Status) -> Result<(), Error> {
92 let msg = proto::HopMessage {
93 type_pb: proto::HopMessageType::STATUS,
94 peer: None,
95 reservation: None,
96 limit: None,
97 status: Some(status),
98 };
99
100 self.send(msg).await
101 }
102
103 async fn send(mut self, msg: proto::HopMessage) -> Result<(), Error> {
104 self.substream.send(msg).await?;
105 self.substream.flush().await?;
106 self.substream.close().await?;
107
108 Ok(())
109 }
110}
111
112pub struct CircuitReq {
113 dst: PeerId,
114 substream: Framed<Stream, quick_protobuf_codec::Codec<proto::HopMessage>>,
115 max_circuit_duration: Duration,
116 max_circuit_bytes: u64,
117}
118
119impl CircuitReq {
120 pub fn dst(&self) -> PeerId {
121 self.dst
122 }
123
124 pub async fn accept(mut self) -> Result<(Stream, Bytes), Error> {
125 let msg = proto::HopMessage {
126 type_pb: proto::HopMessageType::STATUS,
127 peer: None,
128 reservation: None,
129 limit: Some(proto::Limit {
130 duration: Some(
131 self.max_circuit_duration
132 .as_secs()
133 .try_into()
134 .expect("`max_circuit_duration` not to exceed `u32::MAX`."),
135 ),
136 data: Some(self.max_circuit_bytes),
137 }),
138 status: Some(proto::Status::OK),
139 };
140
141 self.send(msg).await?;
142
143 let FramedParts {
144 io,
145 read_buffer,
146 write_buffer,
147 ..
148 } = self.substream.into_parts();
149 assert!(
150 write_buffer.is_empty(),
151 "Expect a flushed Framed to have an empty write buffer."
152 );
153
154 Ok((io, read_buffer.freeze()))
155 }
156
157 pub async fn deny(mut self, status: proto::Status) -> Result<(), Error> {
158 let msg = proto::HopMessage {
159 type_pb: proto::HopMessageType::STATUS,
160 peer: None,
161 reservation: None,
162 limit: None,
163 status: Some(status),
164 };
165 self.send(msg).await?;
166 self.substream.close().await.map_err(Into::into)
167 }
168
169 async fn send(&mut self, msg: proto::HopMessage) -> Result<(), quick_protobuf_codec::Error> {
170 self.substream.send(msg).await?;
171 self.substream.flush().await?;
172
173 Ok(())
174 }
175}
176
177pub(crate) async fn handle_inbound_request(
178 io: Stream,
179 reservation_duration: Duration,
180 max_circuit_duration: Duration,
181 max_circuit_bytes: u64,
182) -> Result<Either<ReservationReq, CircuitReq>, Error> {
183 let mut substream = Framed::new(io, quick_protobuf_codec::Codec::new(MAX_MESSAGE_SIZE));
184
185 let res = substream.next().await;
186
187 if let None | Some(Err(_)) = res {
188 return Err(Error::StreamClosed);
189 }
190
191 let proto::HopMessage {
192 type_pb,
193 peer,
194 reservation: _,
195 limit: _,
196 status: _,
197 } = res.unwrap().expect("should be ok");
198
199 let req = match type_pb {
200 Type::RESERVE => Either::Left(ReservationReq {
201 substream,
202 reservation_duration,
203 max_circuit_duration,
204 max_circuit_bytes,
205 }),
206 Type::CONNECT => {
207 let peer_id_res = match peer {
208 Some(r) => PeerId::from_bytes(&r.id),
209 None => return Err(Error::MissingPeer),
210 };
211
212 let dst = peer_id_res.map_err(|_| Error::ParsePeerId)?;
213
214 Either::Right(CircuitReq {
215 dst,
216 substream,
217 max_circuit_duration,
218 max_circuit_bytes,
219 })
220 }
221 Type::STATUS => return Err(Error::UnexpectedTypeStatus),
222 };
223
224 Ok(req)
225}