From 346ef13ed44aeecf0577eab6e6bb293b35129913 Mon Sep 17 00:00:00 2001 From: Mingwei Samuel Date: Mon, 14 Oct 2019 01:00:20 -0700 Subject: [PATCH] using parking_lot --- Cargo.toml | 1 + src/req/rate_limit.rs | 7 +++--- src/req/regional_requester.rs | 43 +++++++++++++++++++++++++++-------- src/req/token_bucket.rs | 5 ++-- 4 files changed, 41 insertions(+), 15 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 62d76ad..5180364 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -13,5 +13,6 @@ exclude = [ [dependencies] async-std = "0.99" +parking_lot = { version = "0.9", features = [ "nightly" ] } reqwest = { version = "0.10.0-alpha.1", features = [ "gzip", "json" ] } serde = "^1.0" diff --git a/src/req/rate_limit.rs b/src/req/rate_limit.rs index 1e0ab94..1a92601 100644 --- a/src/req/rate_limit.rs +++ b/src/req/rate_limit.rs @@ -3,7 +3,8 @@ use std::time::{ Duration, Instant, }; -use std::sync::{ + +use parking_lot::{ RwLock, }; @@ -48,8 +49,8 @@ impl RateLimit { return retry_after_delay } // Check buckets. - let app_buckets = app_rate_limit.buckets.read().unwrap(); - let method_buckets = app_rate_limit.buckets.read().unwrap(); + let app_buckets = app_rate_limit.buckets.read(); + let method_buckets = method_rate_limit.buckets.read(); for bucket in app_buckets.iter().chain(method_buckets.iter()) { let delay = bucket.get_delay(); if delay.is_some() { diff --git a/src/req/regional_requester.rs b/src/req/regional_requester.rs index fc84255..90f09f4 100644 --- a/src/req/regional_requester.rs +++ b/src/req/regional_requester.rs @@ -1,6 +1,12 @@ use std::collections::HashMap; use std::future::Future; -use std::sync::RwLock; +use parking_lot::{ + RwLock, + RwLockReadGuard, + RwLockWriteGuard, + MappedRwLockReadGuard, + MappedRwLockWriteGuard, +}; use async_std::task; use reqwest::{ @@ -23,7 +29,7 @@ pub struct RegionalRequester<'a> { /// Represents the app rate limit. app_rate_limit: RateLimit, /// Represents method rate limits. - method_rate_limits: HashMap<&'a str, RateLimit>, + method_rate_limits: RwLock>, } impl <'a> RegionalRequester<'a> { @@ -39,7 +45,7 @@ impl <'a> RegionalRequester<'a> { riot_api_config: riot_api_config, client: client, app_rate_limit: RateLimit::new(RateLimitType::Application), - method_rate_limits: HashMap::new(), + method_rate_limits: RwLock::new(HashMap::new()), } } @@ -52,11 +58,10 @@ impl <'a> RegionalRequester<'a> { attempts += 1; // Rate limiting. - let app_rate_limit = &self.app_rate_limit; - let method_rate_limit = self.method_rate_limits.entry(method_id) - .or_insert_with(|| RateLimit::new(RateLimitType::Method)); - - while let Some(delay) = RateLimit::get_both_or_delay(app_rate_limit, method_rate_limit) { + while let Some(delay) = { + let method_rate_limit = &self.get_insert_rate_limit(method_id); + RateLimit::get_both_or_delay(&self.app_rate_limit, method_rate_limit) + } { task::sleep(delay).await; } @@ -74,8 +79,10 @@ impl <'a> RegionalRequester<'a> { }; // Update rate limits (if needed). - app_rate_limit.on_response(&response); - method_rate_limit.on_response(&response); + { + self.app_rate_limit.on_response(&response); + self.method_rate_limits.read().get(method_id).unwrap().on_response(&response); + } // Handle response. let status = response.status(); @@ -111,6 +118,22 @@ impl <'a> RegionalRequester<'a> { self.get(method_id, relative_url, region, query) } + fn get_insert_rate_limit(&self, method_id: &'a str) -> MappedRwLockReadGuard { + // This is really stupid? + { + let map_guard = self.method_rate_limits.read(); + if map_guard.contains_key(method_id) { + return RwLockReadGuard::map(map_guard, |mrl| mrl.get(method_id).unwrap()); + } + } + let map_guard = self.method_rate_limits.write(); + let val_write = RwLockWriteGuard::map( + map_guard, |mrl| mrl.entry(method_id) + .or_insert(RateLimit::new(RateLimitType::Method)) + ); + MappedRwLockWriteGuard::downgrade(val_write) + } + fn is_none_status_code(status: &StatusCode) -> bool { Self::NONE_STATUS_CODES.contains(&status.as_u16()) } diff --git a/src/req/token_bucket.rs b/src/req/token_bucket.rs index 557ed5c..acf0805 100644 --- a/src/req/token_bucket.rs +++ b/src/req/token_bucket.rs @@ -1,6 +1,7 @@ use std::collections::VecDeque; use std::time::{Duration, Instant}; -use std::sync::{Mutex, MutexGuard}; + +use parking_lot::{Mutex, MutexGuard}; pub trait TokenBucket { /// Get the duration til the next available token, or 0 duration if a token is available. @@ -45,7 +46,7 @@ impl VectorTokenBucket { } fn update_get_timestamps(&self) -> MutexGuard> { - let mut timestamps = self.timestamps.lock().unwrap(); + let mut timestamps = self.timestamps.lock(); let cutoff = Instant::now() - self.duration; while timestamps.back().map_or(false, |ts| ts < &cutoff) { timestamps.pop_back();