1#![cfg_attr(docsrs, feature(doc_cfg, doc_auto_cfg))]
27
28mod crypt_writer;
29use std::{
30 error,
31 fmt::{self, Write},
32 io,
33 io::Error as IoError,
34 num::ParseIntError,
35 pin::Pin,
36 str::FromStr,
37 task::{Context, Poll},
38};
39
40use crypt_writer::CryptWriter;
41use futures::prelude::*;
42use pin_project::pin_project;
43use rand::RngCore;
44use salsa20::{
45 cipher::{KeyIvInit, StreamCipher},
46 Salsa20, XSalsa20,
47};
48use sha3::{digest::ExtendableOutput, Shake128};
49
50const KEY_SIZE: usize = 32;
51const NONCE_SIZE: usize = 24;
52const WRITE_BUFFER_SIZE: usize = 1024;
53const FINGERPRINT_SIZE: usize = 16;
54
55#[derive(Copy, Clone, PartialEq, Eq)]
57pub struct PreSharedKey([u8; KEY_SIZE]);
58
59impl PreSharedKey {
60 pub fn new(data: [u8; KEY_SIZE]) -> Self {
62 Self(data)
63 }
64
65 pub fn fingerprint(&self) -> Fingerprint {
71 use std::io::{Read, Write};
72 let mut enc = [0u8; 64];
73 let nonce: [u8; 8] = *b"finprint";
74 let mut out = [0u8; 16];
75 let mut cipher = Salsa20::new(&self.0.into(), &nonce.into());
76 cipher.apply_keystream(&mut enc);
77 let mut hasher = Shake128::default();
78 hasher.write_all(&enc).expect("shake128 failed");
79 hasher
80 .finalize_xof()
81 .read_exact(&mut out)
82 .expect("shake128 failed");
83 Fingerprint(out)
84 }
85}
86
87fn parse_hex_key(s: &str) -> Result<[u8; KEY_SIZE], KeyParseError> {
88 if s.len() == KEY_SIZE * 2 {
89 let mut r = [0u8; KEY_SIZE];
90 for i in 0..KEY_SIZE {
91 r[i] = u8::from_str_radix(&s[i * 2..i * 2 + 2], 16)
92 .map_err(KeyParseError::InvalidKeyChar)?;
93 }
94 Ok(r)
95 } else {
96 Err(KeyParseError::InvalidKeyLength)
97 }
98}
99
100fn to_hex(bytes: &[u8]) -> String {
101 let mut hex = String::with_capacity(bytes.len() * 2);
102
103 for byte in bytes {
104 write!(hex, "{byte:02x}").expect("Can't fail on writing to string");
105 }
106
107 hex
108}
109
110impl FromStr for PreSharedKey {
114 type Err = KeyParseError;
115
116 fn from_str(s: &str) -> Result<Self, Self::Err> {
117 if let [keytype, encoding, key] = *s.lines().take(3).collect::<Vec<_>>().as_slice() {
118 if keytype != "/key/swarm/psk/1.0.0/" {
119 return Err(KeyParseError::InvalidKeyType);
120 }
121 if encoding != "/base16/" {
122 return Err(KeyParseError::InvalidKeyEncoding);
123 }
124 parse_hex_key(key.trim_end()).map(PreSharedKey)
125 } else {
126 Err(KeyParseError::InvalidKeyFile)
127 }
128 }
129}
130
131impl fmt::Debug for PreSharedKey {
132 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
133 f.debug_tuple("PreSharedKey")
134 .field(&to_hex(&self.0))
135 .finish()
136 }
137}
138
139impl fmt::Display for PreSharedKey {
141 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
142 writeln!(f, "/key/swarm/psk/1.0.0/")?;
143 writeln!(f, "/base16/")?;
144 writeln!(f, "{}", to_hex(&self.0))
145 }
146}
147
148#[derive(Copy, Clone, PartialEq, Eq)]
150pub struct Fingerprint([u8; FINGERPRINT_SIZE]);
151
152impl fmt::Display for Fingerprint {
154 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
155 write!(f, "{}", to_hex(&self.0))
156 }
157}
158
159#[derive(Clone, Debug, PartialEq, Eq)]
161#[allow(clippy::enum_variant_names)] pub enum KeyParseError {
163 InvalidKeyFile,
165 InvalidKeyType,
167 InvalidKeyEncoding,
169 InvalidKeyLength,
171 InvalidKeyChar(ParseIntError),
173}
174
175impl fmt::Display for KeyParseError {
176 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
177 write!(f, "{self:?}")
178 }
179}
180
181impl error::Error for KeyParseError {
182 fn source(&self) -> Option<&(dyn error::Error + 'static)> {
183 match *self {
184 KeyParseError::InvalidKeyChar(ref err) => Some(err),
185 _ => None,
186 }
187 }
188}
189
190#[derive(Debug, Copy, Clone)]
192pub struct PnetConfig {
193 key: PreSharedKey,
195}
196impl PnetConfig {
197 pub fn new(key: PreSharedKey) -> Self {
198 Self { key }
199 }
200
201 pub async fn handshake<TSocket>(
206 self,
207 mut socket: TSocket,
208 ) -> Result<PnetOutput<TSocket>, PnetError>
209 where
210 TSocket: AsyncRead + AsyncWrite + Send + Unpin + 'static,
211 {
212 tracing::trace!("exchanging nonces");
213 let mut local_nonce = [0u8; NONCE_SIZE];
214 let mut remote_nonce = [0u8; NONCE_SIZE];
215 rand::thread_rng().fill_bytes(&mut local_nonce);
216 socket
217 .write_all(&local_nonce)
218 .await
219 .map_err(PnetError::HandshakeError)?;
220 socket.flush().await?;
221 socket
222 .read_exact(&mut remote_nonce)
223 .await
224 .map_err(PnetError::HandshakeError)?;
225 tracing::trace!("setting up ciphers");
226 let write_cipher = XSalsa20::new(&self.key.0.into(), &local_nonce.into());
227 let read_cipher = XSalsa20::new(&self.key.0.into(), &remote_nonce.into());
228 Ok(PnetOutput::new(socket, write_cipher, read_cipher))
229 }
230}
231
232#[pin_project]
235pub struct PnetOutput<S> {
236 #[pin]
237 inner: CryptWriter<S>,
238 read_cipher: XSalsa20,
239}
240
241impl<S: AsyncRead + AsyncWrite> PnetOutput<S> {
242 fn new(inner: S, write_cipher: XSalsa20, read_cipher: XSalsa20) -> Self {
243 Self {
244 inner: CryptWriter::with_capacity(WRITE_BUFFER_SIZE, inner, write_cipher),
245 read_cipher,
246 }
247 }
248}
249
250impl<S: AsyncRead + AsyncWrite> AsyncRead for PnetOutput<S> {
251 fn poll_read(
252 self: Pin<&mut Self>,
253 cx: &mut Context<'_>,
254 buf: &mut [u8],
255 ) -> Poll<Result<usize, io::Error>> {
256 let this = self.project();
257 let result = this.inner.get_pin_mut().poll_read(cx, buf);
258 if let Poll::Ready(Ok(size)) = &result {
259 tracing::trace!(bytes=%size, "read bytes");
260 this.read_cipher.apply_keystream(&mut buf[..*size]);
261 tracing::trace!(bytes=%size, "decrypted bytes");
262 }
263 result
264 }
265}
266
267impl<S: AsyncRead + AsyncWrite> AsyncWrite for PnetOutput<S> {
268 fn poll_write(
269 self: Pin<&mut Self>,
270 cx: &mut Context<'_>,
271 buf: &[u8],
272 ) -> Poll<Result<usize, io::Error>> {
273 self.project().inner.poll_write(cx, buf)
274 }
275
276 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
277 self.project().inner.poll_flush(cx)
278 }
279
280 fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
281 self.project().inner.poll_close(cx)
282 }
283}
284
285#[derive(Debug)]
287pub enum PnetError {
288 HandshakeError(IoError),
290 IoError(IoError),
292}
293
294impl From<IoError> for PnetError {
295 #[inline]
296 fn from(err: IoError) -> PnetError {
297 PnetError::IoError(err)
298 }
299}
300
301impl error::Error for PnetError {
302 fn cause(&self) -> Option<&dyn error::Error> {
303 match *self {
304 PnetError::HandshakeError(ref err) => Some(err),
305 PnetError::IoError(ref err) => Some(err),
306 }
307 }
308}
309
310impl fmt::Display for PnetError {
311 #[inline]
312 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> {
313 match self {
314 PnetError::HandshakeError(e) => write!(f, "Handshake error: {e}"),
315 PnetError::IoError(e) => write!(f, "I/O error: {e}"),
316 }
317 }
318}
319
320#[cfg(test)]
321mod tests {
322 use quickcheck::*;
323
324 use super::*;
325
326 impl Arbitrary for PreSharedKey {
327 fn arbitrary(g: &mut Gen) -> PreSharedKey {
328 let key = core::array::from_fn(|_| u8::arbitrary(g));
329 PreSharedKey(key)
330 }
331 }
332
333 #[test]
334 fn psk_tostring_parse() {
335 fn prop(key: PreSharedKey) -> bool {
336 let text = key.to_string();
337 text.parse::<PreSharedKey>()
338 .map(|res| res == key)
339 .unwrap_or(false)
340 }
341 QuickCheck::new()
342 .tests(10)
343 .quickcheck(prop as fn(PreSharedKey) -> _);
344 }
345
346 #[test]
347 fn psk_parse_failure() {
348 use KeyParseError::*;
349 assert_eq!("".parse::<PreSharedKey>().unwrap_err(), InvalidKeyFile);
350 assert_eq!(
351 "a\nb\nc".parse::<PreSharedKey>().unwrap_err(),
352 InvalidKeyType
353 );
354 assert_eq!(
355 "/key/swarm/psk/1.0.0/\nx\ny"
356 .parse::<PreSharedKey>()
357 .unwrap_err(),
358 InvalidKeyEncoding
359 );
360 assert_eq!(
361 "/key/swarm/psk/1.0.0/\n/base16/\ny"
362 .parse::<PreSharedKey>()
363 .unwrap_err(),
364 InvalidKeyLength
365 );
366 }
367
368 #[test]
369 fn fingerprint() {
370 let key = "/key/swarm/psk/1.0.0/\n/base16/\n6189c5cf0b87fb800c1a9feeda73c6ab5e998db48fb9e6a978575c770ceef683".parse::<PreSharedKey>().unwrap();
372 let expected = "45fc986bbc9388a11d939df26f730f0c";
373 let actual = key.fingerprint().to_string();
374 assert_eq!(expected, actual);
375 }
376}