1use std::{
28 io,
29 pin::Pin,
30 task::{Context, Poll},
31 time::Duration,
32};
33
34use futures::{
35 future::{Future, FutureExt},
36 io::{AsyncBufRead, AsyncRead, AsyncWrite, BufReader},
37 ready,
38};
39use futures_timer::Delay;
40
41pub(crate) struct CopyFuture<S, D> {
42 src: BufReader<S>,
43 dst: BufReader<D>,
44
45 max_circuit_duration: Delay,
46 max_circuit_bytes: u64,
47 bytes_sent: u64,
48}
49
50impl<S: AsyncRead, D: AsyncRead> CopyFuture<S, D> {
51 pub(crate) fn new(
52 src: S,
53 dst: D,
54 max_circuit_duration: Duration,
55 max_circuit_bytes: u64,
56 ) -> Self {
57 CopyFuture {
58 src: BufReader::new(src),
59 dst: BufReader::new(dst),
60 max_circuit_duration: Delay::new(max_circuit_duration),
61 max_circuit_bytes,
62 bytes_sent: Default::default(),
63 }
64 }
65}
66
67impl<S, D> Future for CopyFuture<S, D>
68where
69 S: AsyncRead + AsyncWrite + Unpin,
70 D: AsyncRead + AsyncWrite + Unpin,
71{
72 type Output = io::Result<()>;
73
74 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
75 let this = &mut *self;
76
77 loop {
78 if this.max_circuit_bytes > 0 && this.bytes_sent > this.max_circuit_bytes {
79 return Poll::Ready(Err(io::Error::new(
80 io::ErrorKind::Other,
81 "Max circuit bytes reached.",
82 )));
83 }
84
85 enum Status {
86 Pending,
87 Done,
88 Progressed,
89 }
90
91 let src_status = match forward_data(&mut this.src, &mut this.dst, cx) {
92 Poll::Ready(Err(e)) => return Poll::Ready(Err(e)),
93 Poll::Ready(Ok(0)) => Status::Done,
94 Poll::Ready(Ok(i)) => {
95 this.bytes_sent += i;
96 Status::Progressed
97 }
98 Poll::Pending => Status::Pending,
99 };
100
101 let dst_status = match forward_data(&mut this.dst, &mut this.src, cx) {
102 Poll::Ready(Err(e)) => return Poll::Ready(Err(e)),
103 Poll::Ready(Ok(0)) => Status::Done,
104 Poll::Ready(Ok(i)) => {
105 this.bytes_sent += i;
106 Status::Progressed
107 }
108 Poll::Pending => Status::Pending,
109 };
110
111 match (src_status, dst_status) {
112 (Status::Done, Status::Done) => return Poll::Ready(Ok(())),
114 (Status::Progressed, _) | (_, Status::Progressed) => {}
116 (Status::Pending, Status::Pending) => break,
119 (Status::Pending, Status::Done) | (Status::Done, Status::Pending) => break,
122 }
123 }
124
125 if let Poll::Ready(()) = this.max_circuit_duration.poll_unpin(cx) {
126 return Poll::Ready(Err(io::ErrorKind::TimedOut.into()));
127 }
128
129 Poll::Pending
130 }
131}
132
133fn forward_data<S: AsyncBufRead + Unpin, D: AsyncWrite + Unpin>(
138 mut src: &mut S,
139 mut dst: &mut D,
140 cx: &mut Context<'_>,
141) -> Poll<io::Result<u64>> {
142 let buffer = match Pin::new(&mut src).poll_fill_buf(cx)? {
143 Poll::Ready(buffer) => buffer,
144 Poll::Pending => {
145 let _ = Pin::new(&mut dst).poll_flush(cx)?;
146 return Poll::Pending;
147 }
148 };
149
150 if buffer.is_empty() {
151 ready!(Pin::new(&mut dst).poll_flush(cx))?;
152 ready!(Pin::new(&mut dst).poll_close(cx))?;
153 return Poll::Ready(Ok(0));
154 }
155
156 let i = ready!(Pin::new(dst).poll_write(cx, buffer))?;
157 if i == 0 {
158 return Poll::Ready(Err(io::ErrorKind::WriteZero.into()));
159 }
160 Pin::new(src).consume(i);
161
162 Poll::Ready(Ok(i.try_into().expect("usize to fit into u64.")))
163}
164
165#[cfg(test)]
166mod tests {
167 use std::io::ErrorKind;
168
169 use futures::{executor::block_on, io::BufWriter};
170 use quickcheck::QuickCheck;
171
172 use super::*;
173
174 #[test]
175 fn quickcheck() {
176 struct Connection {
177 read: Vec<u8>,
178 write: Vec<u8>,
179 }
180
181 impl AsyncWrite for Connection {
182 fn poll_write(
183 mut self: std::pin::Pin<&mut Self>,
184 cx: &mut Context<'_>,
185 buf: &[u8],
186 ) -> Poll<std::io::Result<usize>> {
187 Pin::new(&mut self.write).poll_write(cx, buf)
188 }
189
190 fn poll_flush(
191 mut self: std::pin::Pin<&mut Self>,
192 cx: &mut Context<'_>,
193 ) -> Poll<std::io::Result<()>> {
194 Pin::new(&mut self.write).poll_flush(cx)
195 }
196
197 fn poll_close(
198 mut self: std::pin::Pin<&mut Self>,
199 cx: &mut Context<'_>,
200 ) -> Poll<std::io::Result<()>> {
201 Pin::new(&mut self.write).poll_close(cx)
202 }
203 }
204
205 impl AsyncRead for Connection {
206 fn poll_read(
207 mut self: Pin<&mut Self>,
208 _cx: &mut Context<'_>,
209 buf: &mut [u8],
210 ) -> Poll<std::io::Result<usize>> {
211 let n = std::cmp::min(self.read.len(), buf.len());
212 buf[0..n].copy_from_slice(&self.read[0..n]);
213 self.read = self.read.split_off(n);
214 Poll::Ready(Ok(n))
215 }
216 }
217
218 fn prop(a: Vec<u8>, b: Vec<u8>, max_circuit_bytes: u64) {
219 let connection_a = Connection {
220 read: a.clone(),
221 write: Vec::new(),
222 };
223
224 let connection_b = Connection {
225 read: b.clone(),
226 write: Vec::new(),
227 };
228
229 let mut copy_future = CopyFuture::new(
230 connection_a,
231 connection_b,
232 Duration::from_secs(60),
233 max_circuit_bytes,
234 );
235
236 match block_on(&mut copy_future) {
237 Ok(()) => {
238 assert_eq!(copy_future.src.into_inner().write, b);
239 assert_eq!(copy_future.dst.into_inner().write, a);
240 }
241 Err(error) => {
242 assert_eq!(error.kind(), ErrorKind::Other);
243 assert_eq!(error.to_string(), "Max circuit bytes reached.");
244 assert!(a.len() + b.len() > max_circuit_bytes as usize);
245 }
246 }
247 }
248
249 QuickCheck::new().quickcheck(prop as fn(_, _, _))
250 }
251
252 #[test]
253 fn max_circuit_duration() {
254 struct PendingConnection {}
255
256 impl AsyncWrite for PendingConnection {
257 fn poll_write(
258 self: std::pin::Pin<&mut Self>,
259 _cx: &mut Context<'_>,
260 _buf: &[u8],
261 ) -> Poll<std::io::Result<usize>> {
262 Poll::Pending
263 }
264
265 fn poll_flush(
266 self: std::pin::Pin<&mut Self>,
267 _cx: &mut Context<'_>,
268 ) -> Poll<std::io::Result<()>> {
269 Poll::Pending
270 }
271
272 fn poll_close(
273 self: std::pin::Pin<&mut Self>,
274 _cx: &mut Context<'_>,
275 ) -> Poll<std::io::Result<()>> {
276 Poll::Pending
277 }
278 }
279
280 impl AsyncRead for PendingConnection {
281 fn poll_read(
282 self: Pin<&mut Self>,
283 _cx: &mut Context<'_>,
284 _buf: &mut [u8],
285 ) -> Poll<std::io::Result<usize>> {
286 Poll::Pending
287 }
288 }
289
290 let copy_future = CopyFuture::new(
291 PendingConnection {},
292 PendingConnection {},
293 Duration::from_millis(1),
294 u64::MAX,
295 );
296
297 std::thread::sleep(Duration::from_millis(2));
298
299 let error =
300 block_on(copy_future).expect_err("Expect maximum circuit duration to be reached.");
301 assert_eq!(error.kind(), ErrorKind::TimedOut);
302 }
303
304 #[test]
305 fn forward_data_should_flush_on_pending_source() {
306 struct NeverEndingSource {
307 read: Vec<u8>,
308 }
309
310 impl AsyncRead for NeverEndingSource {
311 fn poll_read(
312 mut self: Pin<&mut Self>,
313 _cx: &mut Context<'_>,
314 buf: &mut [u8],
315 ) -> Poll<std::io::Result<usize>> {
316 if let Some(b) = self.read.pop() {
317 buf[0] = b;
318 return Poll::Ready(Ok(1));
319 }
320
321 Poll::Pending
322 }
323 }
324
325 struct RecordingDestination {
326 method_calls: Vec<Method>,
327 }
328
329 #[derive(Debug, PartialEq)]
330 enum Method {
331 Write(Vec<u8>),
332 Flush,
333 Close,
334 }
335
336 impl AsyncWrite for RecordingDestination {
337 fn poll_write(
338 mut self: std::pin::Pin<&mut Self>,
339 _cx: &mut Context<'_>,
340 buf: &[u8],
341 ) -> Poll<std::io::Result<usize>> {
342 self.method_calls.push(Method::Write(buf.to_vec()));
343 Poll::Ready(Ok(buf.len()))
344 }
345
346 fn poll_flush(
347 mut self: std::pin::Pin<&mut Self>,
348 _cx: &mut Context<'_>,
349 ) -> Poll<std::io::Result<()>> {
350 self.method_calls.push(Method::Flush);
351 Poll::Ready(Ok(()))
352 }
353
354 fn poll_close(
355 mut self: std::pin::Pin<&mut Self>,
356 _cx: &mut Context<'_>,
357 ) -> Poll<std::io::Result<()>> {
358 self.method_calls.push(Method::Close);
359 Poll::Ready(Ok(()))
360 }
361 }
362
363 let mut source = BufReader::new(NeverEndingSource { read: vec![1, 2] });
366
367 let mut destination = BufWriter::with_capacity(
372 3,
373 RecordingDestination {
374 method_calls: vec![],
375 },
376 );
377
378 let mut cx = Context::from_waker(futures::task::noop_waker_ref());
379
380 assert!(
381 matches!(
382 forward_data(&mut source, &mut destination, &mut cx),
383 Poll::Ready(Ok(1)),
384 ),
385 "Expect `forward_data` to forward one read from the source to the wrapped destination."
386 );
387 assert_eq!(
388 destination.get_ref().method_calls.as_slice(), &[],
389 "Given that destination is wrapped with a `BufWrite`, the write doesn't (yet) make it to \
390 the destination. The source might have more data available, thus `forward_data` has not \
391 yet flushed.",
392 );
393
394 assert!(
395 matches!(
396 forward_data(&mut source, &mut destination, &mut cx),
397 Poll::Ready(Ok(1)),
398 ),
399 "Expect `forward_data` to forward one read from the source to the wrapped destination."
400 );
401 assert_eq!(
402 destination.get_ref().method_calls.as_slice(), &[],
403 "Given that destination is wrapped with a `BufWrite`, the write doesn't (yet) make it to \
404 the destination. The source might have more data available, thus `forward_data` has not \
405 yet flushed.",
406 );
407
408 assert!(
409 matches!(
410 forward_data(&mut source, &mut destination, &mut cx),
411 Poll::Pending,
412 ),
413 "The source has no more reads available, but does not close i.e. does not return \
414 `Poll::Ready(Ok(1))` but instead `Poll::Pending`. Thus `forward_data` returns \
415 `Poll::Pending` as well."
416 );
417 assert_eq!(
418 destination.get_ref().method_calls.as_slice(),
419 &[Method::Write(vec![2, 1]), Method::Flush],
420 "Given that source had no more reads, `forward_data` calls flush, thus instructing the \
421 `BufWriter` to flush the two buffered writes down to the destination."
422 );
423 }
424}