1use std::{
28    io,
29    pin::Pin,
30    task::{Context, Poll},
31    time::Duration,
32};
33
34use futures::{
35    future::{Future, FutureExt},
36    io::{AsyncBufRead, AsyncRead, AsyncWrite, BufReader},
37    ready,
38};
39use futures_timer::Delay;
40
41pub(crate) struct CopyFuture<S, D> {
42    src: BufReader<S>,
43    dst: BufReader<D>,
44
45    max_circuit_duration: Delay,
46    max_circuit_bytes: u64,
47    bytes_sent: u64,
48}
49
50impl<S: AsyncRead, D: AsyncRead> CopyFuture<S, D> {
51    pub(crate) fn new(
52        src: S,
53        dst: D,
54        max_circuit_duration: Duration,
55        max_circuit_bytes: u64,
56    ) -> Self {
57        CopyFuture {
58            src: BufReader::new(src),
59            dst: BufReader::new(dst),
60            max_circuit_duration: Delay::new(max_circuit_duration),
61            max_circuit_bytes,
62            bytes_sent: Default::default(),
63        }
64    }
65}
66
67impl<S, D> Future for CopyFuture<S, D>
68where
69    S: AsyncRead + AsyncWrite + Unpin,
70    D: AsyncRead + AsyncWrite + Unpin,
71{
72    type Output = io::Result<()>;
73
74    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
75        let this = &mut *self;
76
77        loop {
78            if this.max_circuit_bytes > 0 && this.bytes_sent > this.max_circuit_bytes {
79                return Poll::Ready(Err(io::Error::other("Max circuit bytes reached.")));
80            }
81
82            enum Status {
83                Pending,
84                Done,
85                Progressed,
86            }
87
88            let src_status = match forward_data(&mut this.src, &mut this.dst, cx) {
89                Poll::Ready(Err(e)) => return Poll::Ready(Err(e)),
90                Poll::Ready(Ok(0)) => Status::Done,
91                Poll::Ready(Ok(i)) => {
92                    this.bytes_sent += i;
93                    Status::Progressed
94                }
95                Poll::Pending => Status::Pending,
96            };
97
98            let dst_status = match forward_data(&mut this.dst, &mut this.src, cx) {
99                Poll::Ready(Err(e)) => return Poll::Ready(Err(e)),
100                Poll::Ready(Ok(0)) => Status::Done,
101                Poll::Ready(Ok(i)) => {
102                    this.bytes_sent += i;
103                    Status::Progressed
104                }
105                Poll::Pending => Status::Pending,
106            };
107
108            match (src_status, dst_status) {
109                (Status::Done, Status::Done) => return Poll::Ready(Ok(())),
111                (Status::Progressed, _) | (_, Status::Progressed) => {}
113                (Status::Pending, Status::Pending) => break,
116                (Status::Pending, Status::Done) | (Status::Done, Status::Pending) => break,
119            }
120        }
121
122        if let Poll::Ready(()) = this.max_circuit_duration.poll_unpin(cx) {
123            return Poll::Ready(Err(io::ErrorKind::TimedOut.into()));
124        }
125
126        Poll::Pending
127    }
128}
129
130fn forward_data<S: AsyncBufRead + Unpin, D: AsyncWrite + Unpin>(
135    mut src: &mut S,
136    mut dst: &mut D,
137    cx: &mut Context<'_>,
138) -> Poll<io::Result<u64>> {
139    let buffer = match Pin::new(&mut src).poll_fill_buf(cx)? {
140        Poll::Ready(buffer) => buffer,
141        Poll::Pending => {
142            let _ = Pin::new(&mut dst).poll_flush(cx)?;
143            return Poll::Pending;
144        }
145    };
146
147    if buffer.is_empty() {
148        ready!(Pin::new(&mut dst).poll_flush(cx))?;
149        ready!(Pin::new(&mut dst).poll_close(cx))?;
150        return Poll::Ready(Ok(0));
151    }
152
153    let i = ready!(Pin::new(dst).poll_write(cx, buffer))?;
154    if i == 0 {
155        return Poll::Ready(Err(io::ErrorKind::WriteZero.into()));
156    }
157    Pin::new(src).consume(i);
158
159    Poll::Ready(Ok(i.try_into().expect("usize to fit into u64.")))
160}
161
162#[cfg(test)]
163mod tests {
164    use std::io::ErrorKind;
165
166    use futures::{executor::block_on, io::BufWriter};
167    use quickcheck::QuickCheck;
168
169    use super::*;
170
171    #[test]
172    fn quickcheck() {
173        struct Connection {
174            read: Vec<u8>,
175            write: Vec<u8>,
176        }
177
178        impl AsyncWrite for Connection {
179            fn poll_write(
180                mut self: std::pin::Pin<&mut Self>,
181                cx: &mut Context<'_>,
182                buf: &[u8],
183            ) -> Poll<std::io::Result<usize>> {
184                Pin::new(&mut self.write).poll_write(cx, buf)
185            }
186
187            fn poll_flush(
188                mut self: std::pin::Pin<&mut Self>,
189                cx: &mut Context<'_>,
190            ) -> Poll<std::io::Result<()>> {
191                Pin::new(&mut self.write).poll_flush(cx)
192            }
193
194            fn poll_close(
195                mut self: std::pin::Pin<&mut Self>,
196                cx: &mut Context<'_>,
197            ) -> Poll<std::io::Result<()>> {
198                Pin::new(&mut self.write).poll_close(cx)
199            }
200        }
201
202        impl AsyncRead for Connection {
203            fn poll_read(
204                mut self: Pin<&mut Self>,
205                _cx: &mut Context<'_>,
206                buf: &mut [u8],
207            ) -> Poll<std::io::Result<usize>> {
208                let n = std::cmp::min(self.read.len(), buf.len());
209                buf[0..n].copy_from_slice(&self.read[0..n]);
210                self.read = self.read.split_off(n);
211                Poll::Ready(Ok(n))
212            }
213        }
214
215        fn prop(a: Vec<u8>, b: Vec<u8>, max_circuit_bytes: u64) {
216            let connection_a = Connection {
217                read: a.clone(),
218                write: Vec::new(),
219            };
220
221            let connection_b = Connection {
222                read: b.clone(),
223                write: Vec::new(),
224            };
225
226            let mut copy_future = CopyFuture::new(
227                connection_a,
228                connection_b,
229                Duration::from_secs(60),
230                max_circuit_bytes,
231            );
232
233            match block_on(&mut copy_future) {
234                Ok(()) => {
235                    assert_eq!(copy_future.src.into_inner().write, b);
236                    assert_eq!(copy_future.dst.into_inner().write, a);
237                }
238                Err(error) => {
239                    assert_eq!(error.kind(), ErrorKind::Other);
240                    assert_eq!(error.to_string(), "Max circuit bytes reached.");
241                    assert!(a.len() + b.len() > max_circuit_bytes as usize);
242                }
243            }
244        }
245
246        QuickCheck::new().quickcheck(prop as fn(_, _, _))
247    }
248
249    #[test]
250    fn max_circuit_duration() {
251        struct PendingConnection {}
252
253        impl AsyncWrite for PendingConnection {
254            fn poll_write(
255                self: std::pin::Pin<&mut Self>,
256                _cx: &mut Context<'_>,
257                _buf: &[u8],
258            ) -> Poll<std::io::Result<usize>> {
259                Poll::Pending
260            }
261
262            fn poll_flush(
263                self: std::pin::Pin<&mut Self>,
264                _cx: &mut Context<'_>,
265            ) -> Poll<std::io::Result<()>> {
266                Poll::Pending
267            }
268
269            fn poll_close(
270                self: std::pin::Pin<&mut Self>,
271                _cx: &mut Context<'_>,
272            ) -> Poll<std::io::Result<()>> {
273                Poll::Pending
274            }
275        }
276
277        impl AsyncRead for PendingConnection {
278            fn poll_read(
279                self: Pin<&mut Self>,
280                _cx: &mut Context<'_>,
281                _buf: &mut [u8],
282            ) -> Poll<std::io::Result<usize>> {
283                Poll::Pending
284            }
285        }
286
287        let copy_future = CopyFuture::new(
288            PendingConnection {},
289            PendingConnection {},
290            Duration::from_millis(1),
291            u64::MAX,
292        );
293
294        std::thread::sleep(Duration::from_millis(2));
295
296        let error =
297            block_on(copy_future).expect_err("Expect maximum circuit duration to be reached.");
298        assert_eq!(error.kind(), ErrorKind::TimedOut);
299    }
300
301    #[test]
302    fn forward_data_should_flush_on_pending_source() {
303        struct NeverEndingSource {
304            read: Vec<u8>,
305        }
306
307        impl AsyncRead for NeverEndingSource {
308            fn poll_read(
309                mut self: Pin<&mut Self>,
310                _cx: &mut Context<'_>,
311                buf: &mut [u8],
312            ) -> Poll<std::io::Result<usize>> {
313                if let Some(b) = self.read.pop() {
314                    buf[0] = b;
315                    return Poll::Ready(Ok(1));
316                }
317
318                Poll::Pending
319            }
320        }
321
322        struct RecordingDestination {
323            method_calls: Vec<Method>,
324        }
325
326        #[derive(Debug, PartialEq)]
327        enum Method {
328            Write(Vec<u8>),
329            Flush,
330            Close,
331        }
332
333        impl AsyncWrite for RecordingDestination {
334            fn poll_write(
335                mut self: std::pin::Pin<&mut Self>,
336                _cx: &mut Context<'_>,
337                buf: &[u8],
338            ) -> Poll<std::io::Result<usize>> {
339                self.method_calls.push(Method::Write(buf.to_vec()));
340                Poll::Ready(Ok(buf.len()))
341            }
342
343            fn poll_flush(
344                mut self: std::pin::Pin<&mut Self>,
345                _cx: &mut Context<'_>,
346            ) -> Poll<std::io::Result<()>> {
347                self.method_calls.push(Method::Flush);
348                Poll::Ready(Ok(()))
349            }
350
351            fn poll_close(
352                mut self: std::pin::Pin<&mut Self>,
353                _cx: &mut Context<'_>,
354            ) -> Poll<std::io::Result<()>> {
355                self.method_calls.push(Method::Close);
356                Poll::Ready(Ok(()))
357            }
358        }
359
360        let mut source = BufReader::new(NeverEndingSource { read: vec![1, 2] });
363
364        let mut destination = BufWriter::with_capacity(
369            3,
370            RecordingDestination {
371                method_calls: vec![],
372            },
373        );
374
375        let mut cx = Context::from_waker(futures::task::noop_waker_ref());
376
377        assert!(
378            matches!(
379                forward_data(&mut source, &mut destination, &mut cx),
380                Poll::Ready(Ok(1)),
381            ),
382            "Expect `forward_data` to forward one read from the source to the wrapped destination."
383        );
384        assert_eq!(
385            destination.get_ref().method_calls.as_slice(), &[],
386            "Given that destination is wrapped with a `BufWrite`, the write doesn't (yet) make it to \
387            the destination. The source might have more data available, thus `forward_data` has not \
388            yet flushed.",
389        );
390
391        assert!(
392            matches!(
393                forward_data(&mut source, &mut destination, &mut cx),
394                Poll::Ready(Ok(1)),
395            ),
396            "Expect `forward_data` to forward one read from the source to the wrapped destination."
397        );
398        assert_eq!(
399            destination.get_ref().method_calls.as_slice(), &[],
400            "Given that destination is wrapped with a `BufWrite`, the write doesn't (yet) make it to \
401            the destination. The source might have more data available, thus `forward_data` has not \
402            yet flushed.",
403        );
404
405        assert!(
406            matches!(
407                forward_data(&mut source, &mut destination, &mut cx),
408                Poll::Pending,
409            ),
410            "The source has no more reads available, but does not close i.e. does not return \
411            `Poll::Ready(Ok(1))` but instead `Poll::Pending`. Thus `forward_data` returns \
412            `Poll::Pending` as well."
413        );
414        assert_eq!(
415            destination.get_ref().method_calls.as_slice(),
416            &[Method::Write(vec![2, 1]), Method::Flush],
417            "Given that source had no more reads, `forward_data` calls flush, thus instructing the \
418            `BufWriter` to flush the two buffered writes down to the destination."
419        );
420    }
421}