1use std::{fmt, io, sync::Arc};
22
23use futures_rustls::{rustls, TlsAcceptor, TlsConnector};
24
25#[derive(Clone)]
27pub struct Config {
28 pub(crate) client: TlsConnector,
29 pub(crate) server: Option<TlsAcceptor>,
30}
31
32impl fmt::Debug for Config {
33 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
34 f.write_str("Config")
35 }
36}
37
38pub struct PrivateKey(rustls::pki_types::PrivateKeyDer<'static>);
40
41impl PrivateKey {
42 pub fn new(bytes: Vec<u8>) -> Self {
44 PrivateKey(
45 rustls::pki_types::PrivateKeyDer::try_from(bytes)
46 .expect("unknown or invalid key format"),
47 )
48 }
49}
50
51impl Clone for PrivateKey {
52 fn clone(&self) -> Self {
53 Self(self.0.clone_key())
54 }
55}
56
57#[derive(Debug, Clone)]
59pub struct Certificate(rustls::pki_types::CertificateDer<'static>);
60
61impl Certificate {
62 pub fn new(bytes: Vec<u8>) -> Self {
64 Certificate(rustls::pki_types::CertificateDer::from(bytes))
65 }
66}
67
68impl Config {
69 pub fn new<I>(key: PrivateKey, certs: I) -> Result<Self, Error>
71 where
72 I: IntoIterator<Item = Certificate>,
73 {
74 let mut builder = Config::builder();
75 builder.server(key, certs)?;
76 Ok(builder.finish())
77 }
78
79 pub fn client() -> Self {
81 let provider = rustls::crypto::ring::default_provider();
82 let client = rustls::ClientConfig::builder_with_provider(provider.into())
83 .with_safe_default_protocol_versions()
84 .unwrap()
85 .with_root_certificates(client_root_store())
86 .with_no_client_auth();
87 Config {
88 client: Arc::new(client).into(),
89 server: None,
90 }
91 }
92
93 pub fn builder() -> Builder {
95 Builder {
96 client_root_store: client_root_store(),
97 server: None,
98 }
99 }
100}
101
102fn client_root_store() -> rustls::RootCertStore {
104 let mut client_root_store = rustls::RootCertStore::empty();
105 client_root_store.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
106 client_root_store
107}
108
109pub struct Builder {
111 client_root_store: rustls::RootCertStore,
112 server: Option<rustls::ServerConfig>,
113}
114
115impl Builder {
116 pub fn server<I>(&mut self, key: PrivateKey, certs: I) -> Result<&mut Self, Error>
118 where
119 I: IntoIterator<Item = Certificate>,
120 {
121 let certs = certs.into_iter().map(|c| c.0).collect();
122 let provider = rustls::crypto::ring::default_provider();
123 let server = rustls::ServerConfig::builder_with_provider(provider.into())
124 .with_safe_default_protocol_versions()
125 .unwrap()
126 .with_no_client_auth()
127 .with_single_cert(certs, key.0)
128 .map_err(|e| Error::Tls(Box::new(e)))?;
129 self.server = Some(server);
130 Ok(self)
131 }
132
133 pub fn add_trust(&mut self, cert: &Certificate) -> Result<&mut Self, Error> {
135 self.client_root_store
136 .add(cert.0.to_owned())
137 .map_err(|e| Error::Tls(Box::new(e)))?;
138 Ok(self)
139 }
140
141 pub fn finish(self) -> Config {
143 let provider = rustls::crypto::ring::default_provider();
144 let client = rustls::ClientConfig::builder_with_provider(provider.into())
145 .with_safe_default_protocol_versions()
146 .unwrap()
147 .with_root_certificates(self.client_root_store)
148 .with_no_client_auth();
149
150 Config {
151 client: Arc::new(client).into(),
152 server: self.server.map(|s| Arc::new(s).into()),
153 }
154 }
155}
156
157pub(crate) fn dns_name_ref(name: &str) -> Result<rustls::pki_types::ServerName<'static>, Error> {
158 rustls::pki_types::ServerName::try_from(String::from(name))
159 .map_err(|_| Error::InvalidDnsName(name.into()))
160}
161
162#[derive(Debug)]
166#[non_exhaustive]
167pub enum Error {
168 Io(io::Error),
170 Tls(Box<dyn std::error::Error + Send + Sync>),
172 InvalidDnsName(String),
174}
175
176impl fmt::Display for Error {
177 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
178 match self {
179 Error::Io(e) => write!(f, "i/o error: {e}"),
180 Error::Tls(e) => write!(f, "tls error: {e}"),
181 Error::InvalidDnsName(n) => write!(f, "invalid DNS name: {n}"),
182 }
183 }
184}
185
186impl std::error::Error for Error {
187 fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
188 match self {
189 Error::Io(e) => Some(e),
190 Error::Tls(e) => Some(&**e),
191 Error::InvalidDnsName(_) => None,
192 }
193 }
194}
195
196impl From<io::Error> for Error {
197 fn from(e: io::Error) -> Self {
198 Error::Io(e)
199 }
200}