libp2p_websocket/
quicksink.rs

1// Copyright (c) 2019-2020 Parity Technologies (UK) Ltd.
2//
3// Licensed under the Apache License, Version 2.0
4// <LICENSE-APACHE or http://www.apache.org/licenses/LICENSE-2.0> or the MIT
5// license <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
6// option. All files in the project carrying such notice may not be copied,
7// modified, or distributed except according to those terms.
8//
9// Forked into rust-libp2p and further distributed under the MIT license.
10
11// Create a [`Sink`] implementation from an initial value and a closure
12// returning a [`Future`].
13//
14// This is very similar to how `futures::stream::unfold` creates a `Stream`
15// implementation from a seed value and a future-returning closure.
16//
17// # Examples
18//
19// ```no_run
20// use async_std::io;
21// use futures::prelude::*;
22//
23// use crate::quicksink::Action;
24//
25// crate::quicksink::make_sink(io::stdout(), |mut stdout, action| async move {
26//     match action {
27//         Action::Send(x) => stdout.write_all(x).await?,
28//         Action::Flush => stdout.flush().await?,
29//         Action::Close => stdout.close().await?,
30//     }
31//     Ok::<_, io::Error>(stdout)
32// });
33// ```
34
35use std::{
36    future::Future,
37    pin::Pin,
38    task::{Context, Poll},
39};
40
41use futures::{ready, sink::Sink};
42use pin_project_lite::pin_project;
43
44/// Returns a `Sink` impl based on the initial value and the given closure.
45///
46/// The closure will be applied to the initial value and an [`Action`] that
47/// informs it about the action it should perform. The returned [`Future`]
48/// will resolve to another value and the process starts over using this
49/// output.
50pub(crate) fn make_sink<S, F, T, A, E>(init: S, f: F) -> SinkImpl<S, F, T, A, E>
51where
52    F: FnMut(S, Action<A>) -> T,
53    T: Future<Output = Result<S, E>>,
54{
55    SinkImpl {
56        lambda: f,
57        future: None,
58        param: Some(init),
59        state: State::Empty,
60        _mark: std::marker::PhantomData,
61    }
62}
63
64/// The command given to the closure so that it can perform appropriate action.
65///
66/// Presumably the closure encapsulates a resource to perform I/O. The commands
67/// correspond to methods of the [`Sink`] trait and provide the closure with
68/// sufficient information to know what kind of action to perform with it.
69#[derive(Clone, Debug, PartialEq, Eq)]
70pub(crate) enum Action<A> {
71    /// Send the given value.
72    /// Corresponds to [`Sink::start_send`].
73    Send(A),
74    /// Flush the resource.
75    /// Corresponds to [`Sink::poll_flush`].
76    Flush,
77    /// Close the resource.
78    /// Corresponds to [`Sink::poll_close`].
79    Close,
80}
81
82/// The various states the `Sink` may be in.
83#[derive(Debug, PartialEq, Eq)]
84enum State {
85    /// The `Sink` is idle.
86    Empty,
87    /// The `Sink` is sending a value.
88    Sending,
89    /// The `Sink` is flushing its resource.
90    Flushing,
91    /// The `Sink` is closing its resource.
92    Closing,
93    /// The `Sink` is closed (terminal state).
94    Closed,
95    /// The `Sink` experienced an error (terminal state).
96    Failed,
97}
98
99/// Errors the `Sink` may return.
100#[derive(Debug, thiserror::Error)]
101pub(crate) enum Error<E> {
102    #[error("Error while sending over the sink, {0}")]
103    Send(E),
104    #[error("The Sink has closed")]
105    Closed,
106}
107
108pin_project! {
109    /// `SinkImpl` implements the `Sink` trait.
110    #[derive(Debug)]
111    pub(crate) struct SinkImpl<S, F, T, A, E> {
112        lambda: F,
113        #[pin] future: Option<T>,
114        param: Option<S>,
115        state: State,
116        _mark: std::marker::PhantomData<(A, E)>
117    }
118}
119
120impl<S, F, T, A, E> Sink<A> for SinkImpl<S, F, T, A, E>
121where
122    F: FnMut(S, Action<A>) -> T,
123    T: Future<Output = Result<S, E>>,
124{
125    type Error = Error<E>;
126
127    fn poll_ready(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
128        let mut this = self.project();
129        match this.state {
130            State::Sending | State::Flushing => {
131                match ready!(this.future.as_mut().as_pin_mut().unwrap().poll(cx)) {
132                    Ok(p) => {
133                        this.future.set(None);
134                        *this.param = Some(p);
135                        *this.state = State::Empty;
136                        Poll::Ready(Ok(()))
137                    }
138                    Err(e) => {
139                        this.future.set(None);
140                        *this.state = State::Failed;
141                        Poll::Ready(Err(Error::Send(e)))
142                    }
143                }
144            }
145            State::Closing => match ready!(this.future.as_mut().as_pin_mut().unwrap().poll(cx)) {
146                Ok(_) => {
147                    this.future.set(None);
148                    *this.state = State::Closed;
149                    Poll::Ready(Err(Error::Closed))
150                }
151                Err(e) => {
152                    this.future.set(None);
153                    *this.state = State::Failed;
154                    Poll::Ready(Err(Error::Send(e)))
155                }
156            },
157            State::Empty => {
158                assert!(this.param.is_some());
159                Poll::Ready(Ok(()))
160            }
161            State::Closed | State::Failed => Poll::Ready(Err(Error::Closed)),
162        }
163    }
164
165    fn start_send(self: Pin<&mut Self>, item: A) -> Result<(), Self::Error> {
166        assert_eq!(State::Empty, self.state);
167        let mut this = self.project();
168        let param = this.param.take().unwrap();
169        let future = (this.lambda)(param, Action::Send(item));
170        this.future.set(Some(future));
171        *this.state = State::Sending;
172        Ok(())
173    }
174
175    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
176        loop {
177            let mut this = self.as_mut().project();
178            match this.state {
179                State::Empty => {
180                    if let Some(p) = this.param.take() {
181                        let future = (this.lambda)(p, Action::Flush);
182                        this.future.set(Some(future));
183                        *this.state = State::Flushing
184                    } else {
185                        return Poll::Ready(Ok(()));
186                    }
187                }
188                State::Sending => match ready!(this.future.as_mut().as_pin_mut().unwrap().poll(cx))
189                {
190                    Ok(p) => {
191                        this.future.set(None);
192                        *this.param = Some(p);
193                        *this.state = State::Empty
194                    }
195                    Err(e) => {
196                        this.future.set(None);
197                        *this.state = State::Failed;
198                        return Poll::Ready(Err(Error::Send(e)));
199                    }
200                },
201                State::Flushing => {
202                    match ready!(this.future.as_mut().as_pin_mut().unwrap().poll(cx)) {
203                        Ok(p) => {
204                            this.future.set(None);
205                            *this.param = Some(p);
206                            *this.state = State::Empty;
207                            return Poll::Ready(Ok(()));
208                        }
209                        Err(e) => {
210                            this.future.set(None);
211                            *this.state = State::Failed;
212                            return Poll::Ready(Err(Error::Send(e)));
213                        }
214                    }
215                }
216                State::Closing => match ready!(this.future.as_mut().as_pin_mut().unwrap().poll(cx))
217                {
218                    Ok(_) => {
219                        this.future.set(None);
220                        *this.state = State::Closed;
221                        return Poll::Ready(Ok(()));
222                    }
223                    Err(e) => {
224                        this.future.set(None);
225                        *this.state = State::Failed;
226                        return Poll::Ready(Err(Error::Send(e)));
227                    }
228                },
229                State::Closed | State::Failed => return Poll::Ready(Err(Error::Closed)),
230            }
231        }
232    }
233
234    fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
235        loop {
236            let mut this = self.as_mut().project();
237            match this.state {
238                State::Empty => {
239                    if let Some(p) = this.param.take() {
240                        let future = (this.lambda)(p, Action::Close);
241                        this.future.set(Some(future));
242                        *this.state = State::Closing;
243                    } else {
244                        return Poll::Ready(Ok(()));
245                    }
246                }
247                State::Sending => match ready!(this.future.as_mut().as_pin_mut().unwrap().poll(cx))
248                {
249                    Ok(p) => {
250                        this.future.set(None);
251                        *this.param = Some(p);
252                        *this.state = State::Empty
253                    }
254                    Err(e) => {
255                        this.future.set(None);
256                        *this.state = State::Failed;
257                        return Poll::Ready(Err(Error::Send(e)));
258                    }
259                },
260                State::Flushing => {
261                    match ready!(this.future.as_mut().as_pin_mut().unwrap().poll(cx)) {
262                        Ok(p) => {
263                            this.future.set(None);
264                            *this.param = Some(p);
265                            *this.state = State::Empty
266                        }
267                        Err(e) => {
268                            this.future.set(None);
269                            *this.state = State::Failed;
270                            return Poll::Ready(Err(Error::Send(e)));
271                        }
272                    }
273                }
274                State::Closing => match ready!(this.future.as_mut().as_pin_mut().unwrap().poll(cx))
275                {
276                    Ok(_) => {
277                        this.future.set(None);
278                        *this.state = State::Closed;
279                        return Poll::Ready(Ok(()));
280                    }
281                    Err(e) => {
282                        this.future.set(None);
283                        *this.state = State::Failed;
284                        return Poll::Ready(Err(Error::Send(e)));
285                    }
286                },
287                State::Closed => return Poll::Ready(Ok(())),
288                State::Failed => return Poll::Ready(Err(Error::Closed)),
289            }
290        }
291    }
292}
293
294#[cfg(test)]
295mod tests {
296    use async_std::{io, task};
297    use futures::{channel::mpsc, prelude::*};
298
299    use crate::quicksink::{make_sink, Action};
300
301    #[test]
302    fn smoke_test() {
303        task::block_on(async {
304            let sink = make_sink(io::stdout(), |mut stdout, action| async move {
305                match action {
306                    Action::Send(x) => stdout.write_all(x).await?,
307                    Action::Flush => stdout.flush().await?,
308                    Action::Close => stdout.close().await?,
309                }
310                Ok::<_, io::Error>(stdout)
311            });
312
313            let values = vec![Ok(&b"hello\n"[..]), Ok(&b"world\n"[..])];
314            assert!(stream::iter(values).forward(sink).await.is_ok())
315        })
316    }
317
318    #[test]
319    fn replay() {
320        task::block_on(async {
321            let (tx, rx) = mpsc::channel(5);
322
323            let sink = make_sink(tx, |mut tx, action| async move {
324                tx.send(action.clone()).await?;
325                if action == Action::Close {
326                    tx.close().await?
327                }
328                Ok::<_, mpsc::SendError>(tx)
329            });
330
331            futures::pin_mut!(sink);
332
333            let expected = [
334                Action::Send("hello\n"),
335                Action::Flush,
336                Action::Send("world\n"),
337                Action::Flush,
338                Action::Close,
339            ];
340
341            for &item in &["hello\n", "world\n"] {
342                sink.send(item).await.unwrap()
343            }
344
345            sink.close().await.unwrap();
346
347            let actual = rx.collect::<Vec<_>>().await;
348
349            assert_eq!(&expected[..], &actual[..])
350        });
351    }
352
353    #[test]
354    fn error_does_not_panic() {
355        task::block_on(async {
356            let sink = make_sink(io::stdout(), |mut _stdout, _action| async move {
357                Err(io::Error::new(io::ErrorKind::Other, "oh no"))
358            });
359
360            futures::pin_mut!(sink);
361
362            let result = sink.send("hello").await;
363            match result {
364                Err(crate::quicksink::Error::Send(e)) => {
365                    assert_eq!(e.kind(), io::ErrorKind::Other);
366                    assert_eq!(e.to_string(), "oh no")
367                }
368                _ => panic!("unexpected result: {:?}", result),
369            };
370
371            // Call send again, expect not to panic.
372            let result = sink.send("hello").await;
373            match result {
374                Err(crate::quicksink::Error::Closed) => {}
375                _ => panic!("unexpected result: {:?}", result),
376            };
377        })
378    }
379}