libp2p_webrtc_utils/stream/
state.rs

1// Copyright 2022 Parity Technologies (UK) Ltd.
2//
3// Permission is hereby granted, free of charge, to any person obtaining a
4// copy of this software and associated documentation files (the "Software"),
5// to deal in the Software without restriction, including without limitation
6// the rights to use, copy, modify, merge, publish, distribute, sublicense,
7// and/or sell copies of the Software, and to permit persons to whom the
8// Software is furnished to do so, subject to the following conditions:
9//
10// The above copyright notice and this permission notice shall be included in
11// all copies or substantial portions of the Software.
12//
13// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
14// OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
16// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
18// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
19// DEALINGS IN THE SOFTWARE.
20
21use std::io;
22
23use bytes::Bytes;
24
25use crate::proto::Flag;
26
27#[derive(Debug, Copy, Clone)]
28pub(crate) enum State {
29    Open,
30    ReadClosed,
31    WriteClosed,
32    ClosingRead {
33        /// Whether the write side of our channel was already closed.
34        write_closed: bool,
35        inner: Closing,
36    },
37    ClosingWrite {
38        /// Whether the write side of our channel was already closed.
39        read_closed: bool,
40        inner: Closing,
41    },
42    BothClosed {
43        reset: bool,
44    },
45}
46
47/// Represents the state of closing one half (either read or write) of the connection.
48///
49/// Gracefully closing the read or write requires sending the `STOP_SENDING` or `FIN` flag
50/// respectively and flushing the underlying connection.
51#[derive(Debug, Copy, Clone)]
52pub(crate) enum Closing {
53    Requested,
54    MessageSent,
55}
56
57impl State {
58    /// Performs a state transition for a flag contained in an inbound message.
59    pub(crate) fn handle_inbound_flag(&mut self, flag: Flag, buffer: &mut Bytes) {
60        let current = *self;
61
62        match (current, flag) {
63            (Self::Open, Flag::FIN) => {
64                *self = Self::ReadClosed;
65            }
66            (Self::WriteClosed, Flag::FIN) => {
67                *self = Self::BothClosed { reset: false };
68            }
69            (Self::Open, Flag::STOP_SENDING) => {
70                *self = Self::WriteClosed;
71            }
72            (Self::ReadClosed, Flag::STOP_SENDING) => {
73                *self = Self::BothClosed { reset: false };
74            }
75            (_, Flag::RESET) => {
76                buffer.clear();
77                *self = Self::BothClosed { reset: true };
78            }
79            _ => {}
80        }
81    }
82
83    pub(crate) fn write_closed(&mut self) {
84        match self {
85            State::ClosingWrite {
86                read_closed: true,
87                inner,
88            } => {
89                debug_assert!(matches!(inner, Closing::MessageSent));
90
91                *self = State::BothClosed { reset: false };
92            }
93            State::ClosingWrite {
94                read_closed: false,
95                inner,
96            } => {
97                debug_assert!(matches!(inner, Closing::MessageSent));
98
99                *self = State::WriteClosed;
100            }
101            State::Open
102            | State::ReadClosed
103            | State::WriteClosed
104            | State::ClosingRead { .. }
105            | State::BothClosed { .. } => {
106                unreachable!("bad state machine impl")
107            }
108        }
109    }
110
111    pub(crate) fn close_write_message_sent(&mut self) {
112        match self {
113            State::ClosingWrite { inner, read_closed } => {
114                debug_assert!(matches!(inner, Closing::Requested));
115
116                *self = State::ClosingWrite {
117                    read_closed: *read_closed,
118                    inner: Closing::MessageSent,
119                };
120            }
121            State::Open
122            | State::ReadClosed
123            | State::WriteClosed
124            | State::ClosingRead { .. }
125            | State::BothClosed { .. } => {
126                unreachable!("bad state machine impl")
127            }
128        }
129    }
130
131    pub(crate) fn read_closed(&mut self) {
132        match self {
133            State::ClosingRead {
134                write_closed: true,
135                inner,
136            } => {
137                debug_assert!(matches!(inner, Closing::MessageSent));
138
139                *self = State::BothClosed { reset: false };
140            }
141            State::ClosingRead {
142                write_closed: false,
143                inner,
144            } => {
145                debug_assert!(matches!(inner, Closing::MessageSent));
146
147                *self = State::ReadClosed;
148            }
149            State::Open
150            | State::ReadClosed
151            | State::WriteClosed
152            | State::ClosingWrite { .. }
153            | State::BothClosed { .. } => {
154                unreachable!("bad state machine impl")
155            }
156        }
157    }
158
159    pub(crate) fn close_read_message_sent(&mut self) {
160        match self {
161            State::ClosingRead {
162                inner,
163                write_closed,
164            } => {
165                debug_assert!(matches!(inner, Closing::Requested));
166
167                *self = State::ClosingRead {
168                    write_closed: *write_closed,
169                    inner: Closing::MessageSent,
170                };
171            }
172            State::Open
173            | State::ReadClosed
174            | State::WriteClosed
175            | State::ClosingWrite { .. }
176            | State::BothClosed { .. } => {
177                unreachable!("bad state machine impl")
178            }
179        }
180    }
181
182    /// Whether we should read from the stream in the [`futures::AsyncWrite`] implementation.
183    ///
184    /// This is necessary for read-closed streams because we would otherwise
185    /// not read any more flags from the socket.
186    pub(crate) fn read_flags_in_async_write(&self) -> bool {
187        matches!(self, Self::ReadClosed)
188    }
189
190    /// Acts as a "barrier" for [`futures::AsyncRead::poll_read`].
191    pub(crate) fn read_barrier(&self) -> io::Result<()> {
192        use State::*;
193
194        let kind = match self {
195            Open
196            | WriteClosed
197            | ClosingWrite {
198                read_closed: false, ..
199            } => return Ok(()),
200            ClosingWrite {
201                read_closed: true, ..
202            }
203            | ReadClosed
204            | ClosingRead { .. }
205            | BothClosed { reset: false } => io::ErrorKind::BrokenPipe,
206            BothClosed { reset: true } => io::ErrorKind::ConnectionReset,
207        };
208
209        Err(kind.into())
210    }
211
212    /// Acts as a "barrier" for [`futures::AsyncWrite::poll_write`].
213    pub(crate) fn write_barrier(&self) -> io::Result<()> {
214        use State::*;
215
216        let kind = match self {
217            Open
218            | ReadClosed
219            | ClosingRead {
220                write_closed: false,
221                ..
222            } => return Ok(()),
223            ClosingRead {
224                write_closed: true, ..
225            }
226            | WriteClosed
227            | ClosingWrite { .. }
228            | BothClosed { reset: false } => io::ErrorKind::BrokenPipe,
229            BothClosed { reset: true } => io::ErrorKind::ConnectionReset,
230        };
231
232        Err(kind.into())
233    }
234
235    /// Acts as a "barrier" for [`futures::AsyncWrite::poll_close`].
236    pub(crate) fn close_write_barrier(&mut self) -> io::Result<Option<Closing>> {
237        loop {
238            match &self {
239                State::WriteClosed => return Ok(None),
240
241                State::ClosingWrite { inner, .. } => return Ok(Some(*inner)),
242
243                State::Open => {
244                    *self = Self::ClosingWrite {
245                        read_closed: false,
246                        inner: Closing::Requested,
247                    };
248                }
249                State::ReadClosed => {
250                    *self = Self::ClosingWrite {
251                        read_closed: true,
252                        inner: Closing::Requested,
253                    };
254                }
255
256                State::ClosingRead {
257                    write_closed: true, ..
258                }
259                | State::BothClosed { reset: false } => {
260                    return Err(io::ErrorKind::BrokenPipe.into())
261                }
262
263                State::ClosingRead {
264                    write_closed: false,
265                    ..
266                } => {
267                    return Err(io::Error::new(
268                        io::ErrorKind::Other,
269                        "cannot close read half while closing write half",
270                    ))
271                }
272
273                State::BothClosed { reset: true } => {
274                    return Err(io::ErrorKind::ConnectionReset.into())
275                }
276            }
277        }
278    }
279
280    /// Acts as a "barrier" for [`Stream::poll_close_read`](super::Stream::poll_close_read).
281    pub(crate) fn close_read_barrier(&mut self) -> io::Result<Option<Closing>> {
282        loop {
283            match self {
284                State::ReadClosed => return Ok(None),
285
286                State::ClosingRead { inner, .. } => return Ok(Some(*inner)),
287
288                State::Open => {
289                    *self = Self::ClosingRead {
290                        write_closed: false,
291                        inner: Closing::Requested,
292                    };
293                }
294                State::WriteClosed => {
295                    *self = Self::ClosingRead {
296                        write_closed: true,
297                        inner: Closing::Requested,
298                    };
299                }
300
301                State::ClosingWrite {
302                    read_closed: true, ..
303                }
304                | State::BothClosed { reset: false } => {
305                    return Err(io::ErrorKind::BrokenPipe.into())
306                }
307
308                State::ClosingWrite {
309                    read_closed: false, ..
310                } => {
311                    return Err(io::Error::new(
312                        io::ErrorKind::Other,
313                        "cannot close write half while closing read half",
314                    ))
315                }
316
317                State::BothClosed { reset: true } => {
318                    return Err(io::ErrorKind::ConnectionReset.into())
319                }
320            }
321        }
322    }
323}
324
325#[cfg(test)]
326mod tests {
327    use std::io::ErrorKind;
328
329    use super::*;
330
331    #[test]
332    fn cannot_read_after_receiving_fin() {
333        let mut open = State::Open;
334
335        open.handle_inbound_flag(Flag::FIN, &mut Bytes::default());
336        let error = open.read_barrier().unwrap_err();
337
338        assert_eq!(error.kind(), ErrorKind::BrokenPipe)
339    }
340
341    #[test]
342    fn cannot_read_after_closing_read() {
343        let mut open = State::Open;
344
345        open.close_read_barrier().unwrap();
346        open.close_read_message_sent();
347        open.read_closed();
348        let error = open.read_barrier().unwrap_err();
349
350        assert_eq!(error.kind(), ErrorKind::BrokenPipe)
351    }
352
353    #[test]
354    fn cannot_write_after_receiving_stop_sending() {
355        let mut open = State::Open;
356
357        open.handle_inbound_flag(Flag::STOP_SENDING, &mut Bytes::default());
358        let error = open.write_barrier().unwrap_err();
359
360        assert_eq!(error.kind(), ErrorKind::BrokenPipe)
361    }
362
363    #[test]
364    fn cannot_write_after_closing_write() {
365        let mut open = State::Open;
366
367        open.close_write_barrier().unwrap();
368        open.close_write_message_sent();
369        open.write_closed();
370        let error = open.write_barrier().unwrap_err();
371
372        assert_eq!(error.kind(), ErrorKind::BrokenPipe)
373    }
374
375    #[test]
376    fn everything_broken_after_receiving_reset() {
377        let mut open = State::Open;
378
379        open.handle_inbound_flag(Flag::RESET, &mut Bytes::default());
380        let error1 = open.read_barrier().unwrap_err();
381        let error2 = open.write_barrier().unwrap_err();
382        let error3 = open.close_write_barrier().unwrap_err();
383        let error4 = open.close_read_barrier().unwrap_err();
384
385        assert_eq!(error1.kind(), ErrorKind::ConnectionReset);
386        assert_eq!(error2.kind(), ErrorKind::ConnectionReset);
387        assert_eq!(error3.kind(), ErrorKind::ConnectionReset);
388        assert_eq!(error4.kind(), ErrorKind::ConnectionReset);
389    }
390
391    #[test]
392    fn should_read_flags_in_async_write_after_read_closed() {
393        let mut open = State::Open;
394
395        open.handle_inbound_flag(Flag::FIN, &mut Bytes::default());
396
397        assert!(open.read_flags_in_async_write())
398    }
399
400    #[test]
401    fn cannot_read_or_write_after_receiving_fin_and_stop_sending() {
402        let mut open = State::Open;
403
404        open.handle_inbound_flag(Flag::FIN, &mut Bytes::default());
405        open.handle_inbound_flag(Flag::STOP_SENDING, &mut Bytes::default());
406
407        let error1 = open.read_barrier().unwrap_err();
408        let error2 = open.write_barrier().unwrap_err();
409
410        assert_eq!(error1.kind(), ErrorKind::BrokenPipe);
411        assert_eq!(error2.kind(), ErrorKind::BrokenPipe);
412    }
413
414    #[test]
415    fn can_read_after_closing_write() {
416        let mut open = State::Open;
417
418        open.close_write_barrier().unwrap();
419        open.close_write_message_sent();
420        open.write_closed();
421
422        open.read_barrier().unwrap();
423    }
424
425    #[test]
426    fn can_write_after_closing_read() {
427        let mut open = State::Open;
428
429        open.close_read_barrier().unwrap();
430        open.close_read_message_sent();
431        open.read_closed();
432
433        open.write_barrier().unwrap();
434    }
435
436    #[test]
437    fn cannot_write_after_starting_close() {
438        let mut open = State::Open;
439
440        open.close_write_barrier().expect("to close in open");
441        let error = open.write_barrier().unwrap_err();
442
443        assert_eq!(error.kind(), ErrorKind::BrokenPipe);
444    }
445
446    #[test]
447    fn cannot_read_after_starting_close() {
448        let mut open = State::Open;
449
450        open.close_read_barrier().expect("to close in open");
451        let error = open.read_barrier().unwrap_err();
452
453        assert_eq!(error.kind(), ErrorKind::BrokenPipe);
454    }
455
456    #[test]
457    fn can_read_in_open() {
458        let open = State::Open;
459
460        let result = open.read_barrier();
461
462        result.unwrap();
463    }
464
465    #[test]
466    fn can_write_in_open() {
467        let open = State::Open;
468
469        let result = open.write_barrier();
470
471        result.unwrap();
472    }
473
474    #[test]
475    fn write_close_barrier_returns_ok_when_closed() {
476        let mut open = State::Open;
477
478        open.close_write_barrier().unwrap();
479        open.close_write_message_sent();
480        open.write_closed();
481
482        let maybe = open.close_write_barrier().unwrap();
483
484        assert!(maybe.is_none())
485    }
486
487    #[test]
488    fn read_close_barrier_returns_ok_when_closed() {
489        let mut open = State::Open;
490
491        open.close_read_barrier().unwrap();
492        open.close_read_message_sent();
493        open.read_closed();
494
495        let maybe = open.close_read_barrier().unwrap();
496
497        assert!(maybe.is_none())
498    }
499
500    #[test]
501    fn reset_flag_clears_buffer() {
502        let mut open = State::Open;
503        let mut buffer = Bytes::copy_from_slice(b"foobar");
504
505        open.handle_inbound_flag(Flag::RESET, &mut buffer);
506
507        assert!(buffer.is_empty());
508    }
509}