1use std::{
28 io,
29 pin::Pin,
30 task::{Context, Poll},
31 time::Duration,
32};
33
34use futures::{
35 future::{Future, FutureExt},
36 io::{AsyncBufRead, AsyncRead, AsyncWrite, BufReader},
37 ready,
38};
39use futures_timer::Delay;
40
41pub(crate) struct CopyFuture<S, D> {
42 src: BufReader<S>,
43 dst: BufReader<D>,
44
45 max_circuit_duration: Delay,
46 max_circuit_bytes: u64,
47 bytes_sent: u64,
48}
49
50impl<S: AsyncRead, D: AsyncRead> CopyFuture<S, D> {
51 pub(crate) fn new(
52 src: S,
53 dst: D,
54 max_circuit_duration: Duration,
55 max_circuit_bytes: u64,
56 ) -> Self {
57 CopyFuture {
58 src: BufReader::new(src),
59 dst: BufReader::new(dst),
60 max_circuit_duration: Delay::new(max_circuit_duration),
61 max_circuit_bytes,
62 bytes_sent: Default::default(),
63 }
64 }
65}
66
67impl<S, D> Future for CopyFuture<S, D>
68where
69 S: AsyncRead + AsyncWrite + Unpin,
70 D: AsyncRead + AsyncWrite + Unpin,
71{
72 type Output = io::Result<()>;
73
74 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
75 let this = &mut *self;
76
77 loop {
78 if this.max_circuit_bytes > 0 && this.bytes_sent > this.max_circuit_bytes {
79 return Poll::Ready(Err(io::Error::other("Max circuit bytes reached.")));
80 }
81
82 enum Status {
83 Pending,
84 Done,
85 Progressed,
86 }
87
88 let src_status = match forward_data(&mut this.src, &mut this.dst, cx) {
89 Poll::Ready(Err(e)) => return Poll::Ready(Err(e)),
90 Poll::Ready(Ok(0)) => Status::Done,
91 Poll::Ready(Ok(i)) => {
92 this.bytes_sent += i;
93 Status::Progressed
94 }
95 Poll::Pending => Status::Pending,
96 };
97
98 let dst_status = match forward_data(&mut this.dst, &mut this.src, cx) {
99 Poll::Ready(Err(e)) => return Poll::Ready(Err(e)),
100 Poll::Ready(Ok(0)) => Status::Done,
101 Poll::Ready(Ok(i)) => {
102 this.bytes_sent += i;
103 Status::Progressed
104 }
105 Poll::Pending => Status::Pending,
106 };
107
108 match (src_status, dst_status) {
109 (Status::Done, Status::Done) => return Poll::Ready(Ok(())),
111 (Status::Progressed, _) | (_, Status::Progressed) => {}
113 (Status::Pending, Status::Pending) => break,
116 (Status::Pending, Status::Done) | (Status::Done, Status::Pending) => break,
119 }
120 }
121
122 if let Poll::Ready(()) = this.max_circuit_duration.poll_unpin(cx) {
123 return Poll::Ready(Err(io::ErrorKind::TimedOut.into()));
124 }
125
126 Poll::Pending
127 }
128}
129
130fn forward_data<S: AsyncBufRead + Unpin, D: AsyncWrite + Unpin>(
135 mut src: &mut S,
136 mut dst: &mut D,
137 cx: &mut Context<'_>,
138) -> Poll<io::Result<u64>> {
139 let buffer = match Pin::new(&mut src).poll_fill_buf(cx)? {
140 Poll::Ready(buffer) => buffer,
141 Poll::Pending => {
142 let _ = Pin::new(&mut dst).poll_flush(cx)?;
143 return Poll::Pending;
144 }
145 };
146
147 if buffer.is_empty() {
148 ready!(Pin::new(&mut dst).poll_flush(cx))?;
149 ready!(Pin::new(&mut dst).poll_close(cx))?;
150 return Poll::Ready(Ok(0));
151 }
152
153 let i = ready!(Pin::new(dst).poll_write(cx, buffer))?;
154 if i == 0 {
155 return Poll::Ready(Err(io::ErrorKind::WriteZero.into()));
156 }
157 Pin::new(src).consume(i);
158
159 Poll::Ready(Ok(i.try_into().expect("usize to fit into u64.")))
160}
161
162#[cfg(test)]
163mod tests {
164 use std::io::ErrorKind;
165
166 use futures::{executor::block_on, io::BufWriter};
167 use quickcheck::QuickCheck;
168
169 use super::*;
170
171 #[test]
172 fn quickcheck() {
173 struct Connection {
174 read: Vec<u8>,
175 write: Vec<u8>,
176 }
177
178 impl AsyncWrite for Connection {
179 fn poll_write(
180 mut self: std::pin::Pin<&mut Self>,
181 cx: &mut Context<'_>,
182 buf: &[u8],
183 ) -> Poll<std::io::Result<usize>> {
184 Pin::new(&mut self.write).poll_write(cx, buf)
185 }
186
187 fn poll_flush(
188 mut self: std::pin::Pin<&mut Self>,
189 cx: &mut Context<'_>,
190 ) -> Poll<std::io::Result<()>> {
191 Pin::new(&mut self.write).poll_flush(cx)
192 }
193
194 fn poll_close(
195 mut self: std::pin::Pin<&mut Self>,
196 cx: &mut Context<'_>,
197 ) -> Poll<std::io::Result<()>> {
198 Pin::new(&mut self.write).poll_close(cx)
199 }
200 }
201
202 impl AsyncRead for Connection {
203 fn poll_read(
204 mut self: Pin<&mut Self>,
205 _cx: &mut Context<'_>,
206 buf: &mut [u8],
207 ) -> Poll<std::io::Result<usize>> {
208 let n = std::cmp::min(self.read.len(), buf.len());
209 buf[0..n].copy_from_slice(&self.read[0..n]);
210 self.read = self.read.split_off(n);
211 Poll::Ready(Ok(n))
212 }
213 }
214
215 fn prop(a: Vec<u8>, b: Vec<u8>, max_circuit_bytes: u64) {
216 let connection_a = Connection {
217 read: a.clone(),
218 write: Vec::new(),
219 };
220
221 let connection_b = Connection {
222 read: b.clone(),
223 write: Vec::new(),
224 };
225
226 let mut copy_future = CopyFuture::new(
227 connection_a,
228 connection_b,
229 Duration::from_secs(60),
230 max_circuit_bytes,
231 );
232
233 match block_on(&mut copy_future) {
234 Ok(()) => {
235 assert_eq!(copy_future.src.into_inner().write, b);
236 assert_eq!(copy_future.dst.into_inner().write, a);
237 }
238 Err(error) => {
239 assert_eq!(error.kind(), ErrorKind::Other);
240 assert_eq!(error.to_string(), "Max circuit bytes reached.");
241 assert!(a.len() + b.len() > max_circuit_bytes as usize);
242 }
243 }
244 }
245
246 QuickCheck::new().quickcheck(prop as fn(_, _, _))
247 }
248
249 #[test]
250 fn max_circuit_duration() {
251 struct PendingConnection {}
252
253 impl AsyncWrite for PendingConnection {
254 fn poll_write(
255 self: std::pin::Pin<&mut Self>,
256 _cx: &mut Context<'_>,
257 _buf: &[u8],
258 ) -> Poll<std::io::Result<usize>> {
259 Poll::Pending
260 }
261
262 fn poll_flush(
263 self: std::pin::Pin<&mut Self>,
264 _cx: &mut Context<'_>,
265 ) -> Poll<std::io::Result<()>> {
266 Poll::Pending
267 }
268
269 fn poll_close(
270 self: std::pin::Pin<&mut Self>,
271 _cx: &mut Context<'_>,
272 ) -> Poll<std::io::Result<()>> {
273 Poll::Pending
274 }
275 }
276
277 impl AsyncRead for PendingConnection {
278 fn poll_read(
279 self: Pin<&mut Self>,
280 _cx: &mut Context<'_>,
281 _buf: &mut [u8],
282 ) -> Poll<std::io::Result<usize>> {
283 Poll::Pending
284 }
285 }
286
287 let copy_future = CopyFuture::new(
288 PendingConnection {},
289 PendingConnection {},
290 Duration::from_millis(1),
291 u64::MAX,
292 );
293
294 std::thread::sleep(Duration::from_millis(2));
295
296 let error =
297 block_on(copy_future).expect_err("Expect maximum circuit duration to be reached.");
298 assert_eq!(error.kind(), ErrorKind::TimedOut);
299 }
300
301 #[test]
302 fn forward_data_should_flush_on_pending_source() {
303 struct NeverEndingSource {
304 read: Vec<u8>,
305 }
306
307 impl AsyncRead for NeverEndingSource {
308 fn poll_read(
309 mut self: Pin<&mut Self>,
310 _cx: &mut Context<'_>,
311 buf: &mut [u8],
312 ) -> Poll<std::io::Result<usize>> {
313 if let Some(b) = self.read.pop() {
314 buf[0] = b;
315 return Poll::Ready(Ok(1));
316 }
317
318 Poll::Pending
319 }
320 }
321
322 struct RecordingDestination {
323 method_calls: Vec<Method>,
324 }
325
326 #[derive(Debug, PartialEq)]
327 enum Method {
328 Write(Vec<u8>),
329 Flush,
330 Close,
331 }
332
333 impl AsyncWrite for RecordingDestination {
334 fn poll_write(
335 mut self: std::pin::Pin<&mut Self>,
336 _cx: &mut Context<'_>,
337 buf: &[u8],
338 ) -> Poll<std::io::Result<usize>> {
339 self.method_calls.push(Method::Write(buf.to_vec()));
340 Poll::Ready(Ok(buf.len()))
341 }
342
343 fn poll_flush(
344 mut self: std::pin::Pin<&mut Self>,
345 _cx: &mut Context<'_>,
346 ) -> Poll<std::io::Result<()>> {
347 self.method_calls.push(Method::Flush);
348 Poll::Ready(Ok(()))
349 }
350
351 fn poll_close(
352 mut self: std::pin::Pin<&mut Self>,
353 _cx: &mut Context<'_>,
354 ) -> Poll<std::io::Result<()>> {
355 self.method_calls.push(Method::Close);
356 Poll::Ready(Ok(()))
357 }
358 }
359
360 let mut source = BufReader::new(NeverEndingSource { read: vec![1, 2] });
363
364 let mut destination = BufWriter::with_capacity(
369 3,
370 RecordingDestination {
371 method_calls: vec![],
372 },
373 );
374
375 let mut cx = Context::from_waker(futures::task::noop_waker_ref());
376
377 assert!(
378 matches!(
379 forward_data(&mut source, &mut destination, &mut cx),
380 Poll::Ready(Ok(1)),
381 ),
382 "Expect `forward_data` to forward one read from the source to the wrapped destination."
383 );
384 assert_eq!(
385 destination.get_ref().method_calls.as_slice(), &[],
386 "Given that destination is wrapped with a `BufWrite`, the write doesn't (yet) make it to \
387 the destination. The source might have more data available, thus `forward_data` has not \
388 yet flushed.",
389 );
390
391 assert!(
392 matches!(
393 forward_data(&mut source, &mut destination, &mut cx),
394 Poll::Ready(Ok(1)),
395 ),
396 "Expect `forward_data` to forward one read from the source to the wrapped destination."
397 );
398 assert_eq!(
399 destination.get_ref().method_calls.as_slice(), &[],
400 "Given that destination is wrapped with a `BufWrite`, the write doesn't (yet) make it to \
401 the destination. The source might have more data available, thus `forward_data` has not \
402 yet flushed.",
403 );
404
405 assert!(
406 matches!(
407 forward_data(&mut source, &mut destination, &mut cx),
408 Poll::Pending,
409 ),
410 "The source has no more reads available, but does not close i.e. does not return \
411 `Poll::Ready(Ok(1))` but instead `Poll::Pending`. Thus `forward_data` returns \
412 `Poll::Pending` as well."
413 );
414 assert_eq!(
415 destination.get_ref().method_calls.as_slice(),
416 &[Method::Write(vec![2, 1]), Method::Flush],
417 "Given that source had no more reads, `forward_data` calls flush, thus instructing the \
418 `BufWriter` to flush the two buffered writes down to the destination."
419 );
420 }
421}