libp2p/
bandwidth.rs

1// Copyright 2019 Parity Technologies (UK) Ltd.
2//
3// Permission is hereby granted, free of charge, to any person obtaining a
4// copy of this software and associated documentation files (the "Software"),
5// to deal in the Software without restriction, including without limitation
6// the rights to use, copy, modify, merge, publish, distribute, sublicense,
7// and/or sell copies of the Software, and to permit persons to whom the
8// Software is furnished to do so, subject to the following conditions:
9//
10// The above copyright notice and this permission notice shall be included in
11// all copies or substantial portions of the Software.
12//
13// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
14// OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
16// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
18// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
19// DEALINGS IN THE SOFTWARE.
20
21#![allow(deprecated)]
22
23use std::{
24    convert::TryFrom as _,
25    io,
26    pin::Pin,
27    sync::{
28        atomic::{AtomicU64, Ordering},
29        Arc,
30    },
31    task::{Context, Poll},
32};
33
34use futures::{
35    io::{IoSlice, IoSliceMut},
36    prelude::*,
37    ready,
38};
39
40use crate::core::muxing::{StreamMuxer, StreamMuxerEvent};
41
42/// Wraps around a [`StreamMuxer`] and counts the number of bytes that go through all the opened
43/// streams.
44#[derive(Clone)]
45#[pin_project::pin_project]
46pub(crate) struct BandwidthLogging<SMInner> {
47    #[pin]
48    inner: SMInner,
49    sinks: Arc<BandwidthSinks>,
50}
51
52impl<SMInner> BandwidthLogging<SMInner> {
53    /// Creates a new [`BandwidthLogging`] around the stream muxer.
54    pub(crate) fn new(inner: SMInner, sinks: Arc<BandwidthSinks>) -> Self {
55        Self { inner, sinks }
56    }
57}
58
59impl<SMInner> StreamMuxer for BandwidthLogging<SMInner>
60where
61    SMInner: StreamMuxer,
62{
63    type Substream = InstrumentedStream<SMInner::Substream>;
64    type Error = SMInner::Error;
65
66    fn poll(
67        self: Pin<&mut Self>,
68        cx: &mut Context<'_>,
69    ) -> Poll<Result<StreamMuxerEvent, Self::Error>> {
70        let this = self.project();
71        this.inner.poll(cx)
72    }
73
74    fn poll_inbound(
75        self: Pin<&mut Self>,
76        cx: &mut Context<'_>,
77    ) -> Poll<Result<Self::Substream, Self::Error>> {
78        let this = self.project();
79        let inner = ready!(this.inner.poll_inbound(cx)?);
80        let logged = InstrumentedStream {
81            inner,
82            sinks: this.sinks.clone(),
83        };
84        Poll::Ready(Ok(logged))
85    }
86
87    fn poll_outbound(
88        self: Pin<&mut Self>,
89        cx: &mut Context<'_>,
90    ) -> Poll<Result<Self::Substream, Self::Error>> {
91        let this = self.project();
92        let inner = ready!(this.inner.poll_outbound(cx)?);
93        let logged = InstrumentedStream {
94            inner,
95            sinks: this.sinks.clone(),
96        };
97        Poll::Ready(Ok(logged))
98    }
99
100    fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
101        let this = self.project();
102        this.inner.poll_close(cx)
103    }
104}
105
106/// Allows obtaining the average bandwidth of the streams.
107#[deprecated(
108    note = "Use `libp2p::SwarmBuilder::with_bandwidth_metrics` or `libp2p_metrics::BandwidthTransport` instead."
109)]
110pub struct BandwidthSinks {
111    inbound: AtomicU64,
112    outbound: AtomicU64,
113}
114
115impl BandwidthSinks {
116    /// Returns a new [`BandwidthSinks`].
117    pub(crate) fn new() -> Arc<Self> {
118        Arc::new(Self {
119            inbound: AtomicU64::new(0),
120            outbound: AtomicU64::new(0),
121        })
122    }
123
124    /// Returns the total number of bytes that have been downloaded on all the streams.
125    ///
126    /// > **Note**: This method is by design subject to race conditions. The returned value should
127    /// > only ever be used for statistics purposes.
128    pub fn total_inbound(&self) -> u64 {
129        self.inbound.load(Ordering::Relaxed)
130    }
131
132    /// Returns the total number of bytes that have been uploaded on all the streams.
133    ///
134    /// > **Note**: This method is by design subject to race conditions. The returned value should
135    /// > only ever be used for statistics purposes.
136    pub fn total_outbound(&self) -> u64 {
137        self.outbound.load(Ordering::Relaxed)
138    }
139}
140
141/// Wraps around an [`AsyncRead`] + [`AsyncWrite`] and logs the bandwidth that goes through it.
142#[pin_project::pin_project]
143pub(crate) struct InstrumentedStream<SMInner> {
144    #[pin]
145    inner: SMInner,
146    sinks: Arc<BandwidthSinks>,
147}
148
149impl<SMInner: AsyncRead> AsyncRead for InstrumentedStream<SMInner> {
150    fn poll_read(
151        self: Pin<&mut Self>,
152        cx: &mut Context<'_>,
153        buf: &mut [u8],
154    ) -> Poll<io::Result<usize>> {
155        let this = self.project();
156        let num_bytes = ready!(this.inner.poll_read(cx, buf))?;
157        this.sinks.inbound.fetch_add(
158            u64::try_from(num_bytes).unwrap_or(u64::MAX),
159            Ordering::Relaxed,
160        );
161        Poll::Ready(Ok(num_bytes))
162    }
163
164    fn poll_read_vectored(
165        self: Pin<&mut Self>,
166        cx: &mut Context<'_>,
167        bufs: &mut [IoSliceMut<'_>],
168    ) -> Poll<io::Result<usize>> {
169        let this = self.project();
170        let num_bytes = ready!(this.inner.poll_read_vectored(cx, bufs))?;
171        this.sinks.inbound.fetch_add(
172            u64::try_from(num_bytes).unwrap_or(u64::MAX),
173            Ordering::Relaxed,
174        );
175        Poll::Ready(Ok(num_bytes))
176    }
177}
178
179impl<SMInner: AsyncWrite> AsyncWrite for InstrumentedStream<SMInner> {
180    fn poll_write(
181        self: Pin<&mut Self>,
182        cx: &mut Context<'_>,
183        buf: &[u8],
184    ) -> Poll<io::Result<usize>> {
185        let this = self.project();
186        let num_bytes = ready!(this.inner.poll_write(cx, buf))?;
187        this.sinks.outbound.fetch_add(
188            u64::try_from(num_bytes).unwrap_or(u64::MAX),
189            Ordering::Relaxed,
190        );
191        Poll::Ready(Ok(num_bytes))
192    }
193
194    fn poll_write_vectored(
195        self: Pin<&mut Self>,
196        cx: &mut Context<'_>,
197        bufs: &[IoSlice<'_>],
198    ) -> Poll<io::Result<usize>> {
199        let this = self.project();
200        let num_bytes = ready!(this.inner.poll_write_vectored(cx, bufs))?;
201        this.sinks.outbound.fetch_add(
202            u64::try_from(num_bytes).unwrap_or(u64::MAX),
203            Ordering::Relaxed,
204        );
205        Poll::Ready(Ok(num_bytes))
206    }
207
208    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
209        let this = self.project();
210        this.inner.poll_flush(cx)
211    }
212
213    fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
214        let this = self.project();
215        this.inner.poll_close(cx)
216    }
217}