libp2p_relay/
copy_future.rs

1// Copyright 2020 Parity Technologies (UK) Ltd.
2// Copyright 2021 Protocol Labs.
3//
4// Permission is hereby granted, free of charge, to any person obtaining a
5// copy of this software and associated documentation files (the "Software"),
6// to deal in the Software without restriction, including without limitation
7// the rights to use, copy, modify, merge, publish, distribute, sublicense,
8// and/or sell copies of the Software, and to permit persons to whom the
9// Software is furnished to do so, subject to the following conditions:
10//
11// The above copyright notice and this permission notice shall be included in
12// all copies or substantial portions of the Software.
13//
14// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
15// OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
16// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
17// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
18// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
19// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
20// DEALINGS IN THE SOFTWARE.
21
22//! Helper to interconnect two substreams, connecting the receiver side of A with the sender side of
23//! B and vice versa.
24//!
25//! Inspired by [`futures::io::Copy`].
26
27use std::{
28    io,
29    pin::Pin,
30    task::{Context, Poll},
31    time::Duration,
32};
33
34use futures::{
35    future::{Future, FutureExt},
36    io::{AsyncBufRead, AsyncRead, AsyncWrite, BufReader},
37    ready,
38};
39use futures_timer::Delay;
40
41pub(crate) struct CopyFuture<S, D> {
42    src: BufReader<S>,
43    dst: BufReader<D>,
44
45    max_circuit_duration: Delay,
46    max_circuit_bytes: u64,
47    bytes_sent: u64,
48}
49
50impl<S: AsyncRead, D: AsyncRead> CopyFuture<S, D> {
51    pub(crate) fn new(
52        src: S,
53        dst: D,
54        max_circuit_duration: Duration,
55        max_circuit_bytes: u64,
56    ) -> Self {
57        CopyFuture {
58            src: BufReader::new(src),
59            dst: BufReader::new(dst),
60            max_circuit_duration: Delay::new(max_circuit_duration),
61            max_circuit_bytes,
62            bytes_sent: Default::default(),
63        }
64    }
65}
66
67impl<S, D> Future for CopyFuture<S, D>
68where
69    S: AsyncRead + AsyncWrite + Unpin,
70    D: AsyncRead + AsyncWrite + Unpin,
71{
72    type Output = io::Result<()>;
73
74    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
75        let this = &mut *self;
76
77        loop {
78            if this.max_circuit_bytes > 0 && this.bytes_sent > this.max_circuit_bytes {
79                return Poll::Ready(Err(io::Error::new(
80                    io::ErrorKind::Other,
81                    "Max circuit bytes reached.",
82                )));
83            }
84
85            enum Status {
86                Pending,
87                Done,
88                Progressed,
89            }
90
91            let src_status = match forward_data(&mut this.src, &mut this.dst, cx) {
92                Poll::Ready(Err(e)) => return Poll::Ready(Err(e)),
93                Poll::Ready(Ok(0)) => Status::Done,
94                Poll::Ready(Ok(i)) => {
95                    this.bytes_sent += i;
96                    Status::Progressed
97                }
98                Poll::Pending => Status::Pending,
99            };
100
101            let dst_status = match forward_data(&mut this.dst, &mut this.src, cx) {
102                Poll::Ready(Err(e)) => return Poll::Ready(Err(e)),
103                Poll::Ready(Ok(0)) => Status::Done,
104                Poll::Ready(Ok(i)) => {
105                    this.bytes_sent += i;
106                    Status::Progressed
107                }
108                Poll::Pending => Status::Pending,
109            };
110
111            match (src_status, dst_status) {
112                // Both source and destination are done sending data.
113                (Status::Done, Status::Done) => return Poll::Ready(Ok(())),
114                // Either source or destination made progress.
115                (Status::Progressed, _) | (_, Status::Progressed) => {}
116                // Both are pending. Check if max circuit duration timer fired, otherwise return
117                // Poll::Pending.
118                (Status::Pending, Status::Pending) => break,
119                // One is done sending data, the other is pending. Check if timer fired, otherwise
120                // return Poll::Pending.
121                (Status::Pending, Status::Done) | (Status::Done, Status::Pending) => break,
122            }
123        }
124
125        if let Poll::Ready(()) = this.max_circuit_duration.poll_unpin(cx) {
126            return Poll::Ready(Err(io::ErrorKind::TimedOut.into()));
127        }
128
129        Poll::Pending
130    }
131}
132
133/// Forwards data from `source` to `destination`.
134///
135/// Returns `0` when done, i.e. `source` having reached EOF, returns number of bytes sent otherwise,
136/// thus indicating progress.
137fn forward_data<S: AsyncBufRead + Unpin, D: AsyncWrite + Unpin>(
138    mut src: &mut S,
139    mut dst: &mut D,
140    cx: &mut Context<'_>,
141) -> Poll<io::Result<u64>> {
142    let buffer = match Pin::new(&mut src).poll_fill_buf(cx)? {
143        Poll::Ready(buffer) => buffer,
144        Poll::Pending => {
145            let _ = Pin::new(&mut dst).poll_flush(cx)?;
146            return Poll::Pending;
147        }
148    };
149
150    if buffer.is_empty() {
151        ready!(Pin::new(&mut dst).poll_flush(cx))?;
152        ready!(Pin::new(&mut dst).poll_close(cx))?;
153        return Poll::Ready(Ok(0));
154    }
155
156    let i = ready!(Pin::new(dst).poll_write(cx, buffer))?;
157    if i == 0 {
158        return Poll::Ready(Err(io::ErrorKind::WriteZero.into()));
159    }
160    Pin::new(src).consume(i);
161
162    Poll::Ready(Ok(i.try_into().expect("usize to fit into u64.")))
163}
164
165#[cfg(test)]
166mod tests {
167    use std::io::ErrorKind;
168
169    use futures::{executor::block_on, io::BufWriter};
170    use quickcheck::QuickCheck;
171
172    use super::*;
173
174    #[test]
175    fn quickcheck() {
176        struct Connection {
177            read: Vec<u8>,
178            write: Vec<u8>,
179        }
180
181        impl AsyncWrite for Connection {
182            fn poll_write(
183                mut self: std::pin::Pin<&mut Self>,
184                cx: &mut Context<'_>,
185                buf: &[u8],
186            ) -> Poll<std::io::Result<usize>> {
187                Pin::new(&mut self.write).poll_write(cx, buf)
188            }
189
190            fn poll_flush(
191                mut self: std::pin::Pin<&mut Self>,
192                cx: &mut Context<'_>,
193            ) -> Poll<std::io::Result<()>> {
194                Pin::new(&mut self.write).poll_flush(cx)
195            }
196
197            fn poll_close(
198                mut self: std::pin::Pin<&mut Self>,
199                cx: &mut Context<'_>,
200            ) -> Poll<std::io::Result<()>> {
201                Pin::new(&mut self.write).poll_close(cx)
202            }
203        }
204
205        impl AsyncRead for Connection {
206            fn poll_read(
207                mut self: Pin<&mut Self>,
208                _cx: &mut Context<'_>,
209                buf: &mut [u8],
210            ) -> Poll<std::io::Result<usize>> {
211                let n = std::cmp::min(self.read.len(), buf.len());
212                buf[0..n].copy_from_slice(&self.read[0..n]);
213                self.read = self.read.split_off(n);
214                Poll::Ready(Ok(n))
215            }
216        }
217
218        fn prop(a: Vec<u8>, b: Vec<u8>, max_circuit_bytes: u64) {
219            let connection_a = Connection {
220                read: a.clone(),
221                write: Vec::new(),
222            };
223
224            let connection_b = Connection {
225                read: b.clone(),
226                write: Vec::new(),
227            };
228
229            let mut copy_future = CopyFuture::new(
230                connection_a,
231                connection_b,
232                Duration::from_secs(60),
233                max_circuit_bytes,
234            );
235
236            match block_on(&mut copy_future) {
237                Ok(()) => {
238                    assert_eq!(copy_future.src.into_inner().write, b);
239                    assert_eq!(copy_future.dst.into_inner().write, a);
240                }
241                Err(error) => {
242                    assert_eq!(error.kind(), ErrorKind::Other);
243                    assert_eq!(error.to_string(), "Max circuit bytes reached.");
244                    assert!(a.len() + b.len() > max_circuit_bytes as usize);
245                }
246            }
247        }
248
249        QuickCheck::new().quickcheck(prop as fn(_, _, _))
250    }
251
252    #[test]
253    fn max_circuit_duration() {
254        struct PendingConnection {}
255
256        impl AsyncWrite for PendingConnection {
257            fn poll_write(
258                self: std::pin::Pin<&mut Self>,
259                _cx: &mut Context<'_>,
260                _buf: &[u8],
261            ) -> Poll<std::io::Result<usize>> {
262                Poll::Pending
263            }
264
265            fn poll_flush(
266                self: std::pin::Pin<&mut Self>,
267                _cx: &mut Context<'_>,
268            ) -> Poll<std::io::Result<()>> {
269                Poll::Pending
270            }
271
272            fn poll_close(
273                self: std::pin::Pin<&mut Self>,
274                _cx: &mut Context<'_>,
275            ) -> Poll<std::io::Result<()>> {
276                Poll::Pending
277            }
278        }
279
280        impl AsyncRead for PendingConnection {
281            fn poll_read(
282                self: Pin<&mut Self>,
283                _cx: &mut Context<'_>,
284                _buf: &mut [u8],
285            ) -> Poll<std::io::Result<usize>> {
286                Poll::Pending
287            }
288        }
289
290        let copy_future = CopyFuture::new(
291            PendingConnection {},
292            PendingConnection {},
293            Duration::from_millis(1),
294            u64::MAX,
295        );
296
297        std::thread::sleep(Duration::from_millis(2));
298
299        let error =
300            block_on(copy_future).expect_err("Expect maximum circuit duration to be reached.");
301        assert_eq!(error.kind(), ErrorKind::TimedOut);
302    }
303
304    #[test]
305    fn forward_data_should_flush_on_pending_source() {
306        struct NeverEndingSource {
307            read: Vec<u8>,
308        }
309
310        impl AsyncRead for NeverEndingSource {
311            fn poll_read(
312                mut self: Pin<&mut Self>,
313                _cx: &mut Context<'_>,
314                buf: &mut [u8],
315            ) -> Poll<std::io::Result<usize>> {
316                if let Some(b) = self.read.pop() {
317                    buf[0] = b;
318                    return Poll::Ready(Ok(1));
319                }
320
321                Poll::Pending
322            }
323        }
324
325        struct RecordingDestination {
326            method_calls: Vec<Method>,
327        }
328
329        #[derive(Debug, PartialEq)]
330        enum Method {
331            Write(Vec<u8>),
332            Flush,
333            Close,
334        }
335
336        impl AsyncWrite for RecordingDestination {
337            fn poll_write(
338                mut self: std::pin::Pin<&mut Self>,
339                _cx: &mut Context<'_>,
340                buf: &[u8],
341            ) -> Poll<std::io::Result<usize>> {
342                self.method_calls.push(Method::Write(buf.to_vec()));
343                Poll::Ready(Ok(buf.len()))
344            }
345
346            fn poll_flush(
347                mut self: std::pin::Pin<&mut Self>,
348                _cx: &mut Context<'_>,
349            ) -> Poll<std::io::Result<()>> {
350                self.method_calls.push(Method::Flush);
351                Poll::Ready(Ok(()))
352            }
353
354            fn poll_close(
355                mut self: std::pin::Pin<&mut Self>,
356                _cx: &mut Context<'_>,
357            ) -> Poll<std::io::Result<()>> {
358                self.method_calls.push(Method::Close);
359                Poll::Ready(Ok(()))
360            }
361        }
362
363        // The source has two reads available, handing them out
364        // on `AsyncRead::poll_read` one by one.
365        let mut source = BufReader::new(NeverEndingSource { read: vec![1, 2] });
366
367        // The destination is wrapped by a `BufWriter` with a capacity of `3`, i.e. one larger than
368        // the available reads of the source. Without an explicit `AsyncWrite::poll_flush` the two
369        // reads would thus never make it to the destination,
370        // but instead be stuck in the buffer of the `BufWrite`.
371        let mut destination = BufWriter::with_capacity(
372            3,
373            RecordingDestination {
374                method_calls: vec![],
375            },
376        );
377
378        let mut cx = Context::from_waker(futures::task::noop_waker_ref());
379
380        assert!(
381            matches!(
382                forward_data(&mut source, &mut destination, &mut cx),
383                Poll::Ready(Ok(1)),
384            ),
385            "Expect `forward_data` to forward one read from the source to the wrapped destination."
386        );
387        assert_eq!(
388            destination.get_ref().method_calls.as_slice(), &[],
389            "Given that destination is wrapped with a `BufWrite`, the write doesn't (yet) make it to \
390            the destination. The source might have more data available, thus `forward_data` has not \
391            yet flushed.",
392        );
393
394        assert!(
395            matches!(
396                forward_data(&mut source, &mut destination, &mut cx),
397                Poll::Ready(Ok(1)),
398            ),
399            "Expect `forward_data` to forward one read from the source to the wrapped destination."
400        );
401        assert_eq!(
402            destination.get_ref().method_calls.as_slice(), &[],
403            "Given that destination is wrapped with a `BufWrite`, the write doesn't (yet) make it to \
404            the destination. The source might have more data available, thus `forward_data` has not \
405            yet flushed.",
406        );
407
408        assert!(
409            matches!(
410                forward_data(&mut source, &mut destination, &mut cx),
411                Poll::Pending,
412            ),
413            "The source has no more reads available, but does not close i.e. does not return \
414            `Poll::Ready(Ok(1))` but instead `Poll::Pending`. Thus `forward_data` returns \
415            `Poll::Pending` as well."
416        );
417        assert_eq!(
418            destination.get_ref().method_calls.as_slice(),
419            &[Method::Write(vec![2, 1]), Method::Flush],
420            "Given that source had no more reads, `forward_data` calls flush, thus instructing the \
421            `BufWriter` to flush the two buffered writes down to the destination."
422        );
423    }
424}