1use std::{
22 mem,
23 pin::Pin,
24 task::{Context, Poll},
25};
26
27use futures::{future::Either, prelude::*};
28pub(crate) use multistream_select::Version;
29use multistream_select::{DialerSelectFuture, ListenerSelectFuture};
30
31use crate::{
32 connection::ConnectedPoint,
33 upgrade::{InboundConnectionUpgrade, OutboundConnectionUpgrade, UpgradeError},
34 Negotiated,
35};
36
37pub(crate) fn apply<C, U>(
40 conn: C,
41 up: U,
42 cp: ConnectedPoint,
43 v: Version,
44) -> Either<InboundUpgradeApply<C, U>, OutboundUpgradeApply<C, U>>
45where
46 C: AsyncRead + AsyncWrite + Unpin,
47 U: InboundConnectionUpgrade<Negotiated<C>> + OutboundConnectionUpgrade<Negotiated<C>>,
48{
49 match cp {
50 ConnectedPoint::Dialer { role_override, .. } if role_override.is_dialer() => {
51 Either::Right(apply_outbound(conn, up, v))
52 }
53 _ => Either::Left(apply_inbound(conn, up)),
54 }
55}
56
57pub(crate) fn apply_inbound<C, U>(conn: C, up: U) -> InboundUpgradeApply<C, U>
59where
60 C: AsyncRead + AsyncWrite + Unpin,
61 U: InboundConnectionUpgrade<Negotiated<C>>,
62{
63 InboundUpgradeApply {
64 inner: InboundUpgradeApplyState::Init {
65 future: multistream_select::listener_select_proto(conn, up.protocol_info()),
66 upgrade: up,
67 },
68 }
69}
70
71pub(crate) fn apply_outbound<C, U>(conn: C, up: U, v: Version) -> OutboundUpgradeApply<C, U>
73where
74 C: AsyncRead + AsyncWrite + Unpin,
75 U: OutboundConnectionUpgrade<Negotiated<C>>,
76{
77 OutboundUpgradeApply {
78 inner: OutboundUpgradeApplyState::Init {
79 future: multistream_select::dialer_select_proto(conn, up.protocol_info(), v),
80 upgrade: up,
81 },
82 }
83}
84
85pub struct InboundUpgradeApply<C, U>
87where
88 C: AsyncRead + AsyncWrite + Unpin,
89 U: InboundConnectionUpgrade<Negotiated<C>>,
90{
91 inner: InboundUpgradeApplyState<C, U>,
92}
93
94#[allow(clippy::large_enum_variant)]
95enum InboundUpgradeApplyState<C, U>
96where
97 C: AsyncRead + AsyncWrite + Unpin,
98 U: InboundConnectionUpgrade<Negotiated<C>>,
99{
100 Init {
101 future: ListenerSelectFuture<C, U::Info>,
102 upgrade: U,
103 },
104 Upgrade {
105 future: Pin<Box<U::Future>>,
106 name: String,
107 },
108 Undefined,
109}
110
111impl<C, U> Unpin for InboundUpgradeApply<C, U>
112where
113 C: AsyncRead + AsyncWrite + Unpin,
114 U: InboundConnectionUpgrade<Negotiated<C>>,
115{
116}
117
118impl<C, U> Future for InboundUpgradeApply<C, U>
119where
120 C: AsyncRead + AsyncWrite + Unpin,
121 U: InboundConnectionUpgrade<Negotiated<C>>,
122{
123 type Output = Result<U::Output, UpgradeError<U::Error>>;
124
125 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
126 loop {
127 match mem::replace(&mut self.inner, InboundUpgradeApplyState::Undefined) {
128 InboundUpgradeApplyState::Init {
129 mut future,
130 upgrade,
131 } => {
132 let (info, io) = match Future::poll(Pin::new(&mut future), cx)? {
133 Poll::Ready(x) => x,
134 Poll::Pending => {
135 self.inner = InboundUpgradeApplyState::Init { future, upgrade };
136 return Poll::Pending;
137 }
138 };
139 self.inner = InboundUpgradeApplyState::Upgrade {
140 future: Box::pin(upgrade.upgrade_inbound(io, info.clone())),
141 name: info.as_ref().to_owned(),
142 };
143 }
144 InboundUpgradeApplyState::Upgrade { mut future, name } => {
145 match Future::poll(Pin::new(&mut future), cx) {
146 Poll::Pending => {
147 self.inner = InboundUpgradeApplyState::Upgrade { future, name };
148 return Poll::Pending;
149 }
150 Poll::Ready(Ok(x)) => {
151 tracing::trace!(upgrade=%name, "Upgraded inbound stream");
152 return Poll::Ready(Ok(x));
153 }
154 Poll::Ready(Err(e)) => {
155 tracing::debug!(upgrade=%name, "Failed to upgrade inbound stream");
156 return Poll::Ready(Err(UpgradeError::Apply(e)));
157 }
158 }
159 }
160 InboundUpgradeApplyState::Undefined => {
161 panic!("InboundUpgradeApplyState::poll called after completion")
162 }
163 }
164 }
165 }
166}
167
168pub struct OutboundUpgradeApply<C, U>
170where
171 C: AsyncRead + AsyncWrite + Unpin,
172 U: OutboundConnectionUpgrade<Negotiated<C>>,
173{
174 inner: OutboundUpgradeApplyState<C, U>,
175}
176
177enum OutboundUpgradeApplyState<C, U>
178where
179 C: AsyncRead + AsyncWrite + Unpin,
180 U: OutboundConnectionUpgrade<Negotiated<C>>,
181{
182 Init {
183 future: DialerSelectFuture<C, <U::InfoIter as IntoIterator>::IntoIter>,
184 upgrade: U,
185 },
186 Upgrade {
187 future: Pin<Box<U::Future>>,
188 name: String,
189 },
190 Undefined,
191}
192
193impl<C, U> Unpin for OutboundUpgradeApply<C, U>
194where
195 C: AsyncRead + AsyncWrite + Unpin,
196 U: OutboundConnectionUpgrade<Negotiated<C>>,
197{
198}
199
200impl<C, U> Future for OutboundUpgradeApply<C, U>
201where
202 C: AsyncRead + AsyncWrite + Unpin,
203 U: OutboundConnectionUpgrade<Negotiated<C>>,
204{
205 type Output = Result<U::Output, UpgradeError<U::Error>>;
206
207 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
208 loop {
209 match mem::replace(&mut self.inner, OutboundUpgradeApplyState::Undefined) {
210 OutboundUpgradeApplyState::Init {
211 mut future,
212 upgrade,
213 } => {
214 let (info, connection) = match Future::poll(Pin::new(&mut future), cx)? {
215 Poll::Ready(x) => x,
216 Poll::Pending => {
217 self.inner = OutboundUpgradeApplyState::Init { future, upgrade };
218 return Poll::Pending;
219 }
220 };
221 self.inner = OutboundUpgradeApplyState::Upgrade {
222 future: Box::pin(upgrade.upgrade_outbound(connection, info.clone())),
223 name: info.as_ref().to_owned(),
224 };
225 }
226 OutboundUpgradeApplyState::Upgrade { mut future, name } => {
227 match Future::poll(Pin::new(&mut future), cx) {
228 Poll::Pending => {
229 self.inner = OutboundUpgradeApplyState::Upgrade { future, name };
230 return Poll::Pending;
231 }
232 Poll::Ready(Ok(x)) => {
233 tracing::trace!(upgrade=%name, "Upgraded outbound stream");
234 return Poll::Ready(Ok(x));
235 }
236 Poll::Ready(Err(e)) => {
237 tracing::debug!(upgrade=%name, "Failed to upgrade outbound stream",);
238 return Poll::Ready(Err(UpgradeError::Apply(e)));
239 }
240 }
241 }
242 OutboundUpgradeApplyState::Undefined => {
243 panic!("OutboundUpgradeApplyState::poll called after completion")
244 }
245 }
246 }
247 }
248}