1use std::{fmt, sync::Arc};
24
25use asn1_der::{
26 typed::{DerDecodable, DerEncodable, DerTypeView, Sequence},
27 Asn1DerError, Asn1DerErrorVariant, DerObject, Sink, VecBacking,
28};
29use ring::{
30 rand::SystemRandom,
31 signature::{self, KeyPair, RsaKeyPair, RSA_PKCS1_2048_8192_SHA256, RSA_PKCS1_SHA256},
32};
33use zeroize::Zeroize;
34
35use super::error::*;
36
37#[derive(Clone)]
39pub struct Keypair(Arc<RsaKeyPair>);
40
41impl std::fmt::Debug for Keypair {
42 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
43 f.debug_struct("Keypair")
44 .field("public", self.0.public_key())
45 .finish()
46 }
47}
48
49impl Keypair {
50 pub fn try_decode_pkcs1(der: &mut [u8]) -> Result<Keypair, DecodingError> {
55 let kp = RsaKeyPair::from_der(der)
56 .map_err(|e| DecodingError::failed_to_parse("RSA DER PKCS#1 RSAPrivateKey", e))?;
57 der.zeroize();
58 Ok(Keypair(Arc::new(kp)))
59 }
60
61 pub fn try_decode_pkcs8(der: &mut [u8]) -> Result<Keypair, DecodingError> {
66 let kp = RsaKeyPair::from_pkcs8(der)
67 .map_err(|e| DecodingError::failed_to_parse("RSA PKCS#8 PrivateKeyInfo", e))?;
68 der.zeroize();
69 Ok(Keypair(Arc::new(kp)))
70 }
71
72 pub fn public(&self) -> PublicKey {
74 PublicKey(self.0.public_key().as_ref().to_vec())
75 }
76
77 pub fn sign(&self, data: &[u8]) -> Result<Vec<u8>, SigningError> {
79 let mut signature = vec![0; self.0.public().modulus_len()];
80 let rng = SystemRandom::new();
81 match self.0.sign(&RSA_PKCS1_SHA256, &rng, data, &mut signature) {
82 Ok(()) => Ok(signature),
83 Err(e) => Err(SigningError::new("RSA").source(e)),
84 }
85 }
86}
87
88#[derive(Clone, PartialEq, Eq, Hash, PartialOrd, Ord)]
90pub struct PublicKey(Vec<u8>);
91
92impl PublicKey {
93 pub fn verify(&self, msg: &[u8], sig: &[u8]) -> bool {
95 let key = signature::UnparsedPublicKey::new(&RSA_PKCS1_2048_8192_SHA256, &self.0);
96 key.verify(msg, sig).is_ok()
97 }
98
99 pub fn encode_pkcs1(&self) -> Vec<u8> {
104 self.0.clone()
106 }
107
108 pub fn encode_x509(&self) -> Vec<u8> {
113 let spki = Asn1SubjectPublicKeyInfo {
114 algorithmIdentifier: Asn1RsaEncryption {
115 algorithm: Asn1OidRsaEncryption,
116 parameters: (),
117 },
118 subjectPublicKey: Asn1SubjectPublicKey(self.clone()),
119 };
120 let mut buf = Vec::new();
121 spki.encode(&mut buf)
122 .map(|_| buf)
123 .expect("RSA X.509 public key encoding failed.")
124 }
125
126 pub fn try_decode_x509(pk: &[u8]) -> Result<PublicKey, DecodingError> {
129 Asn1SubjectPublicKeyInfo::decode(pk)
130 .map_err(|e| DecodingError::failed_to_parse("RSA X.509", e))
131 .map(|spki| spki.subjectPublicKey.0)
132 }
133}
134
135impl fmt::Debug for PublicKey {
136 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
137 f.write_str("PublicKey(PKCS1): ")?;
138 for byte in &self.0 {
139 write!(f, "{byte:x}")?;
140 }
141 Ok(())
142 }
143}
144
145#[derive(Copy, Clone)]
153struct Asn1RawOid<'a> {
154 object: DerObject<'a>,
155}
156
157impl Asn1RawOid<'_> {
158 pub(crate) fn oid(&self) -> &[u8] {
160 self.object.value()
161 }
162
163 pub(crate) fn write<S: Sink>(value: &[u8], sink: &mut S) -> Result<(), Asn1DerError> {
165 DerObject::write(Self::TAG, value.len(), &mut value.iter(), sink)
166 }
167}
168
169impl<'a> DerTypeView<'a> for Asn1RawOid<'a> {
170 const TAG: u8 = 6;
171
172 fn object(&self) -> DerObject<'a> {
173 self.object
174 }
175}
176
177impl DerEncodable for Asn1RawOid<'_> {
178 fn encode<S: Sink>(&self, sink: &mut S) -> Result<(), Asn1DerError> {
179 self.object.encode(sink)
180 }
181}
182
183impl<'a> DerDecodable<'a> for Asn1RawOid<'a> {
184 fn load(object: DerObject<'a>) -> Result<Self, Asn1DerError> {
185 if object.tag() != Self::TAG {
186 return Err(Asn1DerError::new(Asn1DerErrorVariant::InvalidData(
187 "DER object tag is not the object identifier tag.",
188 )));
189 }
190
191 Ok(Self { object })
192 }
193}
194
195#[derive(Clone)]
197struct Asn1OidRsaEncryption;
198
199impl Asn1OidRsaEncryption {
200 const OID: [u8; 9] = [0x2A, 0x86, 0x48, 0x86, 0xF7, 0x0D, 0x01, 0x01, 0x01];
207}
208
209impl DerEncodable for Asn1OidRsaEncryption {
210 fn encode<S: Sink>(&self, sink: &mut S) -> Result<(), Asn1DerError> {
211 Asn1RawOid::write(&Self::OID, sink)
212 }
213}
214
215impl DerDecodable<'_> for Asn1OidRsaEncryption {
216 fn load(object: DerObject<'_>) -> Result<Self, Asn1DerError> {
217 match Asn1RawOid::load(object)?.oid() {
218 oid if oid == Self::OID => Ok(Self),
219 _ => Err(Asn1DerError::new(Asn1DerErrorVariant::InvalidData(
220 "DER object is not the 'rsaEncryption' identifier.",
221 ))),
222 }
223 }
224}
225
226struct Asn1RsaEncryption {
228 algorithm: Asn1OidRsaEncryption,
229 parameters: (),
230}
231
232impl DerEncodable for Asn1RsaEncryption {
233 fn encode<S: Sink>(&self, sink: &mut S) -> Result<(), Asn1DerError> {
234 let mut algorithm_buf = Vec::new();
235 let algorithm = self.algorithm.der_object(VecBacking(&mut algorithm_buf))?;
236
237 let mut parameters_buf = Vec::new();
238 let parameters = self
239 .parameters
240 .der_object(VecBacking(&mut parameters_buf))?;
241
242 Sequence::write(&[algorithm, parameters], sink)
243 }
244}
245
246impl DerDecodable<'_> for Asn1RsaEncryption {
247 fn load(object: DerObject<'_>) -> Result<Self, Asn1DerError> {
248 let seq: Sequence = Sequence::load(object)?;
249
250 Ok(Self {
251 algorithm: seq.get_as(0)?,
252 parameters: seq.get_as(1)?,
253 })
254 }
255}
256
257struct Asn1SubjectPublicKey(PublicKey);
260
261impl DerEncodable for Asn1SubjectPublicKey {
262 fn encode<S: Sink>(&self, sink: &mut S) -> Result<(), Asn1DerError> {
263 let pk_der = &(self.0).0;
264 let mut bit_string = Vec::with_capacity(pk_der.len() + 1);
265 bit_string.push(0u8);
268 bit_string.extend(pk_der);
269 DerObject::write(3, bit_string.len(), &mut bit_string.iter(), sink)?;
270 Ok(())
271 }
272}
273
274impl DerDecodable<'_> for Asn1SubjectPublicKey {
275 fn load(object: DerObject<'_>) -> Result<Self, Asn1DerError> {
276 if object.tag() != 3 {
277 return Err(Asn1DerError::new(Asn1DerErrorVariant::InvalidData(
278 "DER object tag is not the bit string tag.",
279 )));
280 }
281
282 let pk_der: Vec<u8> = object.value().iter().skip(1).cloned().collect();
283 Ok(Self(PublicKey(pk_der)))
286 }
287}
288
289#[allow(non_snake_case)]
291struct Asn1SubjectPublicKeyInfo {
292 algorithmIdentifier: Asn1RsaEncryption,
293 subjectPublicKey: Asn1SubjectPublicKey,
294}
295
296impl DerEncodable for Asn1SubjectPublicKeyInfo {
297 fn encode<S: Sink>(&self, sink: &mut S) -> Result<(), Asn1DerError> {
298 let mut identifier_buf = Vec::new();
299 let identifier = self
300 .algorithmIdentifier
301 .der_object(VecBacking(&mut identifier_buf))?;
302
303 let mut key_buf = Vec::new();
304 let key = self.subjectPublicKey.der_object(VecBacking(&mut key_buf))?;
305
306 Sequence::write(&[identifier, key], sink)
307 }
308}
309
310impl DerDecodable<'_> for Asn1SubjectPublicKeyInfo {
311 fn load(object: DerObject<'_>) -> Result<Self, Asn1DerError> {
312 let seq: Sequence = Sequence::load(object)?;
313
314 Ok(Self {
315 algorithmIdentifier: seq.get_as(0)?,
316 subjectPublicKey: seq.get_as(1)?,
317 })
318 }
319}
320
321#[cfg(test)]
322mod tests {
323 use quickcheck::*;
324
325 use super::*;
326
327 const KEY1: &[u8] = include_bytes!("test/rsa-2048.pk8");
328 const KEY2: &[u8] = include_bytes!("test/rsa-3072.pk8");
329 const KEY3: &[u8] = include_bytes!("test/rsa-4096.pk8");
330
331 #[derive(Clone, Debug)]
332 struct SomeKeypair(Keypair);
333
334 impl Arbitrary for SomeKeypair {
335 fn arbitrary(g: &mut Gen) -> SomeKeypair {
336 let mut key = g.choose(&[KEY1, KEY2, KEY3]).unwrap().to_vec();
337 SomeKeypair(Keypair::try_decode_pkcs8(&mut key).unwrap())
338 }
339 }
340
341 #[test]
342 fn rsa_from_pkcs8() {
343 assert!(Keypair::try_decode_pkcs8(&mut KEY1.to_vec()).is_ok());
344 assert!(Keypair::try_decode_pkcs8(&mut KEY2.to_vec()).is_ok());
345 assert!(Keypair::try_decode_pkcs8(&mut KEY3.to_vec()).is_ok());
346 }
347
348 #[test]
349 fn rsa_x509_encode_decode() {
350 fn prop(SomeKeypair(kp): SomeKeypair) -> Result<bool, String> {
351 let pk = kp.public();
352 PublicKey::try_decode_x509(&pk.encode_x509())
353 .map_err(|e| e.to_string())
354 .map(|pk2| pk2 == pk)
355 }
356 QuickCheck::new().tests(10).quickcheck(prop as fn(_) -> _);
357 }
358
359 #[test]
360 fn rsa_sign_verify() {
361 fn prop(SomeKeypair(kp): SomeKeypair, msg: Vec<u8>) -> Result<bool, SigningError> {
362 kp.sign(&msg).map(|s| kp.public().verify(&msg, &s))
363 }
364 QuickCheck::new()
365 .tests(10)
366 .quickcheck(prop as fn(_, _) -> _);
367 }
368}