libp2p_core/transport/
timeout.rs1use std::{
28 error, fmt, io,
29 pin::Pin,
30 task::{Context, Poll},
31 time::Duration,
32};
33
34use futures::prelude::*;
35use futures_timer::Delay;
36
37use crate::{
38 transport::{DialOpts, ListenerId, TransportError, TransportEvent},
39 Multiaddr, Transport,
40};
41
42#[derive(Debug, Copy, Clone)]
48#[pin_project::pin_project]
49pub struct TransportTimeout<InnerTrans> {
50 #[pin]
51 inner: InnerTrans,
52 outgoing_timeout: Duration,
53 incoming_timeout: Duration,
54}
55
56impl<InnerTrans> TransportTimeout<InnerTrans> {
57 pub fn new(trans: InnerTrans, timeout: Duration) -> Self {
59 TransportTimeout {
60 inner: trans,
61 outgoing_timeout: timeout,
62 incoming_timeout: timeout,
63 }
64 }
65
66 pub fn with_outgoing_timeout(trans: InnerTrans, timeout: Duration) -> Self {
68 TransportTimeout {
69 inner: trans,
70 outgoing_timeout: timeout,
71 incoming_timeout: Duration::from_secs(100 * 365 * 24 * 3600), }
73 }
74
75 pub fn with_ingoing_timeout(trans: InnerTrans, timeout: Duration) -> Self {
77 TransportTimeout {
78 inner: trans,
79 outgoing_timeout: Duration::from_secs(100 * 365 * 24 * 3600), incoming_timeout: timeout,
81 }
82 }
83}
84
85impl<InnerTrans> Transport for TransportTimeout<InnerTrans>
86where
87 InnerTrans: Transport,
88 InnerTrans::Error: 'static,
89{
90 type Output = InnerTrans::Output;
91 type Error = TransportTimeoutError<InnerTrans::Error>;
92 type ListenerUpgrade = Timeout<InnerTrans::ListenerUpgrade>;
93 type Dial = Timeout<InnerTrans::Dial>;
94
95 fn listen_on(
96 &mut self,
97 id: ListenerId,
98 addr: Multiaddr,
99 ) -> Result<(), TransportError<Self::Error>> {
100 self.inner
101 .listen_on(id, addr)
102 .map_err(|err| err.map(TransportTimeoutError::Other))
103 }
104
105 fn remove_listener(&mut self, id: ListenerId) -> bool {
106 self.inner.remove_listener(id)
107 }
108
109 fn dial(
110 &mut self,
111 addr: Multiaddr,
112 opts: DialOpts,
113 ) -> Result<Self::Dial, TransportError<Self::Error>> {
114 let dial = self
115 .inner
116 .dial(addr, opts)
117 .map_err(|err| err.map(TransportTimeoutError::Other))?;
118 Ok(Timeout {
119 inner: dial,
120 timer: Delay::new(self.outgoing_timeout),
121 })
122 }
123
124 fn poll(
125 self: Pin<&mut Self>,
126 cx: &mut Context<'_>,
127 ) -> Poll<TransportEvent<Self::ListenerUpgrade, Self::Error>> {
128 let this = self.project();
129 let timeout = *this.incoming_timeout;
130 this.inner.poll(cx).map(|event| {
131 event
132 .map_upgrade(move |inner_fut| Timeout {
133 inner: inner_fut,
134 timer: Delay::new(timeout),
135 })
136 .map_err(TransportTimeoutError::Other)
137 })
138 }
139}
140
141#[pin_project::pin_project]
146#[must_use = "futures do nothing unless polled"]
147pub struct Timeout<InnerFut> {
148 #[pin]
149 inner: InnerFut,
150 timer: Delay,
151}
152
153impl<InnerFut> Future for Timeout<InnerFut>
154where
155 InnerFut: TryFuture,
156{
157 type Output = Result<InnerFut::Ok, TransportTimeoutError<InnerFut::Error>>;
158
159 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
160 let mut this = self.project();
166
167 match TryFuture::try_poll(this.inner, cx) {
168 Poll::Pending => {}
169 Poll::Ready(Ok(v)) => return Poll::Ready(Ok(v)),
170 Poll::Ready(Err(err)) => return Poll::Ready(Err(TransportTimeoutError::Other(err))),
171 }
172
173 match Pin::new(&mut this.timer).poll(cx) {
174 Poll::Pending => Poll::Pending,
175 Poll::Ready(()) => Poll::Ready(Err(TransportTimeoutError::Timeout)),
176 }
177 }
178}
179
180#[derive(Debug)]
182pub enum TransportTimeoutError<TErr> {
183 Timeout,
185 TimerError(io::Error),
187 Other(TErr),
189}
190
191impl<TErr> fmt::Display for TransportTimeoutError<TErr>
192where
193 TErr: fmt::Display,
194{
195 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
196 match self {
197 TransportTimeoutError::Timeout => write!(f, "Timeout has been reached"),
198 TransportTimeoutError::TimerError(err) => write!(f, "Error in the timer: {err}"),
199 TransportTimeoutError::Other(err) => write!(f, "{err}"),
200 }
201 }
202}
203
204impl<TErr> error::Error for TransportTimeoutError<TErr>
205where
206 TErr: error::Error + 'static,
207{
208 fn source(&self) -> Option<&(dyn error::Error + 'static)> {
209 match self {
210 TransportTimeoutError::Timeout => None,
211 TransportTimeoutError::TimerError(err) => Some(err),
212 TransportTimeoutError::Other(err) => Some(err),
213 }
214 }
215}