libp2p_relay/
copy_future.rs

1// Copyright 2020 Parity Technologies (UK) Ltd.
2// Copyright 2021 Protocol Labs.
3//
4// Permission is hereby granted, free of charge, to any person obtaining a
5// copy of this software and associated documentation files (the "Software"),
6// to deal in the Software without restriction, including without limitation
7// the rights to use, copy, modify, merge, publish, distribute, sublicense,
8// and/or sell copies of the Software, and to permit persons to whom the
9// Software is furnished to do so, subject to the following conditions:
10//
11// The above copyright notice and this permission notice shall be included in
12// all copies or substantial portions of the Software.
13//
14// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
15// OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
16// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
17// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
18// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
19// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
20// DEALINGS IN THE SOFTWARE.
21
22//! Helper to interconnect two substreams, connecting the receiver side of A with the sender side of
23//! B and vice versa.
24//!
25//! Inspired by [`futures::io::Copy`].
26
27use std::{
28    io,
29    pin::Pin,
30    task::{Context, Poll},
31    time::Duration,
32};
33
34use futures::{
35    future::{Future, FutureExt},
36    io::{AsyncBufRead, AsyncRead, AsyncWrite, BufReader},
37    ready,
38};
39use futures_timer::Delay;
40
41pub(crate) struct CopyFuture<S, D> {
42    src: BufReader<S>,
43    dst: BufReader<D>,
44
45    max_circuit_duration: Delay,
46    max_circuit_bytes: u64,
47    bytes_sent: u64,
48}
49
50impl<S: AsyncRead, D: AsyncRead> CopyFuture<S, D> {
51    pub(crate) fn new(
52        src: S,
53        dst: D,
54        max_circuit_duration: Duration,
55        max_circuit_bytes: u64,
56    ) -> Self {
57        CopyFuture {
58            src: BufReader::new(src),
59            dst: BufReader::new(dst),
60            max_circuit_duration: Delay::new(max_circuit_duration),
61            max_circuit_bytes,
62            bytes_sent: Default::default(),
63        }
64    }
65}
66
67impl<S, D> Future for CopyFuture<S, D>
68where
69    S: AsyncRead + AsyncWrite + Unpin,
70    D: AsyncRead + AsyncWrite + Unpin,
71{
72    type Output = io::Result<()>;
73
74    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
75        let this = &mut *self;
76
77        loop {
78            if this.max_circuit_bytes > 0 && this.bytes_sent > this.max_circuit_bytes {
79                return Poll::Ready(Err(io::Error::other("Max circuit bytes reached.")));
80            }
81
82            enum Status {
83                Pending,
84                Done,
85                Progressed,
86            }
87
88            let src_status = match forward_data(&mut this.src, &mut this.dst, cx) {
89                Poll::Ready(Err(e)) => return Poll::Ready(Err(e)),
90                Poll::Ready(Ok(0)) => Status::Done,
91                Poll::Ready(Ok(i)) => {
92                    this.bytes_sent += i;
93                    Status::Progressed
94                }
95                Poll::Pending => Status::Pending,
96            };
97
98            let dst_status = match forward_data(&mut this.dst, &mut this.src, cx) {
99                Poll::Ready(Err(e)) => return Poll::Ready(Err(e)),
100                Poll::Ready(Ok(0)) => Status::Done,
101                Poll::Ready(Ok(i)) => {
102                    this.bytes_sent += i;
103                    Status::Progressed
104                }
105                Poll::Pending => Status::Pending,
106            };
107
108            match (src_status, dst_status) {
109                // Both source and destination are done sending data.
110                (Status::Done, Status::Done) => return Poll::Ready(Ok(())),
111                // Either source or destination made progress.
112                (Status::Progressed, _) | (_, Status::Progressed) => {}
113                // Both are pending. Check if max circuit duration timer fired, otherwise return
114                // Poll::Pending.
115                (Status::Pending, Status::Pending) => break,
116                // One is done sending data, the other is pending. Check if timer fired, otherwise
117                // return Poll::Pending.
118                (Status::Pending, Status::Done) | (Status::Done, Status::Pending) => break,
119            }
120        }
121
122        if let Poll::Ready(()) = this.max_circuit_duration.poll_unpin(cx) {
123            return Poll::Ready(Err(io::ErrorKind::TimedOut.into()));
124        }
125
126        Poll::Pending
127    }
128}
129
130/// Forwards data from `source` to `destination`.
131///
132/// Returns `0` when done, i.e. `source` having reached EOF, returns number of bytes sent otherwise,
133/// thus indicating progress.
134fn forward_data<S: AsyncBufRead + Unpin, D: AsyncWrite + Unpin>(
135    mut src: &mut S,
136    mut dst: &mut D,
137    cx: &mut Context<'_>,
138) -> Poll<io::Result<u64>> {
139    let buffer = match Pin::new(&mut src).poll_fill_buf(cx)? {
140        Poll::Ready(buffer) => buffer,
141        Poll::Pending => {
142            let _ = Pin::new(&mut dst).poll_flush(cx)?;
143            return Poll::Pending;
144        }
145    };
146
147    if buffer.is_empty() {
148        ready!(Pin::new(&mut dst).poll_flush(cx))?;
149        ready!(Pin::new(&mut dst).poll_close(cx))?;
150        return Poll::Ready(Ok(0));
151    }
152
153    let i = ready!(Pin::new(dst).poll_write(cx, buffer))?;
154    if i == 0 {
155        return Poll::Ready(Err(io::ErrorKind::WriteZero.into()));
156    }
157    Pin::new(src).consume(i);
158
159    Poll::Ready(Ok(i.try_into().expect("usize to fit into u64.")))
160}
161
162#[cfg(test)]
163mod tests {
164    use std::io::ErrorKind;
165
166    use futures::{executor::block_on, io::BufWriter};
167    use quickcheck::QuickCheck;
168
169    use super::*;
170
171    #[test]
172    fn quickcheck() {
173        struct Connection {
174            read: Vec<u8>,
175            write: Vec<u8>,
176        }
177
178        impl AsyncWrite for Connection {
179            fn poll_write(
180                mut self: std::pin::Pin<&mut Self>,
181                cx: &mut Context<'_>,
182                buf: &[u8],
183            ) -> Poll<std::io::Result<usize>> {
184                Pin::new(&mut self.write).poll_write(cx, buf)
185            }
186
187            fn poll_flush(
188                mut self: std::pin::Pin<&mut Self>,
189                cx: &mut Context<'_>,
190            ) -> Poll<std::io::Result<()>> {
191                Pin::new(&mut self.write).poll_flush(cx)
192            }
193
194            fn poll_close(
195                mut self: std::pin::Pin<&mut Self>,
196                cx: &mut Context<'_>,
197            ) -> Poll<std::io::Result<()>> {
198                Pin::new(&mut self.write).poll_close(cx)
199            }
200        }
201
202        impl AsyncRead for Connection {
203            fn poll_read(
204                mut self: Pin<&mut Self>,
205                _cx: &mut Context<'_>,
206                buf: &mut [u8],
207            ) -> Poll<std::io::Result<usize>> {
208                let n = std::cmp::min(self.read.len(), buf.len());
209                buf[0..n].copy_from_slice(&self.read[0..n]);
210                self.read = self.read.split_off(n);
211                Poll::Ready(Ok(n))
212            }
213        }
214
215        fn prop(a: Vec<u8>, b: Vec<u8>, max_circuit_bytes: u64) {
216            let connection_a = Connection {
217                read: a.clone(),
218                write: Vec::new(),
219            };
220
221            let connection_b = Connection {
222                read: b.clone(),
223                write: Vec::new(),
224            };
225
226            let mut copy_future = CopyFuture::new(
227                connection_a,
228                connection_b,
229                Duration::from_secs(60),
230                max_circuit_bytes,
231            );
232
233            match block_on(&mut copy_future) {
234                Ok(()) => {
235                    assert_eq!(copy_future.src.into_inner().write, b);
236                    assert_eq!(copy_future.dst.into_inner().write, a);
237                }
238                Err(error) => {
239                    assert_eq!(error.kind(), ErrorKind::Other);
240                    assert_eq!(error.to_string(), "Max circuit bytes reached.");
241                    assert!(a.len() + b.len() > max_circuit_bytes as usize);
242                }
243            }
244        }
245
246        QuickCheck::new().quickcheck(prop as fn(_, _, _))
247    }
248
249    #[test]
250    fn max_circuit_duration() {
251        struct PendingConnection {}
252
253        impl AsyncWrite for PendingConnection {
254            fn poll_write(
255                self: std::pin::Pin<&mut Self>,
256                _cx: &mut Context<'_>,
257                _buf: &[u8],
258            ) -> Poll<std::io::Result<usize>> {
259                Poll::Pending
260            }
261
262            fn poll_flush(
263                self: std::pin::Pin<&mut Self>,
264                _cx: &mut Context<'_>,
265            ) -> Poll<std::io::Result<()>> {
266                Poll::Pending
267            }
268
269            fn poll_close(
270                self: std::pin::Pin<&mut Self>,
271                _cx: &mut Context<'_>,
272            ) -> Poll<std::io::Result<()>> {
273                Poll::Pending
274            }
275        }
276
277        impl AsyncRead for PendingConnection {
278            fn poll_read(
279                self: Pin<&mut Self>,
280                _cx: &mut Context<'_>,
281                _buf: &mut [u8],
282            ) -> Poll<std::io::Result<usize>> {
283                Poll::Pending
284            }
285        }
286
287        let copy_future = CopyFuture::new(
288            PendingConnection {},
289            PendingConnection {},
290            Duration::from_millis(1),
291            u64::MAX,
292        );
293
294        std::thread::sleep(Duration::from_millis(2));
295
296        let error =
297            block_on(copy_future).expect_err("Expect maximum circuit duration to be reached.");
298        assert_eq!(error.kind(), ErrorKind::TimedOut);
299    }
300
301    #[test]
302    fn forward_data_should_flush_on_pending_source() {
303        struct NeverEndingSource {
304            read: Vec<u8>,
305        }
306
307        impl AsyncRead for NeverEndingSource {
308            fn poll_read(
309                mut self: Pin<&mut Self>,
310                _cx: &mut Context<'_>,
311                buf: &mut [u8],
312            ) -> Poll<std::io::Result<usize>> {
313                if let Some(b) = self.read.pop() {
314                    buf[0] = b;
315                    return Poll::Ready(Ok(1));
316                }
317
318                Poll::Pending
319            }
320        }
321
322        struct RecordingDestination {
323            method_calls: Vec<Method>,
324        }
325
326        #[derive(Debug, PartialEq)]
327        enum Method {
328            Write(Vec<u8>),
329            Flush,
330            Close,
331        }
332
333        impl AsyncWrite for RecordingDestination {
334            fn poll_write(
335                mut self: std::pin::Pin<&mut Self>,
336                _cx: &mut Context<'_>,
337                buf: &[u8],
338            ) -> Poll<std::io::Result<usize>> {
339                self.method_calls.push(Method::Write(buf.to_vec()));
340                Poll::Ready(Ok(buf.len()))
341            }
342
343            fn poll_flush(
344                mut self: std::pin::Pin<&mut Self>,
345                _cx: &mut Context<'_>,
346            ) -> Poll<std::io::Result<()>> {
347                self.method_calls.push(Method::Flush);
348                Poll::Ready(Ok(()))
349            }
350
351            fn poll_close(
352                mut self: std::pin::Pin<&mut Self>,
353                _cx: &mut Context<'_>,
354            ) -> Poll<std::io::Result<()>> {
355                self.method_calls.push(Method::Close);
356                Poll::Ready(Ok(()))
357            }
358        }
359
360        // The source has two reads available, handing them out
361        // on `AsyncRead::poll_read` one by one.
362        let mut source = BufReader::new(NeverEndingSource { read: vec![1, 2] });
363
364        // The destination is wrapped by a `BufWriter` with a capacity of `3`, i.e. one larger than
365        // the available reads of the source. Without an explicit `AsyncWrite::poll_flush` the two
366        // reads would thus never make it to the destination,
367        // but instead be stuck in the buffer of the `BufWrite`.
368        let mut destination = BufWriter::with_capacity(
369            3,
370            RecordingDestination {
371                method_calls: vec![],
372            },
373        );
374
375        let mut cx = Context::from_waker(futures::task::noop_waker_ref());
376
377        assert!(
378            matches!(
379                forward_data(&mut source, &mut destination, &mut cx),
380                Poll::Ready(Ok(1)),
381            ),
382            "Expect `forward_data` to forward one read from the source to the wrapped destination."
383        );
384        assert_eq!(
385            destination.get_ref().method_calls.as_slice(), &[],
386            "Given that destination is wrapped with a `BufWrite`, the write doesn't (yet) make it to \
387            the destination. The source might have more data available, thus `forward_data` has not \
388            yet flushed.",
389        );
390
391        assert!(
392            matches!(
393                forward_data(&mut source, &mut destination, &mut cx),
394                Poll::Ready(Ok(1)),
395            ),
396            "Expect `forward_data` to forward one read from the source to the wrapped destination."
397        );
398        assert_eq!(
399            destination.get_ref().method_calls.as_slice(), &[],
400            "Given that destination is wrapped with a `BufWrite`, the write doesn't (yet) make it to \
401            the destination. The source might have more data available, thus `forward_data` has not \
402            yet flushed.",
403        );
404
405        assert!(
406            matches!(
407                forward_data(&mut source, &mut destination, &mut cx),
408                Poll::Pending,
409            ),
410            "The source has no more reads available, but does not close i.e. does not return \
411            `Poll::Ready(Ok(1))` but instead `Poll::Pending`. Thus `forward_data` returns \
412            `Poll::Pending` as well."
413        );
414        assert_eq!(
415            destination.get_ref().method_calls.as_slice(),
416            &[Method::Write(vec![2, 1]), Method::Flush],
417            "Given that source had no more reads, `forward_data` calls flush, thus instructing the \
418            `BufWriter` to flush the two buffered writes down to the destination."
419        );
420    }
421}