use std::{fmt, io, sync::Arc};
use futures_rustls::{rustls, TlsAcceptor, TlsConnector};
#[derive(Clone)]
pub struct Config {
pub(crate) client: TlsConnector,
pub(crate) server: Option<TlsAcceptor>,
}
impl fmt::Debug for Config {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str("Config")
}
}
pub struct PrivateKey(rustls::pki_types::PrivateKeyDer<'static>);
impl PrivateKey {
pub fn new(bytes: Vec<u8>) -> Self {
PrivateKey(
rustls::pki_types::PrivateKeyDer::try_from(bytes)
.expect("unknown or invalid key format"),
)
}
}
impl Clone for PrivateKey {
fn clone(&self) -> Self {
Self(self.0.clone_key())
}
}
#[derive(Debug, Clone)]
pub struct Certificate(rustls::pki_types::CertificateDer<'static>);
impl Certificate {
pub fn new(bytes: Vec<u8>) -> Self {
Certificate(rustls::pki_types::CertificateDer::from(bytes))
}
}
impl Config {
pub fn new<I>(key: PrivateKey, certs: I) -> Result<Self, Error>
where
I: IntoIterator<Item = Certificate>,
{
let mut builder = Config::builder();
builder.server(key, certs)?;
Ok(builder.finish())
}
pub fn client() -> Self {
let provider = rustls::crypto::ring::default_provider();
let client = rustls::ClientConfig::builder_with_provider(provider.into())
.with_safe_default_protocol_versions()
.unwrap()
.with_root_certificates(client_root_store())
.with_no_client_auth();
Config {
client: Arc::new(client).into(),
server: None,
}
}
pub fn builder() -> Builder {
Builder {
client_root_store: client_root_store(),
server: None,
}
}
}
fn client_root_store() -> rustls::RootCertStore {
let mut client_root_store = rustls::RootCertStore::empty();
client_root_store.extend(webpki_roots::TLS_SERVER_ROOTS.iter().map(|ta| {
rustls::pki_types::TrustAnchor {
subject: ta.subject.into(),
subject_public_key_info: ta.spki.into(),
name_constraints: ta.name_constraints.map(|v| v.into()),
}
}));
client_root_store
}
pub struct Builder {
client_root_store: rustls::RootCertStore,
server: Option<rustls::ServerConfig>,
}
impl Builder {
pub fn server<I>(&mut self, key: PrivateKey, certs: I) -> Result<&mut Self, Error>
where
I: IntoIterator<Item = Certificate>,
{
let certs = certs.into_iter().map(|c| c.0).collect();
let provider = rustls::crypto::ring::default_provider();
let server = rustls::ServerConfig::builder_with_provider(provider.into())
.with_safe_default_protocol_versions()
.unwrap()
.with_no_client_auth()
.with_single_cert(certs, key.0)
.map_err(|e| Error::Tls(Box::new(e)))?;
self.server = Some(server);
Ok(self)
}
pub fn add_trust(&mut self, cert: &Certificate) -> Result<&mut Self, Error> {
self.client_root_store
.add(cert.0.to_owned())
.map_err(|e| Error::Tls(Box::new(e)))?;
Ok(self)
}
pub fn finish(self) -> Config {
let provider = rustls::crypto::ring::default_provider();
let client = rustls::ClientConfig::builder_with_provider(provider.into())
.with_safe_default_protocol_versions()
.unwrap()
.with_root_certificates(self.client_root_store)
.with_no_client_auth();
Config {
client: Arc::new(client).into(),
server: self.server.map(|s| Arc::new(s).into()),
}
}
}
pub(crate) fn dns_name_ref(name: &str) -> Result<rustls::pki_types::ServerName<'static>, Error> {
rustls::pki_types::ServerName::try_from(String::from(name))
.map_err(|_| Error::InvalidDnsName(name.into()))
}
#[derive(Debug)]
#[non_exhaustive]
pub enum Error {
Io(io::Error),
Tls(Box<dyn std::error::Error + Send + Sync>),
InvalidDnsName(String),
}
impl fmt::Display for Error {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Error::Io(e) => write!(f, "i/o error: {e}"),
Error::Tls(e) => write!(f, "tls error: {e}"),
Error::InvalidDnsName(n) => write!(f, "invalid DNS name: {n}"),
}
}
}
impl std::error::Error for Error {
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
match self {
Error::Io(e) => Some(e),
Error::Tls(e) => Some(&**e),
Error::InvalidDnsName(_) => None,
}
}
}
impl From<io::Error> for Error {
fn from(e: io::Error) -> Self {
Error::Io(e)
}
}