1use std::{fmt, io, sync::Arc};
22
23use futures_rustls::{rustls, TlsAcceptor, TlsConnector};
24
25#[derive(Clone)]
27pub struct Config {
28 pub(crate) client: TlsConnector,
29 pub(crate) server: Option<TlsAcceptor>,
30}
31
32impl fmt::Debug for Config {
33 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
34 f.write_str("Config")
35 }
36}
37
38pub struct PrivateKey(rustls::pki_types::PrivateKeyDer<'static>);
40
41impl PrivateKey {
42 pub fn new(bytes: Vec<u8>) -> Self {
44 PrivateKey(
45 rustls::pki_types::PrivateKeyDer::try_from(bytes)
46 .expect("unknown or invalid key format"),
47 )
48 }
49}
50
51impl Clone for PrivateKey {
52 fn clone(&self) -> Self {
53 Self(self.0.clone_key())
54 }
55}
56
57#[derive(Debug, Clone)]
59pub struct Certificate(rustls::pki_types::CertificateDer<'static>);
60
61impl Certificate {
62 pub fn new(bytes: Vec<u8>) -> Self {
64 Certificate(rustls::pki_types::CertificateDer::from(bytes))
65 }
66}
67
68impl Config {
69 pub fn new<I>(key: PrivateKey, certs: I) -> Result<Self, Error>
71 where
72 I: IntoIterator<Item = Certificate>,
73 {
74 let mut builder = Config::builder();
75 builder.server(key, certs)?;
76 Ok(builder.finish())
77 }
78
79 pub fn client() -> Self {
81 let provider = rustls::crypto::ring::default_provider();
82 let client = rustls::ClientConfig::builder_with_provider(provider.into())
83 .with_safe_default_protocol_versions()
84 .unwrap()
85 .with_root_certificates(client_root_store())
86 .with_no_client_auth();
87 Config {
88 client: Arc::new(client).into(),
89 server: None,
90 }
91 }
92
93 pub fn builder() -> Builder {
95 Builder {
96 client_root_store: client_root_store(),
97 server: None,
98 }
99 }
100}
101
102fn client_root_store() -> rustls::RootCertStore {
104 let mut client_root_store = rustls::RootCertStore::empty();
105 client_root_store.extend(webpki_roots::TLS_SERVER_ROOTS.iter().map(|ta| {
106 rustls::pki_types::TrustAnchor {
107 subject: ta.subject.into(),
108 subject_public_key_info: ta.spki.into(),
109 name_constraints: ta.name_constraints.map(|v| v.into()),
110 }
111 }));
112 client_root_store
113}
114
115pub struct Builder {
117 client_root_store: rustls::RootCertStore,
118 server: Option<rustls::ServerConfig>,
119}
120
121impl Builder {
122 pub fn server<I>(&mut self, key: PrivateKey, certs: I) -> Result<&mut Self, Error>
124 where
125 I: IntoIterator<Item = Certificate>,
126 {
127 let certs = certs.into_iter().map(|c| c.0).collect();
128 let provider = rustls::crypto::ring::default_provider();
129 let server = rustls::ServerConfig::builder_with_provider(provider.into())
130 .with_safe_default_protocol_versions()
131 .unwrap()
132 .with_no_client_auth()
133 .with_single_cert(certs, key.0)
134 .map_err(|e| Error::Tls(Box::new(e)))?;
135 self.server = Some(server);
136 Ok(self)
137 }
138
139 pub fn add_trust(&mut self, cert: &Certificate) -> Result<&mut Self, Error> {
141 self.client_root_store
142 .add(cert.0.to_owned())
143 .map_err(|e| Error::Tls(Box::new(e)))?;
144 Ok(self)
145 }
146
147 pub fn finish(self) -> Config {
149 let provider = rustls::crypto::ring::default_provider();
150 let client = rustls::ClientConfig::builder_with_provider(provider.into())
151 .with_safe_default_protocol_versions()
152 .unwrap()
153 .with_root_certificates(self.client_root_store)
154 .with_no_client_auth();
155
156 Config {
157 client: Arc::new(client).into(),
158 server: self.server.map(|s| Arc::new(s).into()),
159 }
160 }
161}
162
163pub(crate) fn dns_name_ref(name: &str) -> Result<rustls::pki_types::ServerName<'static>, Error> {
164 rustls::pki_types::ServerName::try_from(String::from(name))
165 .map_err(|_| Error::InvalidDnsName(name.into()))
166}
167
168#[derive(Debug)]
172#[non_exhaustive]
173pub enum Error {
174 Io(io::Error),
176 Tls(Box<dyn std::error::Error + Send + Sync>),
178 InvalidDnsName(String),
180}
181
182impl fmt::Display for Error {
183 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
184 match self {
185 Error::Io(e) => write!(f, "i/o error: {e}"),
186 Error::Tls(e) => write!(f, "tls error: {e}"),
187 Error::InvalidDnsName(n) => write!(f, "invalid DNS name: {n}"),
188 }
189 }
190}
191
192impl std::error::Error for Error {
193 fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
194 match self {
195 Error::Io(e) => Some(e),
196 Error::Tls(e) => Some(&**e),
197 Error::InvalidDnsName(_) => None,
198 }
199 }
200}
201
202impl From<io::Error> for Error {
203 fn from(e: io::Error) -> Self {
204 Error::Io(e)
205 }
206}