libp2p_pnet/
crypt_writer.rs

1// Copyright 2019 Parity Technologies (UK) Ltd.
2//
3// Permission is hereby granted, free of charge, to any person obtaining a
4// copy of this software and associated documentation files (the "Software"),
5// to deal in the Software without restriction, including without limitation
6// the rights to use, copy, modify, merge, publish, distribute, sublicense,
7// and/or sell copies of the Software, and to permit persons to whom the
8// Software is furnished to do so, subject to the following conditions:
9//
10// The above copyright notice and this permission notice shall be included in
11// all copies or substantial portions of the Software.
12//
13// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
14// OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
16// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
18// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
19// DEALINGS IN THE SOFTWARE.
20
21use std::{fmt, pin::Pin};
22
23use futures::{
24    io::{self, AsyncWrite},
25    ready,
26    task::{Context, Poll},
27};
28use pin_project::pin_project;
29use salsa20::{cipher::StreamCipher, XSalsa20};
30
31/// A writer that encrypts and forwards to an inner writer
32#[pin_project]
33pub(crate) struct CryptWriter<W> {
34    #[pin]
35    inner: W,
36    buf: Vec<u8>,
37    cipher: XSalsa20,
38}
39
40impl<W: AsyncWrite> CryptWriter<W> {
41    /// Creates a new `CryptWriter` with the specified buffer capacity.
42    pub(crate) fn with_capacity(capacity: usize, inner: W, cipher: XSalsa20) -> CryptWriter<W> {
43        CryptWriter {
44            inner,
45            buf: Vec::with_capacity(capacity),
46            cipher,
47        }
48    }
49
50    /// Gets a pinned mutable reference to the inner writer.
51    ///
52    /// It is inadvisable to directly write to the inner writer.
53    pub(crate) fn get_pin_mut(self: Pin<&mut Self>) -> Pin<&mut W> {
54        self.project().inner
55    }
56}
57
58/// Write the contents of a [`Vec<u8>`] into an [`AsyncWrite`].
59///
60/// The handling 0 byte progress and the Interrupted error was taken from BufWriter in async_std.
61///
62/// If this fn returns Ready(Ok(())), the buffer has been completely flushed and is empty.
63fn poll_flush_buf<W: AsyncWrite>(
64    inner: &mut Pin<&mut W>,
65    buf: &mut Vec<u8>,
66    cx: &mut Context<'_>,
67) -> Poll<io::Result<()>> {
68    let mut ret = Poll::Ready(Ok(()));
69    let mut written = 0;
70    let len = buf.len();
71    while written < len {
72        match inner.as_mut().poll_write(cx, &buf[written..]) {
73            Poll::Ready(Ok(n)) => {
74                if n > 0 {
75                    // we made progress, so try again
76                    written += n;
77                } else {
78                    // we got Ok but got no progress whatsoever,
79                    // so bail out so we don't spin writing 0 bytes.
80                    ret = Poll::Ready(Err(io::Error::new(
81                        io::ErrorKind::WriteZero,
82                        "Failed to write buffered data",
83                    )));
84                    break;
85                }
86            }
87            Poll::Ready(Err(e)) => {
88                // Interrupted is the only error that we consider to be recoverable by trying again
89                if e.kind() != io::ErrorKind::Interrupted {
90                    // for any other error, don't try again
91                    ret = Poll::Ready(Err(e));
92                    break;
93                }
94            }
95            Poll::Pending => {
96                ret = Poll::Pending;
97                break;
98            }
99        }
100    }
101    if written > 0 {
102        buf.drain(..written);
103    }
104    if let Poll::Ready(Ok(())) = ret {
105        debug_assert!(buf.is_empty());
106    }
107    ret
108}
109
110impl<W: AsyncWrite> AsyncWrite for CryptWriter<W> {
111    fn poll_write(
112        self: Pin<&mut Self>,
113        cx: &mut Context<'_>,
114        buf: &[u8],
115    ) -> Poll<io::Result<usize>> {
116        let mut this = self.project();
117        // completely flush the buffer, returning pending if not possible
118        ready!(poll_flush_buf(&mut this.inner, this.buf, cx))?;
119        // if we get here, the buffer is empty
120        debug_assert!(this.buf.is_empty());
121        let res = Pin::new(&mut *this.buf).poll_write(cx, buf);
122        if let Poll::Ready(Ok(count)) = res {
123            this.cipher.apply_keystream(&mut this.buf[0..count]);
124            tracing::trace!(bytes=%count, "encrypted bytes");
125        } else {
126            debug_assert!(false);
127        };
128        // flush immediately afterwards, but if we get a pending we don't care
129        if let Poll::Ready(Err(e)) = poll_flush_buf(&mut this.inner, this.buf, cx) {
130            Poll::Ready(Err(e))
131        } else {
132            res
133        }
134    }
135
136    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
137        let mut this = self.project();
138        ready!(poll_flush_buf(&mut this.inner, this.buf, cx))?;
139        this.inner.poll_flush(cx)
140    }
141
142    fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
143        let mut this = self.project();
144        ready!(poll_flush_buf(&mut this.inner, this.buf, cx))?;
145        this.inner.poll_close(cx)
146    }
147}
148
149impl<W: AsyncWrite + fmt::Debug> fmt::Debug for CryptWriter<W> {
150    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
151        f.debug_struct("CryptWriter")
152            .field("writer", &self.inner)
153            .field("buf", &self.buf)
154            .finish()
155    }
156}