1use std::{
22 error::Error,
23 fmt, io, mem,
24 pin::Pin,
25 task::{Context, Poll},
26};
27
28use futures::{
29 io::{IoSlice, IoSliceMut},
30 prelude::*,
31 ready,
32};
33use pin_project::pin_project;
34
35use crate::protocol::{HeaderLine, Message, MessageReader, Protocol, ProtocolError};
36
37#[pin_project]
49#[derive(Debug)]
50pub struct Negotiated<TInner> {
51 #[pin]
52 state: State<TInner>,
53}
54
55#[derive(Debug)]
57pub struct NegotiatedComplete<TInner> {
58 inner: Option<Negotiated<TInner>>,
59}
60
61impl<TInner> Future for NegotiatedComplete<TInner>
62where
63 TInner: AsyncRead + AsyncWrite + Unpin,
68{
69 type Output = Result<Negotiated<TInner>, NegotiationError>;
70
71 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
72 let mut io = self
73 .inner
74 .take()
75 .expect("NegotiatedFuture called after completion.");
76 match Negotiated::poll(Pin::new(&mut io), cx) {
77 Poll::Pending => {
78 self.inner = Some(io);
79 Poll::Pending
80 }
81 Poll::Ready(Ok(())) => Poll::Ready(Ok(io)),
82 Poll::Ready(Err(err)) => {
83 self.inner = Some(io);
84 Poll::Ready(Err(err))
85 }
86 }
87 }
88}
89
90impl<TInner> Negotiated<TInner> {
91 pub(crate) fn completed(io: TInner) -> Self {
93 Negotiated {
94 state: State::Completed { io },
95 }
96 }
97
98 pub(crate) fn expecting(
101 io: MessageReader<TInner>,
102 protocol: Protocol,
103 header: Option<HeaderLine>,
104 ) -> Self {
105 Negotiated {
106 state: State::Expecting {
107 io,
108 protocol,
109 header,
110 },
111 }
112 }
113
114 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), NegotiationError>>
116 where
117 TInner: AsyncRead + AsyncWrite + Unpin,
118 {
119 match self.as_mut().poll_flush(cx) {
121 Poll::Ready(Ok(())) => {}
122 Poll::Pending => return Poll::Pending,
123 Poll::Ready(Err(e)) => {
124 if e.kind() != io::ErrorKind::WriteZero {
127 return Poll::Ready(Err(e.into()));
128 }
129 }
130 }
131
132 let mut this = self.project();
133
134 if let StateProj::Completed { .. } = this.state.as_mut().project() {
135 return Poll::Ready(Ok(()));
136 }
137
138 loop {
140 match mem::replace(&mut *this.state, State::Invalid) {
141 State::Expecting {
142 mut io,
143 header,
144 protocol,
145 } => {
146 let msg = match Pin::new(&mut io).poll_next(cx)? {
147 Poll::Ready(Some(msg)) => msg,
148 Poll::Pending => {
149 *this.state = State::Expecting {
150 io,
151 header,
152 protocol,
153 };
154 return Poll::Pending;
155 }
156 Poll::Ready(None) => {
157 return Poll::Ready(Err(ProtocolError::IoError(
158 io::ErrorKind::UnexpectedEof.into(),
159 )
160 .into()));
161 }
162 };
163
164 if let Message::Header(h) = &msg {
165 if Some(h) == header.as_ref() {
166 *this.state = State::Expecting {
167 io,
168 protocol,
169 header: None,
170 };
171 continue;
172 }
173 }
174
175 if let Message::Protocol(p) = &msg {
176 if p.as_ref() == protocol.as_ref() {
177 tracing::debug!(protocol=%p, "Negotiated: Received confirmation for protocol");
178 *this.state = State::Completed {
179 io: io.into_inner(),
180 };
181 return Poll::Ready(Ok(()));
182 }
183 }
184
185 return Poll::Ready(Err(NegotiationError::Failed));
186 }
187
188 _ => panic!("Negotiated: Invalid state"),
189 }
190 }
191 }
192
193 pub fn complete(self) -> NegotiatedComplete<TInner> {
196 NegotiatedComplete { inner: Some(self) }
197 }
198}
199
200#[pin_project(project = StateProj)]
202#[derive(Debug)]
203enum State<R> {
204 Expecting {
208 #[pin]
210 io: MessageReader<R>,
211 header: Option<HeaderLine>,
214 protocol: Protocol,
216 },
217
218 Completed {
221 #[pin]
222 io: R,
223 },
224
225 Invalid,
228}
229
230impl<TInner> AsyncRead for Negotiated<TInner>
231where
232 TInner: AsyncRead + AsyncWrite + Unpin,
233{
234 fn poll_read(
235 mut self: Pin<&mut Self>,
236 cx: &mut Context<'_>,
237 buf: &mut [u8],
238 ) -> Poll<Result<usize, io::Error>> {
239 loop {
240 if let StateProj::Completed { io } = self.as_mut().project().state.project() {
241 return io.poll_read(cx, buf);
243 }
244
245 match self.as_mut().poll(cx) {
248 Poll::Ready(Ok(())) => {}
249 Poll::Pending => return Poll::Pending,
250 Poll::Ready(Err(err)) => return Poll::Ready(Err(From::from(err))),
251 }
252 }
253 }
254
255 fn poll_read_vectored(
265 mut self: Pin<&mut Self>,
266 cx: &mut Context<'_>,
267 bufs: &mut [IoSliceMut<'_>],
268 ) -> Poll<Result<usize, io::Error>> {
269 loop {
270 if let StateProj::Completed { io } = self.as_mut().project().state.project() {
271 return io.poll_read_vectored(cx, bufs);
273 }
274
275 match self.as_mut().poll(cx) {
278 Poll::Ready(Ok(())) => {}
279 Poll::Pending => return Poll::Pending,
280 Poll::Ready(Err(err)) => return Poll::Ready(Err(From::from(err))),
281 }
282 }
283 }
284}
285
286impl<TInner> AsyncWrite for Negotiated<TInner>
287where
288 TInner: AsyncWrite + AsyncRead + Unpin,
289{
290 fn poll_write(
291 self: Pin<&mut Self>,
292 cx: &mut Context<'_>,
293 buf: &[u8],
294 ) -> Poll<Result<usize, io::Error>> {
295 match self.project().state.project() {
296 StateProj::Completed { io } => io.poll_write(cx, buf),
297 StateProj::Expecting { io, .. } => io.poll_write(cx, buf),
298 StateProj::Invalid => panic!("Negotiated: Invalid state"),
299 }
300 }
301
302 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
303 match self.project().state.project() {
304 StateProj::Completed { io } => io.poll_flush(cx),
305 StateProj::Expecting { io, .. } => io.poll_flush(cx),
306 StateProj::Invalid => panic!("Negotiated: Invalid state"),
307 }
308 }
309
310 fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
311 ready!(self
313 .as_mut()
314 .poll_flush(cx)
315 .map_err(Into::<io::Error>::into)?);
316
317 match self.project().state.project() {
319 StateProj::Completed { io, .. } => io.poll_close(cx),
320 StateProj::Expecting { io, .. } => {
321 let close_poll = io.poll_close(cx);
322 if let Poll::Ready(Ok(())) = close_poll {
323 tracing::debug!("Stream closed. Confirmation from remote for optimstic protocol negotiation still pending")
324 }
325 close_poll
326 }
327 StateProj::Invalid => panic!("Negotiated: Invalid state"),
328 }
329 }
330
331 fn poll_write_vectored(
332 self: Pin<&mut Self>,
333 cx: &mut Context<'_>,
334 bufs: &[IoSlice<'_>],
335 ) -> Poll<Result<usize, io::Error>> {
336 match self.project().state.project() {
337 StateProj::Completed { io } => io.poll_write_vectored(cx, bufs),
338 StateProj::Expecting { io, .. } => io.poll_write_vectored(cx, bufs),
339 StateProj::Invalid => panic!("Negotiated: Invalid state"),
340 }
341 }
342}
343
344#[derive(Debug)]
346pub enum NegotiationError {
347 ProtocolError(ProtocolError),
349
350 Failed,
352}
353
354impl From<ProtocolError> for NegotiationError {
355 fn from(err: ProtocolError) -> NegotiationError {
356 NegotiationError::ProtocolError(err)
357 }
358}
359
360impl From<io::Error> for NegotiationError {
361 fn from(err: io::Error) -> NegotiationError {
362 ProtocolError::from(err).into()
363 }
364}
365
366impl From<NegotiationError> for io::Error {
367 fn from(err: NegotiationError) -> io::Error {
368 if let NegotiationError::ProtocolError(e) = err {
369 return e.into();
370 }
371 io::Error::new(io::ErrorKind::Other, err)
372 }
373}
374
375impl Error for NegotiationError {
376 fn source(&self) -> Option<&(dyn Error + 'static)> {
377 match self {
378 NegotiationError::ProtocolError(err) => Some(err),
379 _ => None,
380 }
381 }
382}
383
384impl fmt::Display for NegotiationError {
385 fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> {
386 match self {
387 NegotiationError::ProtocolError(p) => {
388 fmt.write_fmt(format_args!("Protocol error: {p}"))
389 }
390 NegotiationError::Failed => fmt.write_str("Protocol negotiation failed."),
391 }
392 }
393}