1use std::{
36 future::Future,
37 pin::Pin,
38 task::{Context, Poll},
39};
40
41use futures::{ready, sink::Sink};
42use pin_project_lite::pin_project;
43
44pub(crate) fn make_sink<S, F, T, A, E>(init: S, f: F) -> SinkImpl<S, F, T, A, E>
51where
52 F: FnMut(S, Action<A>) -> T,
53 T: Future<Output = Result<S, E>>,
54{
55 SinkImpl {
56 lambda: f,
57 future: None,
58 param: Some(init),
59 state: State::Empty,
60 _mark: std::marker::PhantomData,
61 }
62}
63
64#[derive(Clone, Debug, PartialEq, Eq)]
70pub(crate) enum Action<A> {
71 Send(A),
74 Flush,
77 Close,
80}
81
82#[derive(Debug, PartialEq, Eq)]
84enum State {
85 Empty,
87 Sending,
89 Flushing,
91 Closing,
93 Closed,
95 Failed,
97}
98
99#[derive(Debug, thiserror::Error)]
101pub(crate) enum Error<E> {
102 #[error("Error while sending over the sink, {0}")]
103 Send(E),
104 #[error("The Sink has closed")]
105 Closed,
106}
107
108pin_project! {
109 #[derive(Debug)]
111 pub(crate) struct SinkImpl<S, F, T, A, E> {
112 lambda: F,
113 #[pin] future: Option<T>,
114 param: Option<S>,
115 state: State,
116 _mark: std::marker::PhantomData<(A, E)>
117 }
118}
119
120impl<S, F, T, A, E> Sink<A> for SinkImpl<S, F, T, A, E>
121where
122 F: FnMut(S, Action<A>) -> T,
123 T: Future<Output = Result<S, E>>,
124{
125 type Error = Error<E>;
126
127 fn poll_ready(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
128 let mut this = self.project();
129 match this.state {
130 State::Sending | State::Flushing => {
131 match ready!(this.future.as_mut().as_pin_mut().unwrap().poll(cx)) {
132 Ok(p) => {
133 this.future.set(None);
134 *this.param = Some(p);
135 *this.state = State::Empty;
136 Poll::Ready(Ok(()))
137 }
138 Err(e) => {
139 this.future.set(None);
140 *this.state = State::Failed;
141 Poll::Ready(Err(Error::Send(e)))
142 }
143 }
144 }
145 State::Closing => match ready!(this.future.as_mut().as_pin_mut().unwrap().poll(cx)) {
146 Ok(_) => {
147 this.future.set(None);
148 *this.state = State::Closed;
149 Poll::Ready(Err(Error::Closed))
150 }
151 Err(e) => {
152 this.future.set(None);
153 *this.state = State::Failed;
154 Poll::Ready(Err(Error::Send(e)))
155 }
156 },
157 State::Empty => {
158 assert!(this.param.is_some());
159 Poll::Ready(Ok(()))
160 }
161 State::Closed | State::Failed => Poll::Ready(Err(Error::Closed)),
162 }
163 }
164
165 fn start_send(self: Pin<&mut Self>, item: A) -> Result<(), Self::Error> {
166 assert_eq!(State::Empty, self.state);
167 let mut this = self.project();
168 let param = this.param.take().unwrap();
169 let future = (this.lambda)(param, Action::Send(item));
170 this.future.set(Some(future));
171 *this.state = State::Sending;
172 Ok(())
173 }
174
175 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
176 loop {
177 let mut this = self.as_mut().project();
178 match this.state {
179 State::Empty => {
180 if let Some(p) = this.param.take() {
181 let future = (this.lambda)(p, Action::Flush);
182 this.future.set(Some(future));
183 *this.state = State::Flushing
184 } else {
185 return Poll::Ready(Ok(()));
186 }
187 }
188 State::Sending => match ready!(this.future.as_mut().as_pin_mut().unwrap().poll(cx))
189 {
190 Ok(p) => {
191 this.future.set(None);
192 *this.param = Some(p);
193 *this.state = State::Empty
194 }
195 Err(e) => {
196 this.future.set(None);
197 *this.state = State::Failed;
198 return Poll::Ready(Err(Error::Send(e)));
199 }
200 },
201 State::Flushing => {
202 match ready!(this.future.as_mut().as_pin_mut().unwrap().poll(cx)) {
203 Ok(p) => {
204 this.future.set(None);
205 *this.param = Some(p);
206 *this.state = State::Empty;
207 return Poll::Ready(Ok(()));
208 }
209 Err(e) => {
210 this.future.set(None);
211 *this.state = State::Failed;
212 return Poll::Ready(Err(Error::Send(e)));
213 }
214 }
215 }
216 State::Closing => match ready!(this.future.as_mut().as_pin_mut().unwrap().poll(cx))
217 {
218 Ok(_) => {
219 this.future.set(None);
220 *this.state = State::Closed;
221 return Poll::Ready(Ok(()));
222 }
223 Err(e) => {
224 this.future.set(None);
225 *this.state = State::Failed;
226 return Poll::Ready(Err(Error::Send(e)));
227 }
228 },
229 State::Closed | State::Failed => return Poll::Ready(Err(Error::Closed)),
230 }
231 }
232 }
233
234 fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
235 loop {
236 let mut this = self.as_mut().project();
237 match this.state {
238 State::Empty => {
239 if let Some(p) = this.param.take() {
240 let future = (this.lambda)(p, Action::Close);
241 this.future.set(Some(future));
242 *this.state = State::Closing;
243 } else {
244 return Poll::Ready(Ok(()));
245 }
246 }
247 State::Sending => match ready!(this.future.as_mut().as_pin_mut().unwrap().poll(cx))
248 {
249 Ok(p) => {
250 this.future.set(None);
251 *this.param = Some(p);
252 *this.state = State::Empty
253 }
254 Err(e) => {
255 this.future.set(None);
256 *this.state = State::Failed;
257 return Poll::Ready(Err(Error::Send(e)));
258 }
259 },
260 State::Flushing => {
261 match ready!(this.future.as_mut().as_pin_mut().unwrap().poll(cx)) {
262 Ok(p) => {
263 this.future.set(None);
264 *this.param = Some(p);
265 *this.state = State::Empty
266 }
267 Err(e) => {
268 this.future.set(None);
269 *this.state = State::Failed;
270 return Poll::Ready(Err(Error::Send(e)));
271 }
272 }
273 }
274 State::Closing => match ready!(this.future.as_mut().as_pin_mut().unwrap().poll(cx))
275 {
276 Ok(_) => {
277 this.future.set(None);
278 *this.state = State::Closed;
279 return Poll::Ready(Ok(()));
280 }
281 Err(e) => {
282 this.future.set(None);
283 *this.state = State::Failed;
284 return Poll::Ready(Err(Error::Send(e)));
285 }
286 },
287 State::Closed => return Poll::Ready(Ok(())),
288 State::Failed => return Poll::Ready(Err(Error::Closed)),
289 }
290 }
291 }
292}
293
294#[cfg(test)]
295mod tests {
296 use futures::{channel::mpsc, prelude::*};
297 use tokio::io::{self, AsyncWriteExt};
298
299 use crate::quicksink::{make_sink, Action};
300
301 #[tokio::test]
302 async fn smoke_test() {
303 let sink = make_sink(io::stdout(), |mut stdout, action| async move {
304 match action {
305 Action::Send(x) => stdout.write_all(x).await?,
306 Action::Flush => stdout.flush().await?,
307 Action::Close => stdout.shutdown().await?,
308 }
309 Ok::<_, io::Error>(stdout)
310 });
311
312 let values = vec![Ok(&b"hello\n"[..]), Ok(&b"world\n"[..])];
313 assert!(stream::iter(values).forward(sink).await.is_ok())
314 }
315
316 #[tokio::test]
317 async fn replay() {
318 let (tx, rx) = mpsc::channel(5);
319
320 let sink = make_sink(tx, |mut tx, action| async move {
321 tx.send(action.clone()).await?;
322 if action == Action::Close {
323 tx.close().await?
324 }
325 Ok::<_, mpsc::SendError>(tx)
326 });
327
328 futures::pin_mut!(sink);
329
330 let expected = [
331 Action::Send("hello\n"),
332 Action::Flush,
333 Action::Send("world\n"),
334 Action::Flush,
335 Action::Close,
336 ];
337
338 for &item in &["hello\n", "world\n"] {
339 sink.send(item).await.unwrap()
340 }
341
342 sink.close().await.unwrap();
343
344 let actual = rx.collect::<Vec<_>>().await;
345
346 assert_eq!(&expected[..], &actual[..])
347 }
348
349 #[tokio::test]
350 async fn error_does_not_panic() {
351 let sink = make_sink(io::stdout(), |mut _stdout, _action| async move {
352 Err(io::Error::other("oh no"))
353 });
354
355 futures::pin_mut!(sink);
356
357 let result = sink.send("hello").await;
358 match result {
359 Err(crate::quicksink::Error::Send(e)) => {
360 assert_eq!(e.kind(), io::ErrorKind::Other);
361 assert_eq!(e.to_string(), "oh no")
362 }
363 _ => panic!("unexpected result: {result:?}"),
364 };
365
366 let result = sink.send("hello").await;
368 match result {
369 Err(crate::quicksink::Error::Closed) => {}
370 _ => panic!("unexpected result: {result:?}"),
371 };
372 }
373}