libp2p_websocket/
tls.rs

1// Copyright 2019 Parity Technologies (UK) Ltd.
2//
3// Permission is hereby granted, free of charge, to any person obtaining a
4// copy of this software and associated documentation files (the "Software"),
5// to deal in the Software without restriction, including without limitation
6// the rights to use, copy, modify, merge, publish, distribute, sublicense,
7// and/or sell copies of the Software, and to permit persons to whom the
8// Software is furnished to do so, subject to the following conditions:
9//
10// The above copyright notice and this permission notice shall be included in
11// all copies or substantial portions of the Software.
12//
13// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
14// OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
16// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
18// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
19// DEALINGS IN THE SOFTWARE.
20
21use std::{fmt, io, sync::Arc};
22
23use futures_rustls::{rustls, TlsAcceptor, TlsConnector};
24
25/// TLS configuration.
26#[derive(Clone)]
27pub struct Config {
28    pub(crate) client: TlsConnector,
29    pub(crate) server: Option<TlsAcceptor>,
30}
31
32impl fmt::Debug for Config {
33    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
34        f.write_str("Config")
35    }
36}
37
38/// Private key, DER-encoded ASN.1 in either PKCS#8 or PKCS#1 format.
39pub struct PrivateKey(rustls::pki_types::PrivateKeyDer<'static>);
40
41impl PrivateKey {
42    /// Assert the given bytes are DER-encoded ASN.1 in either PKCS#8 or PKCS#1 format.
43    pub fn new(bytes: Vec<u8>) -> Self {
44        PrivateKey(
45            rustls::pki_types::PrivateKeyDer::try_from(bytes)
46                .expect("unknown or invalid key format"),
47        )
48    }
49}
50
51impl Clone for PrivateKey {
52    fn clone(&self) -> Self {
53        Self(self.0.clone_key())
54    }
55}
56
57/// Certificate, DER-encoded X.509 format.
58#[derive(Debug, Clone)]
59pub struct Certificate(rustls::pki_types::CertificateDer<'static>);
60
61impl Certificate {
62    /// Assert the given bytes are in DER-encoded X.509 format.
63    pub fn new(bytes: Vec<u8>) -> Self {
64        Certificate(rustls::pki_types::CertificateDer::from(bytes))
65    }
66}
67
68impl Config {
69    /// Create a new TLS configuration with the given server key and certificate chain.
70    pub fn new<I>(key: PrivateKey, certs: I) -> Result<Self, Error>
71    where
72        I: IntoIterator<Item = Certificate>,
73    {
74        let mut builder = Config::builder();
75        builder.server(key, certs)?;
76        Ok(builder.finish())
77    }
78
79    /// Create a client-only configuration.
80    pub fn client() -> Self {
81        let provider = rustls::crypto::ring::default_provider();
82        let client = rustls::ClientConfig::builder_with_provider(provider.into())
83            .with_safe_default_protocol_versions()
84            .unwrap()
85            .with_root_certificates(client_root_store())
86            .with_no_client_auth();
87        Config {
88            client: Arc::new(client).into(),
89            server: None,
90        }
91    }
92
93    /// Create a new TLS configuration builder.
94    pub fn builder() -> Builder {
95        Builder {
96            client_root_store: client_root_store(),
97            server: None,
98        }
99    }
100}
101
102/// Setup the rustls client configuration.
103fn client_root_store() -> rustls::RootCertStore {
104    let mut client_root_store = rustls::RootCertStore::empty();
105    client_root_store.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
106    client_root_store
107}
108
109/// TLS configuration builder.
110pub struct Builder {
111    client_root_store: rustls::RootCertStore,
112    server: Option<rustls::ServerConfig>,
113}
114
115impl Builder {
116    /// Set server key and certificate chain.
117    pub fn server<I>(&mut self, key: PrivateKey, certs: I) -> Result<&mut Self, Error>
118    where
119        I: IntoIterator<Item = Certificate>,
120    {
121        let certs = certs.into_iter().map(|c| c.0).collect();
122        let provider = rustls::crypto::ring::default_provider();
123        let server = rustls::ServerConfig::builder_with_provider(provider.into())
124            .with_safe_default_protocol_versions()
125            .unwrap()
126            .with_no_client_auth()
127            .with_single_cert(certs, key.0)
128            .map_err(|e| Error::Tls(Box::new(e)))?;
129        self.server = Some(server);
130        Ok(self)
131    }
132
133    /// Add an additional trust anchor.
134    pub fn add_trust(&mut self, cert: &Certificate) -> Result<&mut Self, Error> {
135        self.client_root_store
136            .add(cert.0.to_owned())
137            .map_err(|e| Error::Tls(Box::new(e)))?;
138        Ok(self)
139    }
140
141    /// Finish configuration.
142    pub fn finish(self) -> Config {
143        let provider = rustls::crypto::ring::default_provider();
144        let client = rustls::ClientConfig::builder_with_provider(provider.into())
145            .with_safe_default_protocol_versions()
146            .unwrap()
147            .with_root_certificates(self.client_root_store)
148            .with_no_client_auth();
149
150        Config {
151            client: Arc::new(client).into(),
152            server: self.server.map(|s| Arc::new(s).into()),
153        }
154    }
155}
156
157pub(crate) fn dns_name_ref(name: &str) -> Result<rustls::pki_types::ServerName<'static>, Error> {
158    rustls::pki_types::ServerName::try_from(String::from(name))
159        .map_err(|_| Error::InvalidDnsName(name.into()))
160}
161
162// Error //////////////////////////////////////////////////////////////////////////////////////////
163
164/// TLS related errors.
165#[derive(Debug)]
166#[non_exhaustive]
167pub enum Error {
168    /// An underlying I/O error.
169    Io(io::Error),
170    /// Actual TLS error.
171    Tls(Box<dyn std::error::Error + Send + Sync>),
172    /// The DNS name was invalid.
173    InvalidDnsName(String),
174}
175
176impl fmt::Display for Error {
177    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
178        match self {
179            Error::Io(e) => write!(f, "i/o error: {e}"),
180            Error::Tls(e) => write!(f, "tls error: {e}"),
181            Error::InvalidDnsName(n) => write!(f, "invalid DNS name: {n}"),
182        }
183    }
184}
185
186impl std::error::Error for Error {
187    fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
188        match self {
189            Error::Io(e) => Some(e),
190            Error::Tls(e) => Some(&**e),
191            Error::InvalidDnsName(_) => None,
192        }
193    }
194}
195
196impl From<io::Error> for Error {
197    fn from(e: io::Error) -> Self {
198        Error::Io(e)
199    }
200}