using parking_lot

This commit is contained in:
Mingwei Samuel 2019-10-14 01:00:20 -07:00
parent 5700547c05
commit 346ef13ed4
4 changed files with 41 additions and 15 deletions

View file

@ -13,5 +13,6 @@ exclude = [
[dependencies] [dependencies]
async-std = "0.99" async-std = "0.99"
parking_lot = { version = "0.9", features = [ "nightly" ] }
reqwest = { version = "0.10.0-alpha.1", features = [ "gzip", "json" ] } reqwest = { version = "0.10.0-alpha.1", features = [ "gzip", "json" ] }
serde = "^1.0" serde = "^1.0"

View file

@ -3,7 +3,8 @@ use std::time::{
Duration, Duration,
Instant, Instant,
}; };
use std::sync::{
use parking_lot::{
RwLock, RwLock,
}; };
@ -48,8 +49,8 @@ impl RateLimit {
return retry_after_delay return retry_after_delay
} }
// Check buckets. // Check buckets.
let app_buckets = app_rate_limit.buckets.read().unwrap(); let app_buckets = app_rate_limit.buckets.read();
let method_buckets = app_rate_limit.buckets.read().unwrap(); let method_buckets = method_rate_limit.buckets.read();
for bucket in app_buckets.iter().chain(method_buckets.iter()) { for bucket in app_buckets.iter().chain(method_buckets.iter()) {
let delay = bucket.get_delay(); let delay = bucket.get_delay();
if delay.is_some() { if delay.is_some() {

View file

@ -1,6 +1,12 @@
use std::collections::HashMap; use std::collections::HashMap;
use std::future::Future; use std::future::Future;
use std::sync::RwLock; use parking_lot::{
RwLock,
RwLockReadGuard,
RwLockWriteGuard,
MappedRwLockReadGuard,
MappedRwLockWriteGuard,
};
use async_std::task; use async_std::task;
use reqwest::{ use reqwest::{
@ -23,7 +29,7 @@ pub struct RegionalRequester<'a> {
/// Represents the app rate limit. /// Represents the app rate limit.
app_rate_limit: RateLimit, app_rate_limit: RateLimit,
/// Represents method rate limits. /// Represents method rate limits.
method_rate_limits: HashMap<&'a str, RateLimit>, method_rate_limits: RwLock<HashMap<&'a str, RateLimit>>,
} }
impl <'a> RegionalRequester<'a> { impl <'a> RegionalRequester<'a> {
@ -39,7 +45,7 @@ impl <'a> RegionalRequester<'a> {
riot_api_config: riot_api_config, riot_api_config: riot_api_config,
client: client, client: client,
app_rate_limit: RateLimit::new(RateLimitType::Application), app_rate_limit: RateLimit::new(RateLimitType::Application),
method_rate_limits: HashMap::new(), method_rate_limits: RwLock::new(HashMap::new()),
} }
} }
@ -52,11 +58,10 @@ impl <'a> RegionalRequester<'a> {
attempts += 1; attempts += 1;
// Rate limiting. // Rate limiting.
let app_rate_limit = &self.app_rate_limit; while let Some(delay) = {
let method_rate_limit = self.method_rate_limits.entry(method_id) let method_rate_limit = &self.get_insert_rate_limit(method_id);
.or_insert_with(|| RateLimit::new(RateLimitType::Method)); RateLimit::get_both_or_delay(&self.app_rate_limit, method_rate_limit)
} {
while let Some(delay) = RateLimit::get_both_or_delay(app_rate_limit, method_rate_limit) {
task::sleep(delay).await; task::sleep(delay).await;
} }
@ -74,8 +79,10 @@ impl <'a> RegionalRequester<'a> {
}; };
// Update rate limits (if needed). // Update rate limits (if needed).
app_rate_limit.on_response(&response); {
method_rate_limit.on_response(&response); self.app_rate_limit.on_response(&response);
self.method_rate_limits.read().get(method_id).unwrap().on_response(&response);
}
// Handle response. // Handle response.
let status = response.status(); let status = response.status();
@ -111,6 +118,22 @@ impl <'a> RegionalRequester<'a> {
self.get(method_id, relative_url, region, query) self.get(method_id, relative_url, region, query)
} }
fn get_insert_rate_limit(&self, method_id: &'a str) -> MappedRwLockReadGuard<RateLimit> {
// This is really stupid?
{
let map_guard = self.method_rate_limits.read();
if map_guard.contains_key(method_id) {
return RwLockReadGuard::map(map_guard, |mrl| mrl.get(method_id).unwrap());
}
}
let map_guard = self.method_rate_limits.write();
let val_write = RwLockWriteGuard::map(
map_guard, |mrl| mrl.entry(method_id)
.or_insert(RateLimit::new(RateLimitType::Method))
);
MappedRwLockWriteGuard::downgrade(val_write)
}
fn is_none_status_code(status: &StatusCode) -> bool { fn is_none_status_code(status: &StatusCode) -> bool {
Self::NONE_STATUS_CODES.contains(&status.as_u16()) Self::NONE_STATUS_CODES.contains(&status.as_u16())
} }

View file

@ -1,6 +1,7 @@
use std::collections::VecDeque; use std::collections::VecDeque;
use std::time::{Duration, Instant}; use std::time::{Duration, Instant};
use std::sync::{Mutex, MutexGuard};
use parking_lot::{Mutex, MutexGuard};
pub trait TokenBucket { pub trait TokenBucket {
/// Get the duration til the next available token, or 0 duration if a token is available. /// Get the duration til the next available token, or 0 duration if a token is available.
@ -45,7 +46,7 @@ impl VectorTokenBucket {
} }
fn update_get_timestamps(&self) -> MutexGuard<VecDeque<Instant>> { fn update_get_timestamps(&self) -> MutexGuard<VecDeque<Instant>> {
let mut timestamps = self.timestamps.lock().unwrap(); let mut timestamps = self.timestamps.lock();
let cutoff = Instant::now() - self.duration; let cutoff = Instant::now() - self.duration;
while timestamps.back().map_or(false, |ts| ts < &cutoff) { while timestamps.back().map_or(false, |ts| ts < &cutoff) {
timestamps.pop_back(); timestamps.pop_back();