1use std::{
22 collections::{HashMap, VecDeque},
23 hash::Hash,
24 net::IpAddr,
25 num::NonZeroU32,
26 time::Duration,
27};
28
29use libp2p_core::multiaddr::{Multiaddr, Protocol};
30use libp2p_identity::PeerId;
31use web_time::Instant;
32
33pub trait RateLimiter: Send {
39 fn try_next(&mut self, peer: PeerId, addr: &Multiaddr, now: Instant) -> bool;
40}
41
42pub(crate) fn new_per_peer(config: GenericRateLimiterConfig) -> Box<dyn RateLimiter> {
43 let mut limiter = GenericRateLimiter::new(config);
44 Box::new(move |peer_id, _addr: &Multiaddr, now| limiter.try_next(peer_id, now))
45}
46
47pub(crate) fn new_per_ip(config: GenericRateLimiterConfig) -> Box<dyn RateLimiter> {
48 let mut limiter = GenericRateLimiter::new(config);
49 Box::new(move |_peer_id, addr: &Multiaddr, now| {
50 multiaddr_to_ip(addr)
51 .map(|a| limiter.try_next(a, now))
52 .unwrap_or(true)
53 })
54}
55
56impl<T: FnMut(PeerId, &Multiaddr, Instant) -> bool + Send> RateLimiter for T {
57 fn try_next(&mut self, peer: PeerId, addr: &Multiaddr, now: Instant) -> bool {
58 self(peer, addr, now)
59 }
60}
61
62fn multiaddr_to_ip(addr: &Multiaddr) -> Option<IpAddr> {
63 addr.iter().find_map(|p| match p {
64 Protocol::Ip4(addr) => Some(addr.into()),
65 Protocol::Ip6(addr) => Some(addr.into()),
66 _ => None,
67 })
68}
69
70pub(crate) struct GenericRateLimiter<Id> {
74 limit: u32,
75 interval: Duration,
76
77 refill_schedule: VecDeque<(Instant, Id)>,
78 buckets: HashMap<Id, u32>,
79}
80
81#[derive(Debug, Clone, Copy)]
83pub(crate) struct GenericRateLimiterConfig {
84 pub(crate) limit: NonZeroU32,
86 pub(crate) interval: Duration,
88}
89
90impl<Id: Eq + PartialEq + Hash + Clone> GenericRateLimiter<Id> {
91 pub(crate) fn new(config: GenericRateLimiterConfig) -> Self {
92 assert!(!config.interval.is_zero());
93
94 Self {
95 limit: config.limit.into(),
96 interval: config.interval,
97 refill_schedule: Default::default(),
98 buckets: Default::default(),
99 }
100 }
101
102 pub(crate) fn try_next(&mut self, id: Id, now: Instant) -> bool {
103 self.refill(now);
104
105 match self.buckets.get_mut(&id) {
106 Some(balance) => match balance.checked_sub(1) {
108 Some(a) => {
109 *balance = a;
110 true
111 }
112 None => false,
113 },
114 None => {
117 self.buckets.insert(id.clone(), self.limit - 1);
118 self.refill_schedule.push_back((now, id));
119 true
120 }
121 }
122 }
123
124 fn refill(&mut self, now: Instant) {
125 loop {
128 match self.refill_schedule.front() {
129 Some((last_refill, _)) if now.duration_since(*last_refill) >= self.interval => {}
132 _ => return,
135 };
136
137 let (last_refill, id) = self
138 .refill_schedule
139 .pop_front()
140 .expect("Queue not to be empty.");
141
142 let balance = self
144 .buckets
145 .get(&id)
146 .expect("Entry can only be removed via refill.");
147
148 let duration_since = now.duration_since(last_refill);
150 let new_tokens = duration_since
151 .as_micros()
152 .checked_div(self.interval.as_micros())
154 .and_then(|i| i.try_into().ok())
155 .unwrap_or(u32::MAX);
156 let new_balance = balance.checked_add(new_tokens).unwrap_or(u32::MAX);
157
158 if new_balance < self.limit {
160 self.buckets
161 .insert(id.clone(), new_balance)
162 .expect("To override value.");
163 self.refill_schedule.push_back((now, id));
164 } else {
165 self.buckets.remove(&id);
168 }
169 }
170 }
171}
172
173#[cfg(test)]
174mod tests {
175 use quickcheck::{QuickCheck, TestResult};
176
177 use super::*;
178
179 #[test]
180 fn first() {
181 let id = 1;
182 let mut l = GenericRateLimiter::new(GenericRateLimiterConfig {
183 limit: NonZeroU32::new(10).unwrap(),
184 interval: Duration::from_secs(1),
185 });
186 assert!(l.try_next(id, Instant::now()));
187 }
188
189 #[test]
190 fn limits() {
191 let id = 1;
192 let now = Instant::now();
193 let mut l = GenericRateLimiter::new(GenericRateLimiterConfig {
194 limit: NonZeroU32::new(10).unwrap(),
195 interval: Duration::from_secs(1),
196 });
197 for _ in 0..10 {
198 assert!(l.try_next(id, now));
199 }
200
201 assert!(!l.try_next(id, now));
202 }
203
204 #[test]
205 fn refills() {
206 let id = 1;
207 let now = Instant::now();
208 let mut l = GenericRateLimiter::new(GenericRateLimiterConfig {
209 limit: NonZeroU32::new(10).unwrap(),
210 interval: Duration::from_secs(1),
211 });
212
213 for _ in 0..10 {
214 assert!(l.try_next(id, now));
215 }
216 assert!(!l.try_next(id, now));
217
218 let now = now + Duration::from_secs(1);
219 assert!(l.try_next(id, now));
220 assert!(!l.try_next(id, now));
221
222 let now = now + Duration::from_secs(10);
223 for _ in 0..10 {
224 assert!(l.try_next(id, now));
225 }
226 }
227
228 #[test]
229 fn move_at_half_interval_steps() {
230 let id = 1;
231 let now = Instant::now();
232 let mut l = GenericRateLimiter::new(GenericRateLimiterConfig {
233 limit: NonZeroU32::new(1).unwrap(),
234 interval: Duration::from_secs(2),
235 });
236
237 assert!(l.try_next(id, now));
238 assert!(!l.try_next(id, now));
239
240 let now = now + Duration::from_secs(1);
241 assert!(!l.try_next(id, now));
242
243 let now = now + Duration::from_secs(1);
244 assert!(l.try_next(id, now));
245 }
246
247 #[test]
248 fn garbage_collects() {
249 let now = Instant::now();
250 let mut l = GenericRateLimiter::new(GenericRateLimiterConfig {
251 limit: NonZeroU32::new(1).unwrap(),
252 interval: Duration::from_secs(1),
253 });
254
255 assert!(l.try_next(1, now));
256
257 let now = now + Duration::from_secs(1);
258 assert!(l.try_next(2, now));
259
260 assert_eq!(l.buckets.len(), 1);
261 assert_eq!(l.refill_schedule.len(), 1);
262 }
263
264 #[test]
265 fn quick_check() {
266 fn prop(limit: NonZeroU32, interval: Duration, events: Vec<(u32, Duration)>) -> TestResult {
267 if interval.is_zero() {
268 return TestResult::discard();
269 }
270
271 let mut now = Instant::now();
272 let mut l = GenericRateLimiter::new(GenericRateLimiterConfig { limit, interval });
273
274 for (id, d) in events {
275 now = if let Some(now) = now.checked_add(d) {
276 now
277 } else {
278 return TestResult::discard();
279 };
280 l.try_next(id, now);
281 }
282
283 now = if let Some(now) = interval
284 .checked_mul(limit.into())
285 .and_then(|full_interval| now.checked_add(full_interval))
286 {
287 now
288 } else {
289 return TestResult::discard();
290 };
291 assert!(l.try_next(1, now));
292
293 assert_eq!(l.buckets.len(), 1);
294 assert_eq!(l.refill_schedule.len(), 1);
295
296 TestResult::passed()
297 }
298
299 QuickCheck::new().quickcheck(prop as fn(_, _, _) -> _)
300 }
301}