1use std::{
22    error::Error,
23    fmt, io, mem,
24    pin::Pin,
25    task::{Context, Poll},
26};
27
28use futures::{
29    io::{IoSlice, IoSliceMut},
30    prelude::*,
31    ready,
32};
33use pin_project::pin_project;
34
35use crate::protocol::{HeaderLine, Message, MessageReader, Protocol, ProtocolError};
36
37#[pin_project]
49#[derive(Debug)]
50pub struct Negotiated<TInner> {
51    #[pin]
52    state: State<TInner>,
53}
54
55#[derive(Debug)]
57pub struct NegotiatedComplete<TInner> {
58    inner: Option<Negotiated<TInner>>,
59}
60
61impl<TInner> Future for NegotiatedComplete<TInner>
62where
63    TInner: AsyncRead + AsyncWrite + Unpin,
68{
69    type Output = Result<Negotiated<TInner>, NegotiationError>;
70
71    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
72        let mut io = self
73            .inner
74            .take()
75            .expect("NegotiatedFuture called after completion.");
76        match Negotiated::poll(Pin::new(&mut io), cx) {
77            Poll::Pending => {
78                self.inner = Some(io);
79                Poll::Pending
80            }
81            Poll::Ready(Ok(())) => Poll::Ready(Ok(io)),
82            Poll::Ready(Err(err)) => {
83                self.inner = Some(io);
84                Poll::Ready(Err(err))
85            }
86        }
87    }
88}
89
90impl<TInner> Negotiated<TInner> {
91    pub(crate) fn completed(io: TInner) -> Self {
93        Negotiated {
94            state: State::Completed { io },
95        }
96    }
97
98    pub(crate) fn expecting(
101        io: MessageReader<TInner>,
102        protocol: Protocol,
103        header: Option<HeaderLine>,
104    ) -> Self {
105        Negotiated {
106            state: State::Expecting {
107                io,
108                protocol,
109                header,
110            },
111        }
112    }
113
114    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), NegotiationError>>
116    where
117        TInner: AsyncRead + AsyncWrite + Unpin,
118    {
119        match self.as_mut().poll_flush(cx) {
121            Poll::Ready(Ok(())) => {}
122            Poll::Pending => return Poll::Pending,
123            Poll::Ready(Err(e)) => {
124                if e.kind() != io::ErrorKind::WriteZero {
127                    return Poll::Ready(Err(e.into()));
128                }
129            }
130        }
131
132        let mut this = self.project();
133
134        if let StateProj::Completed { .. } = this.state.as_mut().project() {
135            return Poll::Ready(Ok(()));
136        }
137
138        loop {
140            match mem::replace(&mut *this.state, State::Invalid) {
141                State::Expecting {
142                    mut io,
143                    header,
144                    protocol,
145                } => {
146                    let msg = match Pin::new(&mut io).poll_next(cx)? {
147                        Poll::Ready(Some(msg)) => msg,
148                        Poll::Pending => {
149                            *this.state = State::Expecting {
150                                io,
151                                header,
152                                protocol,
153                            };
154                            return Poll::Pending;
155                        }
156                        Poll::Ready(None) => {
157                            return Poll::Ready(Err(ProtocolError::IoError(
158                                io::ErrorKind::UnexpectedEof.into(),
159                            )
160                            .into()));
161                        }
162                    };
163
164                    if let Message::Header(h) = &msg {
165                        if Some(h) == header.as_ref() {
166                            *this.state = State::Expecting {
167                                io,
168                                protocol,
169                                header: None,
170                            };
171                            continue;
172                        }
173                    }
174
175                    if let Message::Protocol(p) = &msg {
176                        if p.as_ref() == protocol.as_ref() {
177                            tracing::debug!(protocol=%p, "Negotiated: Received confirmation for protocol");
178                            *this.state = State::Completed {
179                                io: io.into_inner(),
180                            };
181                            return Poll::Ready(Ok(()));
182                        }
183                    }
184
185                    return Poll::Ready(Err(NegotiationError::Failed));
186                }
187
188                _ => panic!("Negotiated: Invalid state"),
189            }
190        }
191    }
192
193    pub fn complete(self) -> NegotiatedComplete<TInner> {
196        NegotiatedComplete { inner: Some(self) }
197    }
198}
199
200#[pin_project(project = StateProj)]
202#[derive(Debug)]
203enum State<R> {
204    Expecting {
208        #[pin]
210        io: MessageReader<R>,
211        header: Option<HeaderLine>,
214        protocol: Protocol,
216    },
217
218    Completed {
221        #[pin]
222        io: R,
223    },
224
225    Invalid,
228}
229
230impl<TInner> AsyncRead for Negotiated<TInner>
231where
232    TInner: AsyncRead + AsyncWrite + Unpin,
233{
234    fn poll_read(
235        mut self: Pin<&mut Self>,
236        cx: &mut Context<'_>,
237        buf: &mut [u8],
238    ) -> Poll<Result<usize, io::Error>> {
239        loop {
240            if let StateProj::Completed { io } = self.as_mut().project().state.project() {
241                return io.poll_read(cx, buf);
243            }
244
245            match self.as_mut().poll(cx) {
248                Poll::Ready(Ok(())) => {}
249                Poll::Pending => return Poll::Pending,
250                Poll::Ready(Err(err)) => return Poll::Ready(Err(From::from(err))),
251            }
252        }
253    }
254
255    fn poll_read_vectored(
265        mut self: Pin<&mut Self>,
266        cx: &mut Context<'_>,
267        bufs: &mut [IoSliceMut<'_>],
268    ) -> Poll<Result<usize, io::Error>> {
269        loop {
270            if let StateProj::Completed { io } = self.as_mut().project().state.project() {
271                return io.poll_read_vectored(cx, bufs);
273            }
274
275            match self.as_mut().poll(cx) {
278                Poll::Ready(Ok(())) => {}
279                Poll::Pending => return Poll::Pending,
280                Poll::Ready(Err(err)) => return Poll::Ready(Err(From::from(err))),
281            }
282        }
283    }
284}
285
286impl<TInner> AsyncWrite for Negotiated<TInner>
287where
288    TInner: AsyncWrite + AsyncRead + Unpin,
289{
290    fn poll_write(
291        self: Pin<&mut Self>,
292        cx: &mut Context<'_>,
293        buf: &[u8],
294    ) -> Poll<Result<usize, io::Error>> {
295        match self.project().state.project() {
296            StateProj::Completed { io } => io.poll_write(cx, buf),
297            StateProj::Expecting { io, .. } => io.poll_write(cx, buf),
298            StateProj::Invalid => panic!("Negotiated: Invalid state"),
299        }
300    }
301
302    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
303        match self.project().state.project() {
304            StateProj::Completed { io } => io.poll_flush(cx),
305            StateProj::Expecting { io, .. } => io.poll_flush(cx),
306            StateProj::Invalid => panic!("Negotiated: Invalid state"),
307        }
308    }
309
310    fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
311        ready!(self
313            .as_mut()
314            .poll_flush(cx)
315            .map_err(Into::<io::Error>::into)?);
316
317        match self.project().state.project() {
319            StateProj::Completed { io, .. } => io.poll_close(cx),
320            StateProj::Expecting { io, .. } => {
321                let close_poll = io.poll_close(cx);
322                if let Poll::Ready(Ok(())) = close_poll {
323                    tracing::debug!("Stream closed. Confirmation from remote for optimstic protocol negotiation still pending")
324                }
325                close_poll
326            }
327            StateProj::Invalid => panic!("Negotiated: Invalid state"),
328        }
329    }
330
331    fn poll_write_vectored(
332        self: Pin<&mut Self>,
333        cx: &mut Context<'_>,
334        bufs: &[IoSlice<'_>],
335    ) -> Poll<Result<usize, io::Error>> {
336        match self.project().state.project() {
337            StateProj::Completed { io } => io.poll_write_vectored(cx, bufs),
338            StateProj::Expecting { io, .. } => io.poll_write_vectored(cx, bufs),
339            StateProj::Invalid => panic!("Negotiated: Invalid state"),
340        }
341    }
342}
343
344#[derive(Debug)]
346pub enum NegotiationError {
347    ProtocolError(ProtocolError),
349
350    Failed,
352}
353
354impl From<ProtocolError> for NegotiationError {
355    fn from(err: ProtocolError) -> NegotiationError {
356        NegotiationError::ProtocolError(err)
357    }
358}
359
360impl From<io::Error> for NegotiationError {
361    fn from(err: io::Error) -> NegotiationError {
362        ProtocolError::from(err).into()
363    }
364}
365
366impl From<NegotiationError> for io::Error {
367    fn from(err: NegotiationError) -> io::Error {
368        if let NegotiationError::ProtocolError(e) = err {
369            return e.into();
370        }
371        io::Error::other(err)
372    }
373}
374
375impl Error for NegotiationError {
376    fn source(&self) -> Option<&(dyn Error + 'static)> {
377        match self {
378            NegotiationError::ProtocolError(err) => Some(err),
379            _ => None,
380        }
381    }
382}
383
384impl fmt::Display for NegotiationError {
385    fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> {
386        match self {
387            NegotiationError::ProtocolError(p) => {
388                fmt.write_fmt(format_args!("Protocol error: {p}"))
389            }
390            NegotiationError::Failed => fmt.write_str("Protocol negotiation failed."),
391        }
392    }
393}