1use core::fmt;
2use std::{
3 io,
4 pin::Pin,
5 sync::{Arc, Mutex},
6 task::{Context, Poll},
7};
8
9use futures::{
10 channel::{mpsc, oneshot},
11 SinkExt as _, StreamExt as _,
12};
13use libp2p_identity::PeerId;
14use libp2p_swarm::{Stream, StreamProtocol};
15
16use crate::{handler::NewStream, shared::Shared, AlreadyRegistered};
17
18#[derive(Clone)]
22pub struct Control {
23 shared: Arc<Mutex<Shared>>,
24}
25
26impl Control {
27 pub(crate) fn new(shared: Arc<Mutex<Shared>>) -> Self {
28 Self { shared }
29 }
30
31 pub async fn open_stream(
45 &mut self,
46 peer: PeerId,
47 protocol: StreamProtocol,
48 ) -> Result<Stream, OpenStreamError> {
49 tracing::debug!(%peer, "Requesting new stream");
50
51 let mut new_stream_sender = Shared::lock(&self.shared).sender(peer);
52
53 let (sender, receiver) = oneshot::channel();
54
55 new_stream_sender
56 .send(NewStream { protocol, sender })
57 .await
58 .map_err(|e| io::Error::new(io::ErrorKind::ConnectionReset, e))?;
59
60 let stream = receiver
61 .await
62 .map_err(|e| io::Error::new(io::ErrorKind::ConnectionReset, e))??;
63
64 Ok(stream)
65 }
66
67 pub fn accept(
71 &mut self,
72 protocol: StreamProtocol,
73 ) -> Result<IncomingStreams, AlreadyRegistered> {
74 Shared::lock(&self.shared).accept(protocol)
75 }
76}
77
78#[derive(Debug)]
80#[non_exhaustive]
81pub enum OpenStreamError {
82 UnsupportedProtocol(StreamProtocol),
84 Io(std::io::Error),
86}
87
88impl From<std::io::Error> for OpenStreamError {
89 fn from(v: std::io::Error) -> Self {
90 Self::Io(v)
91 }
92}
93
94impl fmt::Display for OpenStreamError {
95 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
96 match self {
97 OpenStreamError::UnsupportedProtocol(p) => {
98 write!(f, "failed to open stream: remote peer does not support {p}")
99 }
100 OpenStreamError::Io(e) => {
101 write!(f, "failed to open stream: io error: {e}")
102 }
103 }
104 }
105}
106
107impl std::error::Error for OpenStreamError {
108 fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
109 match self {
110 Self::Io(error) => Some(error),
111 _ => None,
112 }
113 }
114}
115
116#[must_use = "Streams do nothing unless polled."]
118pub struct IncomingStreams {
119 receiver: mpsc::Receiver<(PeerId, Stream)>,
120}
121
122impl IncomingStreams {
123 pub(crate) fn new(receiver: mpsc::Receiver<(PeerId, Stream)>) -> Self {
124 Self { receiver }
125 }
126}
127
128impl futures::Stream for IncomingStreams {
129 type Item = (PeerId, Stream);
130
131 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
132 self.receiver.poll_next_unpin(cx)
133 }
134}