From dff8eb432de71741b42986558320f492e49aff35 Mon Sep 17 00:00:00 2001 From: Mingwei Samuel Date: Mon, 14 Oct 2019 21:29:32 -0700 Subject: [PATCH] use arc to minimize critical section --- src/req/regional_requester.rs | 45 +++++++++++++---------------------- src/req/requester_manager.rs | 19 +++++++-------- 2 files changed, 25 insertions(+), 39 deletions(-) diff --git a/src/req/regional_requester.rs b/src/req/regional_requester.rs index 90f09f4..9dcbf1b 100644 --- a/src/req/regional_requester.rs +++ b/src/req/regional_requester.rs @@ -1,18 +1,15 @@ use std::collections::HashMap; use std::future::Future; -use parking_lot::{ - RwLock, - RwLockReadGuard, - RwLockWriteGuard, - MappedRwLockReadGuard, - MappedRwLockWriteGuard, -}; +use std::sync::Arc; use async_std::task; use reqwest::{ Client, StatusCode, }; +use parking_lot::{ + Mutex, +}; use serde::de::DeserializeOwned; use super::rate_limit::RateLimit; @@ -29,10 +26,10 @@ pub struct RegionalRequester<'a> { /// Represents the app rate limit. app_rate_limit: RateLimit, /// Represents method rate limits. - method_rate_limits: RwLock>, + method_rate_limits: Mutex>>, } -impl <'a> RegionalRequester<'a> { +impl<'a> RegionalRequester<'a> { /// Request header name for the Riot API key. const RIOT_KEY_HEADER: &'static str = "X-Riot-Token"; @@ -40,12 +37,12 @@ impl <'a> RegionalRequester<'a> { const NONE_STATUS_CODES: [u16; 3] = [ 204, 404, 422 ]; - pub fn new(riot_api_config: &'a RiotApiConfig<'a>, client: &'a Client) -> RegionalRequester<'a> { - RegionalRequester { + pub fn new(riot_api_config: &'a RiotApiConfig<'a>, client: &'a Client) -> Self { + Self { riot_api_config: riot_api_config, client: client, app_rate_limit: RateLimit::new(RateLimitType::Application), - method_rate_limits: RwLock::new(HashMap::new()), + method_rate_limits: Mutex::new(HashMap::new()), } } @@ -59,8 +56,8 @@ impl <'a> RegionalRequester<'a> { // Rate limiting. while let Some(delay) = { - let method_rate_limit = &self.get_insert_rate_limit(method_id); - RateLimit::get_both_or_delay(&self.app_rate_limit, method_rate_limit) + let method_rate_limit = self.get_method_rate_limit(method_id); + RateLimit::get_both_or_delay(&self.app_rate_limit, &*method_rate_limit) } { task::sleep(delay).await; } @@ -81,7 +78,7 @@ impl <'a> RegionalRequester<'a> { // Update rate limits (if needed). { self.app_rate_limit.on_response(&response); - self.method_rate_limits.read().get(method_id).unwrap().on_response(&response); + self.get_method_rate_limit(method_id).on_response(&response); } // Handle response. @@ -118,20 +115,10 @@ impl <'a> RegionalRequester<'a> { self.get(method_id, relative_url, region, query) } - fn get_insert_rate_limit(&self, method_id: &'a str) -> MappedRwLockReadGuard { - // This is really stupid? - { - let map_guard = self.method_rate_limits.read(); - if map_guard.contains_key(method_id) { - return RwLockReadGuard::map(map_guard, |mrl| mrl.get(method_id).unwrap()); - } - } - let map_guard = self.method_rate_limits.write(); - let val_write = RwLockWriteGuard::map( - map_guard, |mrl| mrl.entry(method_id) - .or_insert(RateLimit::new(RateLimitType::Method)) - ); - MappedRwLockWriteGuard::downgrade(val_write) + fn get_method_rate_limit(&self, method_id: &'a str) -> Arc { + Arc::clone(self.method_rate_limits.lock() + .entry(method_id) + .or_insert_with(|| Arc::new(RateLimit::new(RateLimitType::Method)))) } fn is_none_status_code(status: &StatusCode) -> bool { diff --git a/src/req/requester_manager.rs b/src/req/requester_manager.rs index f239f08..e0dda73 100644 --- a/src/req/requester_manager.rs +++ b/src/req/requester_manager.rs @@ -1,4 +1,5 @@ use std::collections::HashMap; +use std::sync::Arc; use reqwest::{ Client, @@ -8,14 +9,12 @@ use super::regional_requester::RegionalRequester; use crate::riot_api_config::RiotApiConfig; use crate::consts::region::Region; -// pub struct RequesterManager<'a> { -// /// Configuration settings. -// riot_api_config: &'a RiotApiConfig<'a>, -// /// Client for making requests. -// client: &'a Client, +pub struct RequesterManager<'a> { + /// Configuration settings. + riot_api_config: &'a RiotApiConfig<'a>, + /// Client for making requests. + client: &'a Client, -// /// Represents the app rate limit. -// app_rate_limit: RateLimit, -// /// Represents method rate limits. -// method_rate_limits: HashMap<&'a str, RateLimit>, -// } \ No newline at end of file + /// Per-region requesters. + regional_requesters: HashMap<&'a Region<'a>, Arc>>, +} \ No newline at end of file