libp2p_noise/io/
framed.rs1use std::{io, mem::size_of};
27
28use asynchronous_codec::{Decoder, Encoder};
29use bytes::{Buf, Bytes, BytesMut};
30use quick_protobuf::{BytesReader, MessageRead, MessageWrite, Writer};
31
32use super::handshake::proto;
33use crate::{protocol::PublicKey, Error};
34
35const MAX_NOISE_MSG_LEN: usize = 65535;
37const EXTRA_ENCRYPT_SPACE: usize = 1024;
39pub(crate) const MAX_FRAME_LEN: usize = MAX_NOISE_MSG_LEN - EXTRA_ENCRYPT_SPACE;
41static_assertions::const_assert! {
42 MAX_FRAME_LEN + EXTRA_ENCRYPT_SPACE <= MAX_NOISE_MSG_LEN
43}
44
45pub(crate) struct Codec<S> {
48 session: S,
49
50 write_buffer: BytesMut,
53 encrypt_buffer: BytesMut,
54}
55
56impl<S> Codec<S> {
57 pub(crate) fn new(session: S) -> Self {
58 Codec {
59 session,
60 write_buffer: BytesMut::default(),
61 encrypt_buffer: BytesMut::default(),
62 }
63 }
64}
65
66impl Codec<snow::HandshakeState> {
67 pub(crate) fn is_initiator(&self) -> bool {
69 self.session.is_initiator()
70 }
71
72 pub(crate) fn is_responder(&self) -> bool {
74 !self.session.is_initiator()
75 }
76
77 pub(crate) fn into_transport(self) -> Result<(PublicKey, Codec<snow::TransportState>), Error> {
88 let dh_remote_pubkey = self.session.get_remote_static().ok_or_else(|| {
89 Error::Io(io::Error::other(
90 "expect key to always be present at end of XX session",
91 ))
92 })?;
93
94 let dh_remote_pubkey = PublicKey::from_slice(dh_remote_pubkey)?;
95 let codec = Codec::new(self.session.into_transport_mode()?);
96
97 Ok((dh_remote_pubkey, codec))
98 }
99}
100
101impl Encoder for Codec<snow::HandshakeState> {
102 type Error = io::Error;
103 type Item<'a> = &'a proto::NoiseHandshakePayload;
104
105 fn encode(&mut self, item: Self::Item<'_>, dst: &mut BytesMut) -> Result<(), Self::Error> {
106 let item_size = item.get_size();
107
108 self.write_buffer.resize(item_size, 0);
109 let mut writer = Writer::new(&mut self.write_buffer[..item_size]);
110 item.write_message(&mut writer)
111 .expect("Protobuf encoding to succeed");
112
113 encrypt(
114 &self.write_buffer[..item_size],
115 dst,
116 &mut self.encrypt_buffer,
117 |item, buffer| self.session.write_message(item, buffer),
118 )?;
119
120 Ok(())
121 }
122}
123
124impl Decoder for Codec<snow::HandshakeState> {
125 type Error = io::Error;
126 type Item = proto::NoiseHandshakePayload;
127
128 fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
129 let Some(cleartext) = decrypt(src, |ciphertext, decrypt_buffer| {
130 self.session.read_message(ciphertext, decrypt_buffer)
131 })?
132 else {
133 return Ok(None);
134 };
135
136 let mut reader = BytesReader::from_bytes(&cleartext[..]);
137 let pb =
138 proto::NoiseHandshakePayload::from_reader(&mut reader, &cleartext).map_err(|_| {
139 io::Error::new(
140 io::ErrorKind::InvalidData,
141 "Failed decoding handshake payload",
142 )
143 })?;
144
145 Ok(Some(pb))
146 }
147}
148
149impl Encoder for Codec<snow::TransportState> {
150 type Error = io::Error;
151 type Item<'a> = &'a [u8];
152
153 fn encode(&mut self, item: Self::Item<'_>, dst: &mut BytesMut) -> Result<(), Self::Error> {
154 encrypt(item, dst, &mut self.encrypt_buffer, |item, buffer| {
155 self.session.write_message(item, buffer)
156 })
157 }
158}
159
160impl Decoder for Codec<snow::TransportState> {
161 type Error = io::Error;
162 type Item = Bytes;
163
164 fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
165 decrypt(src, |ciphertext, decrypt_buffer| {
166 self.session.read_message(ciphertext, decrypt_buffer)
167 })
168 }
169}
170
171fn encrypt(
176 cleartext: &[u8],
177 dst: &mut BytesMut,
178 encrypt_buffer: &mut BytesMut,
179 encrypt_fn: impl FnOnce(&[u8], &mut [u8]) -> Result<usize, snow::Error>,
180) -> io::Result<()> {
181 tracing::trace!("Encrypting {} bytes", cleartext.len());
182
183 encrypt_buffer.resize(cleartext.len() + EXTRA_ENCRYPT_SPACE, 0);
184 let n = encrypt_fn(cleartext, encrypt_buffer).map_err(into_io_error)?;
185
186 tracing::trace!("Outgoing ciphertext has {n} bytes");
187
188 encode_length_prefixed(&encrypt_buffer[..n], dst);
189
190 Ok(())
191}
192
193fn decrypt(
199 ciphertext: &mut BytesMut,
200 decrypt_fn: impl FnOnce(&[u8], &mut [u8]) -> Result<usize, snow::Error>,
201) -> io::Result<Option<Bytes>> {
202 let Some(ciphertext) = decode_length_prefixed(ciphertext) else {
203 return Ok(None);
204 };
205
206 tracing::trace!("Incoming ciphertext has {} bytes", ciphertext.len());
207
208 let mut decrypt_buffer = BytesMut::zeroed(ciphertext.len());
209 let n = decrypt_fn(&ciphertext, &mut decrypt_buffer).map_err(into_io_error)?;
210
211 tracing::trace!("Decrypted cleartext has {n} bytes");
212
213 Ok(Some(decrypt_buffer.split_to(n).freeze()))
214}
215
216fn into_io_error(err: snow::Error) -> io::Error {
217 io::Error::new(io::ErrorKind::InvalidData, err)
218}
219
220const U16_LENGTH: usize = size_of::<u16>();
221
222fn encode_length_prefixed(src: &[u8], dst: &mut BytesMut) {
223 dst.reserve(U16_LENGTH + src.len());
224 dst.extend_from_slice(&(src.len() as u16).to_be_bytes());
225 dst.extend_from_slice(src);
226}
227
228fn decode_length_prefixed(src: &mut BytesMut) -> Option<Bytes> {
229 if src.len() < size_of::<u16>() {
230 return None;
231 }
232
233 let mut len_bytes = [0u8; U16_LENGTH];
234 len_bytes.copy_from_slice(&src[..U16_LENGTH]);
235 let len = u16::from_be_bytes(len_bytes) as usize;
236
237 if src.len() - U16_LENGTH >= len {
238 src.advance(U16_LENGTH);
240 Some(src.split_to(len).freeze())
241 } else {
242 None
243 }
244}