libp2p_noise/io/
framed.rs1use std::{io, mem::size_of};
27
28use asynchronous_codec::{Decoder, Encoder};
29use bytes::{Buf, Bytes, BytesMut};
30use quick_protobuf::{BytesReader, MessageRead, MessageWrite, Writer};
31
32use super::handshake::proto;
33use crate::{protocol::PublicKey, Error};
34
35const MAX_NOISE_MSG_LEN: usize = 65535;
37const EXTRA_ENCRYPT_SPACE: usize = 1024;
39pub(crate) const MAX_FRAME_LEN: usize = MAX_NOISE_MSG_LEN - EXTRA_ENCRYPT_SPACE;
41static_assertions::const_assert! {
42 MAX_FRAME_LEN + EXTRA_ENCRYPT_SPACE <= MAX_NOISE_MSG_LEN
43}
44
45pub(crate) struct Codec<S> {
48 session: S,
49
50 write_buffer: BytesMut,
53 encrypt_buffer: BytesMut,
54}
55
56impl<S> Codec<S> {
57 pub(crate) fn new(session: S) -> Self {
58 Codec {
59 session,
60 write_buffer: BytesMut::default(),
61 encrypt_buffer: BytesMut::default(),
62 }
63 }
64}
65
66impl Codec<snow::HandshakeState> {
67 pub(crate) fn is_initiator(&self) -> bool {
69 self.session.is_initiator()
70 }
71
72 pub(crate) fn is_responder(&self) -> bool {
74 !self.session.is_initiator()
75 }
76
77 pub(crate) fn into_transport(self) -> Result<(PublicKey, Codec<snow::TransportState>), Error> {
88 let dh_remote_pubkey = self.session.get_remote_static().ok_or_else(|| {
89 Error::Io(io::Error::new(
90 io::ErrorKind::Other,
91 "expect key to always be present at end of XX session",
92 ))
93 })?;
94
95 let dh_remote_pubkey = PublicKey::from_slice(dh_remote_pubkey)?;
96 let codec = Codec::new(self.session.into_transport_mode()?);
97
98 Ok((dh_remote_pubkey, codec))
99 }
100}
101
102impl Encoder for Codec<snow::HandshakeState> {
103 type Error = io::Error;
104 type Item<'a> = &'a proto::NoiseHandshakePayload;
105
106 fn encode(&mut self, item: Self::Item<'_>, dst: &mut BytesMut) -> Result<(), Self::Error> {
107 let item_size = item.get_size();
108
109 self.write_buffer.resize(item_size, 0);
110 let mut writer = Writer::new(&mut self.write_buffer[..item_size]);
111 item.write_message(&mut writer)
112 .expect("Protobuf encoding to succeed");
113
114 encrypt(
115 &self.write_buffer[..item_size],
116 dst,
117 &mut self.encrypt_buffer,
118 |item, buffer| self.session.write_message(item, buffer),
119 )?;
120
121 Ok(())
122 }
123}
124
125impl Decoder for Codec<snow::HandshakeState> {
126 type Error = io::Error;
127 type Item = proto::NoiseHandshakePayload;
128
129 fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
130 let cleartext = match decrypt(src, |ciphertext, decrypt_buffer| {
131 self.session.read_message(ciphertext, decrypt_buffer)
132 })? {
133 None => return Ok(None),
134 Some(cleartext) => cleartext,
135 };
136
137 let mut reader = BytesReader::from_bytes(&cleartext[..]);
138 let pb =
139 proto::NoiseHandshakePayload::from_reader(&mut reader, &cleartext).map_err(|_| {
140 io::Error::new(
141 io::ErrorKind::InvalidData,
142 "Failed decoding handshake payload",
143 )
144 })?;
145
146 Ok(Some(pb))
147 }
148}
149
150impl Encoder for Codec<snow::TransportState> {
151 type Error = io::Error;
152 type Item<'a> = &'a [u8];
153
154 fn encode(&mut self, item: Self::Item<'_>, dst: &mut BytesMut) -> Result<(), Self::Error> {
155 encrypt(item, dst, &mut self.encrypt_buffer, |item, buffer| {
156 self.session.write_message(item, buffer)
157 })
158 }
159}
160
161impl Decoder for Codec<snow::TransportState> {
162 type Error = io::Error;
163 type Item = Bytes;
164
165 fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
166 decrypt(src, |ciphertext, decrypt_buffer| {
167 self.session.read_message(ciphertext, decrypt_buffer)
168 })
169 }
170}
171
172fn encrypt(
177 cleartext: &[u8],
178 dst: &mut BytesMut,
179 encrypt_buffer: &mut BytesMut,
180 encrypt_fn: impl FnOnce(&[u8], &mut [u8]) -> Result<usize, snow::Error>,
181) -> io::Result<()> {
182 tracing::trace!("Encrypting {} bytes", cleartext.len());
183
184 encrypt_buffer.resize(cleartext.len() + EXTRA_ENCRYPT_SPACE, 0);
185 let n = encrypt_fn(cleartext, encrypt_buffer).map_err(into_io_error)?;
186
187 tracing::trace!("Outgoing ciphertext has {n} bytes");
188
189 encode_length_prefixed(&encrypt_buffer[..n], dst);
190
191 Ok(())
192}
193
194fn decrypt(
200 ciphertext: &mut BytesMut,
201 decrypt_fn: impl FnOnce(&[u8], &mut [u8]) -> Result<usize, snow::Error>,
202) -> io::Result<Option<Bytes>> {
203 let Some(ciphertext) = decode_length_prefixed(ciphertext) else {
204 return Ok(None);
205 };
206
207 tracing::trace!("Incoming ciphertext has {} bytes", ciphertext.len());
208
209 let mut decrypt_buffer = BytesMut::zeroed(ciphertext.len());
210 let n = decrypt_fn(&ciphertext, &mut decrypt_buffer).map_err(into_io_error)?;
211
212 tracing::trace!("Decrypted cleartext has {n} bytes");
213
214 Ok(Some(decrypt_buffer.split_to(n).freeze()))
215}
216
217fn into_io_error(err: snow::Error) -> io::Error {
218 io::Error::new(io::ErrorKind::InvalidData, err)
219}
220
221const U16_LENGTH: usize = size_of::<u16>();
222
223fn encode_length_prefixed(src: &[u8], dst: &mut BytesMut) {
224 dst.reserve(U16_LENGTH + src.len());
225 dst.extend_from_slice(&(src.len() as u16).to_be_bytes());
226 dst.extend_from_slice(src);
227}
228
229fn decode_length_prefixed(src: &mut BytesMut) -> Option<Bytes> {
230 if src.len() < size_of::<u16>() {
231 return None;
232 }
233
234 let mut len_bytes = [0u8; U16_LENGTH];
235 len_bytes.copy_from_slice(&src[..U16_LENGTH]);
236 let len = u16::from_be_bytes(len_bytes) as usize;
237
238 if src.len() - U16_LENGTH >= len {
239 src.advance(U16_LENGTH);
241 Some(src.split_to(len).freeze())
242 } else {
243 None
244 }
245}