libp2p_metrics/
bandwidth.rs

1use std::{
2    convert::TryFrom as _,
3    io,
4    pin::Pin,
5    task::{Context, Poll},
6};
7
8use futures::{
9    future::{MapOk, TryFutureExt},
10    io::{IoSlice, IoSliceMut},
11    prelude::*,
12    ready,
13};
14use libp2p_core::{
15    muxing::{StreamMuxer, StreamMuxerEvent},
16    transport::{DialOpts, ListenerId, TransportError, TransportEvent},
17    Multiaddr,
18};
19use libp2p_identity::PeerId;
20use prometheus_client::{
21    encoding::{EncodeLabelSet, EncodeLabelValue},
22    metrics::{counter::Counter, family::Family},
23    registry::{Registry, Unit},
24};
25
26use crate::protocol_stack;
27
28#[derive(Debug, Clone)]
29#[pin_project::pin_project]
30pub struct Transport<T> {
31    #[pin]
32    transport: T,
33    metrics: Family<Labels, Counter>,
34}
35
36impl<T> Transport<T> {
37    pub fn new(transport: T, registry: &mut Registry) -> Self {
38        let metrics = Family::<Labels, Counter>::default();
39        registry
40            .sub_registry_with_prefix("libp2p")
41            .register_with_unit(
42                "bandwidth",
43                "Bandwidth usage by direction and transport protocols",
44                Unit::Bytes,
45                metrics.clone(),
46            );
47
48        Transport { transport, metrics }
49    }
50}
51
52#[derive(EncodeLabelSet, Hash, Clone, Eq, PartialEq, Debug)]
53struct Labels {
54    protocols: String,
55    direction: Direction,
56}
57
58#[derive(Clone, Hash, PartialEq, Eq, EncodeLabelValue, Debug)]
59enum Direction {
60    Inbound,
61    Outbound,
62}
63
64impl<T, M> libp2p_core::Transport for Transport<T>
65where
66    T: libp2p_core::Transport<Output = (PeerId, M)>,
67    M: StreamMuxer + Send + 'static,
68    M::Substream: Send + 'static,
69    M::Error: Send + Sync + 'static,
70{
71    type Output = (PeerId, Muxer<M>);
72    type Error = T::Error;
73    type ListenerUpgrade =
74        MapOk<T::ListenerUpgrade, Box<dyn FnOnce((PeerId, M)) -> (PeerId, Muxer<M>) + Send>>;
75    type Dial = MapOk<T::Dial, Box<dyn FnOnce((PeerId, M)) -> (PeerId, Muxer<M>) + Send>>;
76
77    fn listen_on(
78        &mut self,
79        id: ListenerId,
80        addr: Multiaddr,
81    ) -> Result<(), TransportError<Self::Error>> {
82        self.transport.listen_on(id, addr)
83    }
84
85    fn remove_listener(&mut self, id: ListenerId) -> bool {
86        self.transport.remove_listener(id)
87    }
88
89    fn dial(
90        &mut self,
91        addr: Multiaddr,
92        dial_opts: DialOpts,
93    ) -> Result<Self::Dial, TransportError<Self::Error>> {
94        let metrics = ConnectionMetrics::from_family_and_addr(&self.metrics, &addr);
95        Ok(self
96            .transport
97            .dial(addr.clone(), dial_opts)?
98            .map_ok(Box::new(|(peer_id, stream_muxer)| {
99                (peer_id, Muxer::new(stream_muxer, metrics))
100            })))
101    }
102
103    fn poll(
104        self: Pin<&mut Self>,
105        cx: &mut Context<'_>,
106    ) -> Poll<TransportEvent<Self::ListenerUpgrade, Self::Error>> {
107        let this = self.project();
108        match this.transport.poll(cx) {
109            Poll::Ready(TransportEvent::Incoming {
110                listener_id,
111                upgrade,
112                local_addr,
113                send_back_addr,
114            }) => {
115                let metrics =
116                    ConnectionMetrics::from_family_and_addr(this.metrics, &send_back_addr);
117                Poll::Ready(TransportEvent::Incoming {
118                    listener_id,
119                    upgrade: upgrade.map_ok(Box::new(|(peer_id, stream_muxer)| {
120                        (peer_id, Muxer::new(stream_muxer, metrics))
121                    })),
122                    local_addr,
123                    send_back_addr,
124                })
125            }
126            Poll::Ready(other) => {
127                let mapped = other.map_upgrade(|_upgrade| unreachable!("case already matched"));
128                Poll::Ready(mapped)
129            }
130            Poll::Pending => Poll::Pending,
131        }
132    }
133}
134
135#[derive(Clone, Debug)]
136struct ConnectionMetrics {
137    outbound: Counter,
138    inbound: Counter,
139}
140
141impl ConnectionMetrics {
142    fn from_family_and_addr(family: &Family<Labels, Counter>, protocols: &Multiaddr) -> Self {
143        let protocols = protocol_stack::as_string(protocols);
144
145        // Additional scope to make sure to drop the lock guard from `get_or_create`.
146        let outbound = {
147            let m = family.get_or_create(&Labels {
148                protocols: protocols.clone(),
149                direction: Direction::Outbound,
150            });
151            m.clone()
152        };
153        // Additional scope to make sure to drop the lock guard from `get_or_create`.
154        let inbound = {
155            let m = family.get_or_create(&Labels {
156                protocols,
157                direction: Direction::Inbound,
158            });
159            m.clone()
160        };
161        ConnectionMetrics { outbound, inbound }
162    }
163}
164
165/// Wraps around a [`StreamMuxer`] and counts the number of bytes that go through all the opened
166/// streams.
167#[derive(Clone)]
168#[pin_project::pin_project]
169pub struct Muxer<SMInner> {
170    #[pin]
171    inner: SMInner,
172    metrics: ConnectionMetrics,
173}
174
175impl<SMInner> Muxer<SMInner> {
176    /// Creates a new [`Muxer`] wrapping around the provided stream muxer.
177    fn new(inner: SMInner, metrics: ConnectionMetrics) -> Self {
178        Self { inner, metrics }
179    }
180}
181
182impl<SMInner> StreamMuxer for Muxer<SMInner>
183where
184    SMInner: StreamMuxer,
185{
186    type Substream = InstrumentedStream<SMInner::Substream>;
187    type Error = SMInner::Error;
188
189    fn poll(
190        self: Pin<&mut Self>,
191        cx: &mut Context<'_>,
192    ) -> Poll<Result<StreamMuxerEvent, Self::Error>> {
193        let this = self.project();
194        this.inner.poll(cx)
195    }
196
197    fn poll_inbound(
198        self: Pin<&mut Self>,
199        cx: &mut Context<'_>,
200    ) -> Poll<Result<Self::Substream, Self::Error>> {
201        let this = self.project();
202        let inner = ready!(this.inner.poll_inbound(cx)?);
203        let logged = InstrumentedStream {
204            inner,
205            metrics: this.metrics.clone(),
206        };
207        Poll::Ready(Ok(logged))
208    }
209
210    fn poll_outbound(
211        self: Pin<&mut Self>,
212        cx: &mut Context<'_>,
213    ) -> Poll<Result<Self::Substream, Self::Error>> {
214        let this = self.project();
215        let inner = ready!(this.inner.poll_outbound(cx)?);
216        let logged = InstrumentedStream {
217            inner,
218            metrics: this.metrics.clone(),
219        };
220        Poll::Ready(Ok(logged))
221    }
222
223    fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
224        let this = self.project();
225        this.inner.poll_close(cx)
226    }
227}
228
229/// Wraps around an [`AsyncRead`] + [`AsyncWrite`] and logs the bandwidth that goes through it.
230#[pin_project::pin_project]
231pub struct InstrumentedStream<SMInner> {
232    #[pin]
233    inner: SMInner,
234    metrics: ConnectionMetrics,
235}
236
237impl<SMInner: AsyncRead> AsyncRead for InstrumentedStream<SMInner> {
238    fn poll_read(
239        self: Pin<&mut Self>,
240        cx: &mut Context<'_>,
241        buf: &mut [u8],
242    ) -> Poll<io::Result<usize>> {
243        let this = self.project();
244        let num_bytes = ready!(this.inner.poll_read(cx, buf))?;
245        this.metrics
246            .inbound
247            .inc_by(u64::try_from(num_bytes).unwrap_or(u64::MAX));
248        Poll::Ready(Ok(num_bytes))
249    }
250
251    fn poll_read_vectored(
252        self: Pin<&mut Self>,
253        cx: &mut Context<'_>,
254        bufs: &mut [IoSliceMut<'_>],
255    ) -> Poll<io::Result<usize>> {
256        let this = self.project();
257        let num_bytes = ready!(this.inner.poll_read_vectored(cx, bufs))?;
258        this.metrics
259            .inbound
260            .inc_by(u64::try_from(num_bytes).unwrap_or(u64::MAX));
261        Poll::Ready(Ok(num_bytes))
262    }
263}
264
265impl<SMInner: AsyncWrite> AsyncWrite for InstrumentedStream<SMInner> {
266    fn poll_write(
267        self: Pin<&mut Self>,
268        cx: &mut Context<'_>,
269        buf: &[u8],
270    ) -> Poll<io::Result<usize>> {
271        let this = self.project();
272        let num_bytes = ready!(this.inner.poll_write(cx, buf))?;
273        this.metrics
274            .outbound
275            .inc_by(u64::try_from(num_bytes).unwrap_or(u64::MAX));
276        Poll::Ready(Ok(num_bytes))
277    }
278
279    fn poll_write_vectored(
280        self: Pin<&mut Self>,
281        cx: &mut Context<'_>,
282        bufs: &[IoSlice<'_>],
283    ) -> Poll<io::Result<usize>> {
284        let this = self.project();
285        let num_bytes = ready!(this.inner.poll_write_vectored(cx, bufs))?;
286        this.metrics
287            .outbound
288            .inc_by(u64::try_from(num_bytes).unwrap_or(u64::MAX));
289        Poll::Ready(Ok(num_bytes))
290    }
291
292    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
293        let this = self.project();
294        this.inner.poll_flush(cx)
295    }
296
297    fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
298        let this = self.project();
299        this.inner.poll_close(cx)
300    }
301}