rate limiting wip

This commit is contained in:
Mingwei Samuel 2019-10-13 23:38:22 -07:00
parent 4d6c45edcf
commit 5700547c05
12 changed files with 383 additions and 28 deletions

1
.gitignore vendored
View file

@ -1,3 +1,4 @@
/target /target
**/*.rs.bk **/*.rs.bk
Cargo.lock Cargo.lock
/doc

View file

@ -2,7 +2,7 @@
name = "riven" name = "riven"
version = "0.0.1" version = "0.0.1"
authors = ["Mingwei Samuel <mingwei.samuel@gmail.com>"] authors = ["Mingwei Samuel <mingwei.samuel@gmail.com>"]
description = "RiotApi library (wip)" description = "Riot API Library (WIP)"
license = "LGPL-3.0" license = "LGPL-3.0"
edition = "2018" edition = "2018"
exclude = [ exclude = [
@ -12,3 +12,6 @@ exclude = [
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies] [dependencies]
async-std = "0.99"
reqwest = { version = "0.10.0-alpha.1", features = [ "gzip", "json" ] }
serde = "^1.0"

1
src/consts/mod.rs Normal file
View file

@ -0,0 +1 @@
pub mod region;

73
src/consts/region.rs Normal file
View file

@ -0,0 +1,73 @@
#[derive(Debug, PartialEq, Eq, Hash, PartialOrd, Ord)]
pub struct Region<'a> {
pub key: &'a str,
pub platform: &'a str,
}
macro_rules! regions {
(
$(
$key:ident => $plat:expr ;
)*
) => {
$(
const $key: &'static Region<'static> = &Region {
key: stringify!($key),
platform: $plat,
};
)*
#[doc="Get region by name."]
#[doc="# Arguments"]
#[doc="* `name` - Case-insensitive ASCII string to match Regions' `key` or `playform`."]
#[doc="# Returns"]
#[doc="`Some(&Region)` if match found, `None` if no match found."]
#[allow(unreachable_patterns)]
pub fn get(name: &str) -> Option<&Region> {
match &*name.to_ascii_uppercase() {
$(
stringify!($key) | $plat => Some(Self::$key),
)*
_ => None
}
}
}
}
impl Region<'_> {
// Is this stupid?
regions! {
BR => "BR1";
EUNE => "EUN1";
EUW => "EUW1";
NA => "NA1";
KR => "KR";
LAN => "LA1";
LAS => "LA2";
OCE => "OC1";
RU => "RU";
TR => "TR1";
JP => "JP1";
PBE => "PBE1";
AMERICAS => "AMERICAS";
EUROPE => "EUROPE";
ASIA => "ASIA";
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_basic() {
assert_eq!("BR1", Region::BR.platform);
}
#[test]
fn test_get() {
assert_eq!(Some(Region::AMERICAS), Region::get("amEricAs"));
assert_eq!(Some(Region::NA), Region::get("na1"));
assert_eq!(None, Region::get("LA"));
}
}

View file

@ -1,6 +1,9 @@
mod req; #![allow(dead_code)] // TODO REMOVE
use req::rate_limit; pub mod consts;
mod req;
mod riot_api_config;
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {

View file

@ -1,2 +1,5 @@
pub mod rate_limit; pub mod rate_limit;
pub mod rate_limit_type;
pub mod token_bucket; pub mod token_bucket;
pub mod regional_requester;
pub mod requester_manager;

View file

@ -1,6 +1,83 @@
use std::time::Duration; use std::cmp;
use std::time::{
Duration,
Instant,
};
use std::sync::{
RwLock,
};
pub trait RateLimit { use super::token_bucket::{
fn get_retry_after_delay(&self) -> Duration; TokenBucket,
VectorTokenBucket,
};
use super::rate_limit_type::RateLimitType;
pub struct RateLimit {
rate_limit_type: RateLimitType,
// Buckets for this rate limit (synchronized).
// Almost always read, written only when rate limit rates are updated
// from API response.
// TODO: Question of writer starvation.
buckets: RwLock<Vec<VectorTokenBucket>>,
// Set to when we can retry if a retry-after header is received.
retry_after: Option<Instant>,
} }
impl RateLimit {
/// Header specifying which RateLimitType caused a 429.
const HEADER_XRATELIMITTYPE: &'static str = "X-Rate-Limit-Type";
/// Header specifying retry after time in seconds after a 429.
const HEADER_RETRYAFTER: &'static str = "Retry-After";
pub fn new(rate_limit_type: RateLimitType) -> Self {
let initial_bucket = VectorTokenBucket::new(Duration::from_secs(1), 1);
RateLimit {
rate_limit_type: rate_limit_type,
// Rate limit before getting from response: 1/s.
buckets: RwLock::new(vec![initial_bucket]),
retry_after: None,
}
}
pub fn get_both_or_delay(app_rate_limit: &Self, method_rate_limit: &Self) -> Option<Duration> {
// Check retry after.
let retry_after_delay = app_rate_limit.get_retry_after_delay()
.and_then(|a| method_rate_limit.get_retry_after_delay().map(|m| cmp::max(a, m)));
if retry_after_delay.is_some() {
return retry_after_delay
}
// Check buckets.
let app_buckets = app_rate_limit.buckets.read().unwrap();
let method_buckets = app_rate_limit.buckets.read().unwrap();
for bucket in app_buckets.iter().chain(method_buckets.iter()) {
let delay = bucket.get_delay();
if delay.is_some() {
return delay;
}
}
// Success.
for bucket in app_buckets.iter().chain(method_buckets.iter()) {
bucket.get_tokens(1);
}
None
}
pub fn get_retry_after_delay(&self) -> Option<Duration> {
self.retry_after.and_then(|i| Instant::now().checked_duration_since(i))
}
pub fn on_response(&self, _response: &reqwest::Response) {
unimplemented!();
}
}
#[cfg(test)]
mod tests {
use super::*;
fn send_sync() {
fn is_send_sync<T: Send + Sync>() {}
is_send_sync::<RateLimit>();
}
}

View file

@ -0,0 +1,27 @@
pub enum RateLimitType {
Application,
Method,
}
impl RateLimitType {
pub fn type_name(self) -> &'static str {
match self {
Self::Application => "application",
Self::Method => "method",
}
}
pub fn limit_header(self) -> &'static str {
match self {
Self::Application => "X-App-Rate-Limit",
Self::Method => "X-Method-Rate-Limit",
}
}
pub fn count_header(self) -> &'static str {
match self {
Self::Application => "X-App-Rate-Limit-Count",
Self::Method => "X-Method-Rate-Limit-Count",
}
}
}

View file

@ -0,0 +1,127 @@
use std::collections::HashMap;
use std::future::Future;
use std::sync::RwLock;
use async_std::task;
use reqwest::{
Client,
StatusCode,
};
use serde::de::DeserializeOwned;
use super::rate_limit::RateLimit;
use super::rate_limit_type::RateLimitType;
use crate::riot_api_config::RiotApiConfig;
use crate::consts::region::Region;
pub struct RegionalRequester<'a> {
/// Configuration settings.
riot_api_config: &'a RiotApiConfig<'a>,
/// Client for making requests.
client: &'a Client,
/// Represents the app rate limit.
app_rate_limit: RateLimit,
/// Represents method rate limits.
method_rate_limits: HashMap<&'a str, RateLimit>,
}
impl <'a> RegionalRequester<'a> {
/// Request header name for the Riot API key.
const RIOT_KEY_HEADER: &'static str = "X-Riot-Token";
/// HttpStatus codes that are considered a success, but will return None.
const NONE_STATUS_CODES: [u16; 3] = [ 204, 404, 422 ];
pub fn new(riot_api_config: &'a RiotApiConfig<'a>, client: &'a Client) -> RegionalRequester<'a> {
RegionalRequester {
riot_api_config: riot_api_config,
client: client,
app_rate_limit: RateLimit::new(RateLimitType::Application),
method_rate_limits: HashMap::new(),
}
}
pub async fn get<T: DeserializeOwned>(
&mut self, method_id: &'a str, relative_url: &'_ str,
region: &'_ Region<'_>, query: &[(&'_ str, &'_ str)]) -> Result<Option<T>, reqwest::Error> {
let mut attempts: u8 = 0;
for _ in 0..=self.riot_api_config.retries {
attempts += 1;
// Rate limiting.
let app_rate_limit = &self.app_rate_limit;
let method_rate_limit = self.method_rate_limits.entry(method_id)
.or_insert_with(|| RateLimit::new(RateLimitType::Method));
while let Some(delay) = RateLimit::get_both_or_delay(app_rate_limit, method_rate_limit) {
task::sleep(delay).await;
}
// Send request.
let url = &*format!("https://{}.api.riotgames.com{}", region.platform, relative_url);
let result = self.client.get(url)
.header(Self::RIOT_KEY_HEADER, self.riot_api_config.api_key)
.query(query)
.send()
.await;
let response = match result {
Err(e) => return Err(e),
Ok(r) => r,
};
// Update rate limits (if needed).
app_rate_limit.on_response(&response);
method_rate_limit.on_response(&response);
// Handle response.
let status = response.status();
// Success, return None.
if Self::is_none_status_code(&status) {
return Ok(None);
}
// Success, return a value.
if status.is_success() {
let value = response.json::<T>().await;
return match value {
Err(e) => Err(e),
Ok(v) => Ok(Some(v)),
}
}
// Retryable.
if StatusCode::TOO_MANY_REQUESTS == status || status.is_server_error() {
continue;
}
// Failure (non-retryable).
if status.is_client_error() {
break;
}
panic!("NOT HANDLED: {}!", status);
}
// TODO: return error.
panic!("FAILED AFTER {} ATTEMPTS!", attempts);
}
pub fn get2<T: 'a + DeserializeOwned>(&'a mut self, method_id: &'a str, relative_url: &'a str,
region: &'a Region<'_>, query: &'a [(&'a str, &'a str)]) -> impl Future<Output = Result<Option<T>, reqwest::Error>> + 'a {
self.get(method_id, relative_url, region, query)
}
fn is_none_status_code(status: &StatusCode) -> bool {
Self::NONE_STATUS_CODES.contains(&status.as_u16())
}
}
#[cfg(test)]
mod tests {
use super::*;
fn send_sync() {
fn is_send_sync<T: Send + Sync>() {}
is_send_sync::<RegionalRequester>();
}
}

View file

@ -0,0 +1,21 @@
use std::collections::HashMap;
use reqwest::{
Client,
};
use super::regional_requester::RegionalRequester;
use crate::riot_api_config::RiotApiConfig;
use crate::consts::region::Region;
// pub struct RequesterManager<'a> {
// /// Configuration settings.
// riot_api_config: &'a RiotApiConfig<'a>,
// /// Client for making requests.
// client: &'a Client,
// /// Represents the app rate limit.
// app_rate_limit: RateLimit,
// /// Represents method rate limits.
// method_rate_limits: HashMap<&'a str, RateLimit>,
// }

View file

@ -1,18 +1,19 @@
use std::collections::VecDeque; use std::collections::VecDeque;
use std::time::{Duration, Instant}; use std::time::{Duration, Instant};
use std::sync::{Mutex, MutexGuard};
pub trait TokenBucket { pub trait TokenBucket {
/// Get the duration til the next available token, or 0 duration if a token is available. /// Get the duration til the next available token, or 0 duration if a token is available.
/// # Returns /// # Returns
/// Duration or 0 duration. /// Duration or 0 duration.
fn get_delay(&mut self) -> Duration; fn get_delay(&self) -> Option<Duration>;
/// Gets n tokens, regardless of whether they are available. /// Gets n tokens, regardless of whether they are available.
/// # Parameters /// # Parameters
/// * `n` - Number of tokens to take. /// * `n` - Number of tokens to take.
/// # Returns /// # Returns
/// True if the tokens were obtained without violating limits, false otherwise. /// True if the tokens were obtained without violating limits, false otherwise.
fn get_tokens(&mut self, n: usize) -> bool; fn get_tokens(&self, n: usize) -> bool;
/// Get the duration of this bucket. /// Get the duration of this bucket.
/// # Returns /// # Returns
@ -25,54 +26,61 @@ pub trait TokenBucket {
fn get_total_limit(&self) -> usize; fn get_total_limit(&self) -> usize;
} }
struct VectorTokenBucket { pub struct VectorTokenBucket {
/// Duration of this TokenBucket. /// Duration of this TokenBucket.
duration: Duration, duration: Duration,
// Total tokens available from this TokenBucket. // Total tokens available from this TokenBucket.
total_limit: usize, total_limit: usize,
// Record of timestamps. // Record of timestamps (synchronized).
timestamps: VecDeque<Instant>, timestamps: Mutex<VecDeque<Instant>>,
} }
impl VectorTokenBucket { impl VectorTokenBucket {
fn create(duration: Duration, total_limit: usize) -> Self { pub fn new(duration: Duration, total_limit: usize) -> Self {
VectorTokenBucket { VectorTokenBucket {
duration: duration, duration: duration,
total_limit: total_limit, total_limit: total_limit,
timestamps: VecDeque::new(), timestamps: Mutex::new(VecDeque::new()),
} }
} }
fn update_state(&mut self) { fn update_get_timestamps(&self) -> MutexGuard<VecDeque<Instant>> {
let mut timestamps = self.timestamps.lock().unwrap();
let cutoff = Instant::now() - self.duration; let cutoff = Instant::now() - self.duration;
while self.timestamps.back().map_or(false, |ts| ts < &cutoff) { while timestamps.back().map_or(false, |ts| ts < &cutoff) {
self.timestamps.pop_back(); timestamps.pop_back();
} }
return timestamps;
} }
} }
impl TokenBucket for VectorTokenBucket { impl TokenBucket for VectorTokenBucket {
fn get_delay(&mut self) -> Duration { fn get_delay(&self) -> Option<Duration> {
self.update_state(); let timestamps = self.update_get_timestamps();
if self.timestamps.len() < self.total_limit {
Duration::new(0, 0) if timestamps.len() < self.total_limit {
None
} }
else { else {
let ts = *self.timestamps.get(self.total_limit - 1).unwrap(); // Timestamp that needs to be popped before
Instant::now().saturating_duration_since(ts) // we can enter another timestamp.
let ts = *timestamps.get(self.total_limit - 1).unwrap();
Instant::now().checked_duration_since(ts)
.and_then(|passed_dur| self.duration.checked_sub(passed_dur))
} }
} }
fn get_tokens(&mut self, n: usize) -> bool { fn get_tokens(&self, n: usize) -> bool {
self.update_state(); let mut timestamps = self.update_get_timestamps();
let now = Instant::now(); let now = Instant::now();
self.timestamps.reserve(n); timestamps.reserve(n);
for _ in 0..n { for _ in 0..n {
self.timestamps.push_front(now); timestamps.push_front(now);
} }
self.timestamps.len() <= self.total_limit timestamps.len() <= self.total_limit
} }
fn get_bucket_duration(&self) -> Duration { fn get_bucket_duration(&self) -> Duration {
@ -86,6 +94,13 @@ impl TokenBucket for VectorTokenBucket {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*;
fn send_sync() {
fn is_send_sync<T: Send + Sync>() {}
is_send_sync::<VectorTokenBucket>();
}
#[test] #[test]
fn it_works() { fn it_works() {
assert_eq!(2 + 2, 4); assert_eq!(2 + 2, 4);

4
src/riot_api_config.rs Normal file
View file

@ -0,0 +1,4 @@
pub struct RiotApiConfig<'a> {
pub api_key: &'a str,
pub retries: u8,
}