libp2p_pnet/crypt_writer.rs
1// Copyright 2019 Parity Technologies (UK) Ltd.
2//
3// Permission is hereby granted, free of charge, to any person obtaining a
4// copy of this software and associated documentation files (the "Software"),
5// to deal in the Software without restriction, including without limitation
6// the rights to use, copy, modify, merge, publish, distribute, sublicense,
7// and/or sell copies of the Software, and to permit persons to whom the
8// Software is furnished to do so, subject to the following conditions:
9//
10// The above copyright notice and this permission notice shall be included in
11// all copies or substantial portions of the Software.
12//
13// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
14// OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
16// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
18// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
19// DEALINGS IN THE SOFTWARE.
20
21use std::{fmt, pin::Pin};
22
23use futures::{
24 io::{self, AsyncWrite},
25 ready,
26 task::{Context, Poll},
27};
28use pin_project::pin_project;
29use salsa20::{cipher::StreamCipher, XSalsa20};
30
31/// A writer that encrypts and forwards to an inner writer
32#[pin_project]
33pub(crate) struct CryptWriter<W> {
34 #[pin]
35 inner: W,
36 buf: Vec<u8>,
37 cipher: XSalsa20,
38}
39
40impl<W: AsyncWrite> CryptWriter<W> {
41 /// Creates a new `CryptWriter` with the specified buffer capacity.
42 pub(crate) fn with_capacity(capacity: usize, inner: W, cipher: XSalsa20) -> CryptWriter<W> {
43 CryptWriter {
44 inner,
45 buf: Vec::with_capacity(capacity),
46 cipher,
47 }
48 }
49
50 /// Gets a pinned mutable reference to the inner writer.
51 ///
52 /// It is inadvisable to directly write to the inner writer.
53 pub(crate) fn get_pin_mut(self: Pin<&mut Self>) -> Pin<&mut W> {
54 self.project().inner
55 }
56}
57
58/// Write the contents of a [`Vec<u8>`] into an [`AsyncWrite`].
59///
60/// The handling 0 byte progress and the Interrupted error was taken from BufWriter in async_std.
61///
62/// If this fn returns Ready(Ok(())), the buffer has been completely flushed and is empty.
63fn poll_flush_buf<W: AsyncWrite>(
64 inner: &mut Pin<&mut W>,
65 buf: &mut Vec<u8>,
66 cx: &mut Context<'_>,
67) -> Poll<io::Result<()>> {
68 let mut ret = Poll::Ready(Ok(()));
69 let mut written = 0;
70 let len = buf.len();
71 while written < len {
72 match inner.as_mut().poll_write(cx, &buf[written..]) {
73 Poll::Ready(Ok(n)) => {
74 if n > 0 {
75 // we made progress, so try again
76 written += n;
77 } else {
78 // we got Ok but got no progress whatsoever,
79 // so bail out so we don't spin writing 0 bytes.
80 ret = Poll::Ready(Err(io::Error::new(
81 io::ErrorKind::WriteZero,
82 "Failed to write buffered data",
83 )));
84 break;
85 }
86 }
87 Poll::Ready(Err(e)) => {
88 // Interrupted is the only error that we consider to be recoverable by trying again
89 if e.kind() != io::ErrorKind::Interrupted {
90 // for any other error, don't try again
91 ret = Poll::Ready(Err(e));
92 break;
93 }
94 }
95 Poll::Pending => {
96 ret = Poll::Pending;
97 break;
98 }
99 }
100 }
101 if written > 0 {
102 buf.drain(..written);
103 }
104 if let Poll::Ready(Ok(())) = ret {
105 debug_assert!(buf.is_empty());
106 }
107 ret
108}
109
110impl<W: AsyncWrite> AsyncWrite for CryptWriter<W> {
111 fn poll_write(
112 self: Pin<&mut Self>,
113 cx: &mut Context<'_>,
114 buf: &[u8],
115 ) -> Poll<io::Result<usize>> {
116 let mut this = self.project();
117 // completely flush the buffer, returning pending if not possible
118 ready!(poll_flush_buf(&mut this.inner, this.buf, cx))?;
119 // if we get here, the buffer is empty
120 debug_assert!(this.buf.is_empty());
121 let res = Pin::new(&mut *this.buf).poll_write(cx, buf);
122 if let Poll::Ready(Ok(count)) = res {
123 this.cipher.apply_keystream(&mut this.buf[0..count]);
124 tracing::trace!(bytes=%count, "encrypted bytes");
125 } else {
126 debug_assert!(false);
127 };
128 // flush immediately afterwards, but if we get a pending we don't care
129 if let Poll::Ready(Err(e)) = poll_flush_buf(&mut this.inner, this.buf, cx) {
130 Poll::Ready(Err(e))
131 } else {
132 res
133 }
134 }
135
136 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
137 let mut this = self.project();
138 ready!(poll_flush_buf(&mut this.inner, this.buf, cx))?;
139 this.inner.poll_flush(cx)
140 }
141
142 fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
143 let mut this = self.project();
144 ready!(poll_flush_buf(&mut this.inner, this.buf, cx))?;
145 this.inner.poll_close(cx)
146 }
147}
148
149impl<W: AsyncWrite + fmt::Debug> fmt::Debug for CryptWriter<W> {
150 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
151 f.debug_struct("CryptWriter")
152 .field("writer", &self.inner)
153 .field("buf", &self.buf)
154 .finish()
155 }
156}