libp2p_websocket/
tls.rs

1// Copyright 2019 Parity Technologies (UK) Ltd.
2//
3// Permission is hereby granted, free of charge, to any person obtaining a
4// copy of this software and associated documentation files (the "Software"),
5// to deal in the Software without restriction, including without limitation
6// the rights to use, copy, modify, merge, publish, distribute, sublicense,
7// and/or sell copies of the Software, and to permit persons to whom the
8// Software is furnished to do so, subject to the following conditions:
9//
10// The above copyright notice and this permission notice shall be included in
11// all copies or substantial portions of the Software.
12//
13// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
14// OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
16// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
18// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
19// DEALINGS IN THE SOFTWARE.
20
21use std::{fmt, io, sync::Arc};
22
23use futures_rustls::{rustls, TlsAcceptor, TlsConnector};
24
25/// TLS configuration.
26#[derive(Clone)]
27pub struct Config {
28    pub(crate) client: TlsConnector,
29    pub(crate) server: Option<TlsAcceptor>,
30}
31
32impl fmt::Debug for Config {
33    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
34        f.write_str("Config")
35    }
36}
37
38/// Private key, DER-encoded ASN.1 in either PKCS#8 or PKCS#1 format.
39pub struct PrivateKey(rustls::pki_types::PrivateKeyDer<'static>);
40
41impl PrivateKey {
42    /// Assert the given bytes are DER-encoded ASN.1 in either PKCS#8 or PKCS#1 format.
43    pub fn new(bytes: Vec<u8>) -> Self {
44        PrivateKey(
45            rustls::pki_types::PrivateKeyDer::try_from(bytes)
46                .expect("unknown or invalid key format"),
47        )
48    }
49}
50
51impl Clone for PrivateKey {
52    fn clone(&self) -> Self {
53        Self(self.0.clone_key())
54    }
55}
56
57/// Certificate, DER-encoded X.509 format.
58#[derive(Debug, Clone)]
59pub struct Certificate(rustls::pki_types::CertificateDer<'static>);
60
61impl Certificate {
62    /// Assert the given bytes are in DER-encoded X.509 format.
63    pub fn new(bytes: Vec<u8>) -> Self {
64        Certificate(rustls::pki_types::CertificateDer::from(bytes))
65    }
66}
67
68impl Config {
69    /// Create a new TLS configuration with the given server key and certificate chain.
70    pub fn new<I>(key: PrivateKey, certs: I) -> Result<Self, Error>
71    where
72        I: IntoIterator<Item = Certificate>,
73    {
74        let mut builder = Config::builder();
75        builder.server(key, certs)?;
76        Ok(builder.finish())
77    }
78
79    /// Create a client-only configuration.
80    pub fn client() -> Self {
81        let provider = rustls::crypto::ring::default_provider();
82        let client = rustls::ClientConfig::builder_with_provider(provider.into())
83            .with_safe_default_protocol_versions()
84            .unwrap()
85            .with_root_certificates(client_root_store())
86            .with_no_client_auth();
87        Config {
88            client: Arc::new(client).into(),
89            server: None,
90        }
91    }
92
93    /// Create a new TLS configuration builder.
94    pub fn builder() -> Builder {
95        Builder {
96            client_root_store: client_root_store(),
97            server: None,
98        }
99    }
100}
101
102/// Setup the rustls client configuration.
103fn client_root_store() -> rustls::RootCertStore {
104    let mut client_root_store = rustls::RootCertStore::empty();
105    client_root_store.extend(webpki_roots::TLS_SERVER_ROOTS.iter().map(|ta| {
106        rustls::pki_types::TrustAnchor {
107            subject: ta.subject.into(),
108            subject_public_key_info: ta.spki.into(),
109            name_constraints: ta.name_constraints.map(|v| v.into()),
110        }
111    }));
112    client_root_store
113}
114
115/// TLS configuration builder.
116pub struct Builder {
117    client_root_store: rustls::RootCertStore,
118    server: Option<rustls::ServerConfig>,
119}
120
121impl Builder {
122    /// Set server key and certificate chain.
123    pub fn server<I>(&mut self, key: PrivateKey, certs: I) -> Result<&mut Self, Error>
124    where
125        I: IntoIterator<Item = Certificate>,
126    {
127        let certs = certs.into_iter().map(|c| c.0).collect();
128        let provider = rustls::crypto::ring::default_provider();
129        let server = rustls::ServerConfig::builder_with_provider(provider.into())
130            .with_safe_default_protocol_versions()
131            .unwrap()
132            .with_no_client_auth()
133            .with_single_cert(certs, key.0)
134            .map_err(|e| Error::Tls(Box::new(e)))?;
135        self.server = Some(server);
136        Ok(self)
137    }
138
139    /// Add an additional trust anchor.
140    pub fn add_trust(&mut self, cert: &Certificate) -> Result<&mut Self, Error> {
141        self.client_root_store
142            .add(cert.0.to_owned())
143            .map_err(|e| Error::Tls(Box::new(e)))?;
144        Ok(self)
145    }
146
147    /// Finish configuration.
148    pub fn finish(self) -> Config {
149        let provider = rustls::crypto::ring::default_provider();
150        let client = rustls::ClientConfig::builder_with_provider(provider.into())
151            .with_safe_default_protocol_versions()
152            .unwrap()
153            .with_root_certificates(self.client_root_store)
154            .with_no_client_auth();
155
156        Config {
157            client: Arc::new(client).into(),
158            server: self.server.map(|s| Arc::new(s).into()),
159        }
160    }
161}
162
163pub(crate) fn dns_name_ref(name: &str) -> Result<rustls::pki_types::ServerName<'static>, Error> {
164    rustls::pki_types::ServerName::try_from(String::from(name))
165        .map_err(|_| Error::InvalidDnsName(name.into()))
166}
167
168// Error //////////////////////////////////////////////////////////////////////////////////////////
169
170/// TLS related errors.
171#[derive(Debug)]
172#[non_exhaustive]
173pub enum Error {
174    /// An underlying I/O error.
175    Io(io::Error),
176    /// Actual TLS error.
177    Tls(Box<dyn std::error::Error + Send + Sync>),
178    /// The DNS name was invalid.
179    InvalidDnsName(String),
180}
181
182impl fmt::Display for Error {
183    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
184        match self {
185            Error::Io(e) => write!(f, "i/o error: {e}"),
186            Error::Tls(e) => write!(f, "tls error: {e}"),
187            Error::InvalidDnsName(n) => write!(f, "invalid DNS name: {n}"),
188        }
189    }
190}
191
192impl std::error::Error for Error {
193    fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
194        match self {
195            Error::Io(e) => Some(e),
196            Error::Tls(e) => Some(&**e),
197            Error::InvalidDnsName(_) => None,
198        }
199    }
200}
201
202impl From<io::Error> for Error {
203    fn from(e: io::Error) -> Self {
204        Error::Io(e)
205    }
206}