1use std::{
2 convert::TryFrom as _,
3 io,
4 pin::Pin,
5 task::{Context, Poll},
6};
7
8use futures::{
9 future::{MapOk, TryFutureExt},
10 io::{IoSlice, IoSliceMut},
11 prelude::*,
12 ready,
13};
14use libp2p_core::{
15 muxing::{StreamMuxer, StreamMuxerEvent},
16 transport::{DialOpts, ListenerId, TransportError, TransportEvent},
17 Multiaddr,
18};
19use libp2p_identity::PeerId;
20use prometheus_client::{
21 encoding::{EncodeLabelSet, EncodeLabelValue},
22 metrics::{counter::Counter, family::Family},
23 registry::{Registry, Unit},
24};
25
26use crate::protocol_stack;
27
28#[derive(Debug, Clone)]
29#[pin_project::pin_project]
30pub struct Transport<T> {
31 #[pin]
32 transport: T,
33 metrics: Family<Labels, Counter>,
34}
35
36impl<T> Transport<T> {
37 pub fn new(transport: T, registry: &mut Registry) -> Self {
38 let metrics = Family::<Labels, Counter>::default();
39 registry
40 .sub_registry_with_prefix("libp2p")
41 .register_with_unit(
42 "bandwidth",
43 "Bandwidth usage by direction and transport protocols",
44 Unit::Bytes,
45 metrics.clone(),
46 );
47
48 Transport { transport, metrics }
49 }
50}
51
52#[derive(EncodeLabelSet, Hash, Clone, Eq, PartialEq, Debug)]
53struct Labels {
54 protocols: String,
55 direction: Direction,
56}
57
58#[derive(Clone, Hash, PartialEq, Eq, EncodeLabelValue, Debug)]
59enum Direction {
60 Inbound,
61 Outbound,
62}
63
64impl<T, M> libp2p_core::Transport for Transport<T>
65where
66 T: libp2p_core::Transport<Output = (PeerId, M)>,
67 M: StreamMuxer + Send + 'static,
68 M::Substream: Send + 'static,
69 M::Error: Send + Sync + 'static,
70{
71 type Output = (PeerId, Muxer<M>);
72 type Error = T::Error;
73 type ListenerUpgrade =
74 MapOk<T::ListenerUpgrade, Box<dyn FnOnce((PeerId, M)) -> (PeerId, Muxer<M>) + Send>>;
75 type Dial = MapOk<T::Dial, Box<dyn FnOnce((PeerId, M)) -> (PeerId, Muxer<M>) + Send>>;
76
77 fn listen_on(
78 &mut self,
79 id: ListenerId,
80 addr: Multiaddr,
81 ) -> Result<(), TransportError<Self::Error>> {
82 self.transport.listen_on(id, addr)
83 }
84
85 fn remove_listener(&mut self, id: ListenerId) -> bool {
86 self.transport.remove_listener(id)
87 }
88
89 fn dial(
90 &mut self,
91 addr: Multiaddr,
92 dial_opts: DialOpts,
93 ) -> Result<Self::Dial, TransportError<Self::Error>> {
94 let metrics = ConnectionMetrics::from_family_and_addr(&self.metrics, &addr);
95 Ok(self
96 .transport
97 .dial(addr.clone(), dial_opts)?
98 .map_ok(Box::new(|(peer_id, stream_muxer)| {
99 (peer_id, Muxer::new(stream_muxer, metrics))
100 })))
101 }
102
103 fn poll(
104 self: Pin<&mut Self>,
105 cx: &mut Context<'_>,
106 ) -> Poll<TransportEvent<Self::ListenerUpgrade, Self::Error>> {
107 let this = self.project();
108 match this.transport.poll(cx) {
109 Poll::Ready(TransportEvent::Incoming {
110 listener_id,
111 upgrade,
112 local_addr,
113 send_back_addr,
114 }) => {
115 let metrics =
116 ConnectionMetrics::from_family_and_addr(this.metrics, &send_back_addr);
117 Poll::Ready(TransportEvent::Incoming {
118 listener_id,
119 upgrade: upgrade.map_ok(Box::new(|(peer_id, stream_muxer)| {
120 (peer_id, Muxer::new(stream_muxer, metrics))
121 })),
122 local_addr,
123 send_back_addr,
124 })
125 }
126 Poll::Ready(other) => {
127 let mapped = other.map_upgrade(|_upgrade| unreachable!("case already matched"));
128 Poll::Ready(mapped)
129 }
130 Poll::Pending => Poll::Pending,
131 }
132 }
133}
134
135#[derive(Clone, Debug)]
136struct ConnectionMetrics {
137 outbound: Counter,
138 inbound: Counter,
139}
140
141impl ConnectionMetrics {
142 fn from_family_and_addr(family: &Family<Labels, Counter>, protocols: &Multiaddr) -> Self {
143 let protocols = protocol_stack::as_string(protocols);
144
145 let outbound = {
147 let m = family.get_or_create(&Labels {
148 protocols: protocols.clone(),
149 direction: Direction::Outbound,
150 });
151 m.clone()
152 };
153 let inbound = {
155 let m = family.get_or_create(&Labels {
156 protocols,
157 direction: Direction::Inbound,
158 });
159 m.clone()
160 };
161 ConnectionMetrics { outbound, inbound }
162 }
163}
164
165#[derive(Clone)]
168#[pin_project::pin_project]
169pub struct Muxer<SMInner> {
170 #[pin]
171 inner: SMInner,
172 metrics: ConnectionMetrics,
173}
174
175impl<SMInner> Muxer<SMInner> {
176 fn new(inner: SMInner, metrics: ConnectionMetrics) -> Self {
178 Self { inner, metrics }
179 }
180}
181
182impl<SMInner> StreamMuxer for Muxer<SMInner>
183where
184 SMInner: StreamMuxer,
185{
186 type Substream = InstrumentedStream<SMInner::Substream>;
187 type Error = SMInner::Error;
188
189 fn poll(
190 self: Pin<&mut Self>,
191 cx: &mut Context<'_>,
192 ) -> Poll<Result<StreamMuxerEvent, Self::Error>> {
193 let this = self.project();
194 this.inner.poll(cx)
195 }
196
197 fn poll_inbound(
198 self: Pin<&mut Self>,
199 cx: &mut Context<'_>,
200 ) -> Poll<Result<Self::Substream, Self::Error>> {
201 let this = self.project();
202 let inner = ready!(this.inner.poll_inbound(cx)?);
203 let logged = InstrumentedStream {
204 inner,
205 metrics: this.metrics.clone(),
206 };
207 Poll::Ready(Ok(logged))
208 }
209
210 fn poll_outbound(
211 self: Pin<&mut Self>,
212 cx: &mut Context<'_>,
213 ) -> Poll<Result<Self::Substream, Self::Error>> {
214 let this = self.project();
215 let inner = ready!(this.inner.poll_outbound(cx)?);
216 let logged = InstrumentedStream {
217 inner,
218 metrics: this.metrics.clone(),
219 };
220 Poll::Ready(Ok(logged))
221 }
222
223 fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
224 let this = self.project();
225 this.inner.poll_close(cx)
226 }
227}
228
229#[pin_project::pin_project]
231pub struct InstrumentedStream<SMInner> {
232 #[pin]
233 inner: SMInner,
234 metrics: ConnectionMetrics,
235}
236
237impl<SMInner: AsyncRead> AsyncRead for InstrumentedStream<SMInner> {
238 fn poll_read(
239 self: Pin<&mut Self>,
240 cx: &mut Context<'_>,
241 buf: &mut [u8],
242 ) -> Poll<io::Result<usize>> {
243 let this = self.project();
244 let num_bytes = ready!(this.inner.poll_read(cx, buf))?;
245 this.metrics
246 .inbound
247 .inc_by(u64::try_from(num_bytes).unwrap_or(u64::MAX));
248 Poll::Ready(Ok(num_bytes))
249 }
250
251 fn poll_read_vectored(
252 self: Pin<&mut Self>,
253 cx: &mut Context<'_>,
254 bufs: &mut [IoSliceMut<'_>],
255 ) -> Poll<io::Result<usize>> {
256 let this = self.project();
257 let num_bytes = ready!(this.inner.poll_read_vectored(cx, bufs))?;
258 this.metrics
259 .inbound
260 .inc_by(u64::try_from(num_bytes).unwrap_or(u64::MAX));
261 Poll::Ready(Ok(num_bytes))
262 }
263}
264
265impl<SMInner: AsyncWrite> AsyncWrite for InstrumentedStream<SMInner> {
266 fn poll_write(
267 self: Pin<&mut Self>,
268 cx: &mut Context<'_>,
269 buf: &[u8],
270 ) -> Poll<io::Result<usize>> {
271 let this = self.project();
272 let num_bytes = ready!(this.inner.poll_write(cx, buf))?;
273 this.metrics
274 .outbound
275 .inc_by(u64::try_from(num_bytes).unwrap_or(u64::MAX));
276 Poll::Ready(Ok(num_bytes))
277 }
278
279 fn poll_write_vectored(
280 self: Pin<&mut Self>,
281 cx: &mut Context<'_>,
282 bufs: &[IoSlice<'_>],
283 ) -> Poll<io::Result<usize>> {
284 let this = self.project();
285 let num_bytes = ready!(this.inner.poll_write_vectored(cx, bufs))?;
286 this.metrics
287 .outbound
288 .inc_by(u64::try_from(num_bytes).unwrap_or(u64::MAX));
289 Poll::Ready(Ok(num_bytes))
290 }
291
292 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
293 let this = self.project();
294 this.inner.poll_flush(cx)
295 }
296
297 fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
298 let this = self.project();
299 this.inner.poll_close(cx)
300 }
301}