forked from mirror/Riven
rate limiting wip
This commit is contained in:
parent
4d6c45edcf
commit
5700547c05
12 changed files with 383 additions and 28 deletions
1
.gitignore
vendored
1
.gitignore
vendored
|
@ -1,3 +1,4 @@
|
|||
/target
|
||||
**/*.rs.bk
|
||||
Cargo.lock
|
||||
/doc
|
||||
|
|
|
@ -2,7 +2,7 @@
|
|||
name = "riven"
|
||||
version = "0.0.1"
|
||||
authors = ["Mingwei Samuel <mingwei.samuel@gmail.com>"]
|
||||
description = "RiotApi library (wip)"
|
||||
description = "Riot API Library (WIP)"
|
||||
license = "LGPL-3.0"
|
||||
edition = "2018"
|
||||
exclude = [
|
||||
|
@ -12,3 +12,6 @@ exclude = [
|
|||
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
|
||||
|
||||
[dependencies]
|
||||
async-std = "0.99"
|
||||
reqwest = { version = "0.10.0-alpha.1", features = [ "gzip", "json" ] }
|
||||
serde = "^1.0"
|
||||
|
|
1
src/consts/mod.rs
Normal file
1
src/consts/mod.rs
Normal file
|
@ -0,0 +1 @@
|
|||
pub mod region;
|
73
src/consts/region.rs
Normal file
73
src/consts/region.rs
Normal file
|
@ -0,0 +1,73 @@
|
|||
#[derive(Debug, PartialEq, Eq, Hash, PartialOrd, Ord)]
|
||||
pub struct Region<'a> {
|
||||
pub key: &'a str,
|
||||
pub platform: &'a str,
|
||||
}
|
||||
|
||||
macro_rules! regions {
|
||||
(
|
||||
$(
|
||||
$key:ident => $plat:expr ;
|
||||
)*
|
||||
) => {
|
||||
$(
|
||||
const $key: &'static Region<'static> = &Region {
|
||||
key: stringify!($key),
|
||||
platform: $plat,
|
||||
};
|
||||
)*
|
||||
|
||||
#[doc="Get region by name."]
|
||||
#[doc="# Arguments"]
|
||||
#[doc="* `name` - Case-insensitive ASCII string to match Regions' `key` or `playform`."]
|
||||
#[doc="# Returns"]
|
||||
#[doc="`Some(&Region)` if match found, `None` if no match found."]
|
||||
#[allow(unreachable_patterns)]
|
||||
pub fn get(name: &str) -> Option<&Region> {
|
||||
match &*name.to_ascii_uppercase() {
|
||||
$(
|
||||
stringify!($key) | $plat => Some(Self::$key),
|
||||
)*
|
||||
_ => None
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Region<'_> {
|
||||
// Is this stupid?
|
||||
regions! {
|
||||
BR => "BR1";
|
||||
EUNE => "EUN1";
|
||||
EUW => "EUW1";
|
||||
NA => "NA1";
|
||||
KR => "KR";
|
||||
LAN => "LA1";
|
||||
LAS => "LA2";
|
||||
OCE => "OC1";
|
||||
RU => "RU";
|
||||
TR => "TR1";
|
||||
JP => "JP1";
|
||||
PBE => "PBE1";
|
||||
AMERICAS => "AMERICAS";
|
||||
EUROPE => "EUROPE";
|
||||
ASIA => "ASIA";
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_basic() {
|
||||
assert_eq!("BR1", Region::BR.platform);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_get() {
|
||||
assert_eq!(Some(Region::AMERICAS), Region::get("amEricAs"));
|
||||
assert_eq!(Some(Region::NA), Region::get("na1"));
|
||||
assert_eq!(None, Region::get("LA"));
|
||||
}
|
||||
}
|
|
@ -1,6 +1,9 @@
|
|||
mod req;
|
||||
#![allow(dead_code)] // TODO REMOVE
|
||||
|
||||
use req::rate_limit;
|
||||
pub mod consts;
|
||||
|
||||
mod req;
|
||||
mod riot_api_config;
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
|
|
|
@ -1,2 +1,5 @@
|
|||
pub mod rate_limit;
|
||||
pub mod rate_limit_type;
|
||||
pub mod token_bucket;
|
||||
pub mod regional_requester;
|
||||
pub mod requester_manager;
|
||||
|
|
|
@ -1,6 +1,83 @@
|
|||
use std::time::Duration;
|
||||
use std::cmp;
|
||||
use std::time::{
|
||||
Duration,
|
||||
Instant,
|
||||
};
|
||||
use std::sync::{
|
||||
RwLock,
|
||||
};
|
||||
|
||||
pub trait RateLimit {
|
||||
fn get_retry_after_delay(&self) -> Duration;
|
||||
use super::token_bucket::{
|
||||
TokenBucket,
|
||||
VectorTokenBucket,
|
||||
};
|
||||
use super::rate_limit_type::RateLimitType;
|
||||
|
||||
pub struct RateLimit {
|
||||
rate_limit_type: RateLimitType,
|
||||
// Buckets for this rate limit (synchronized).
|
||||
// Almost always read, written only when rate limit rates are updated
|
||||
// from API response.
|
||||
// TODO: Question of writer starvation.
|
||||
buckets: RwLock<Vec<VectorTokenBucket>>,
|
||||
// Set to when we can retry if a retry-after header is received.
|
||||
retry_after: Option<Instant>,
|
||||
}
|
||||
|
||||
impl RateLimit {
|
||||
/// Header specifying which RateLimitType caused a 429.
|
||||
const HEADER_XRATELIMITTYPE: &'static str = "X-Rate-Limit-Type";
|
||||
/// Header specifying retry after time in seconds after a 429.
|
||||
const HEADER_RETRYAFTER: &'static str = "Retry-After";
|
||||
|
||||
pub fn new(rate_limit_type: RateLimitType) -> Self {
|
||||
let initial_bucket = VectorTokenBucket::new(Duration::from_secs(1), 1);
|
||||
RateLimit {
|
||||
rate_limit_type: rate_limit_type,
|
||||
// Rate limit before getting from response: 1/s.
|
||||
buckets: RwLock::new(vec![initial_bucket]),
|
||||
retry_after: None,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn get_both_or_delay(app_rate_limit: &Self, method_rate_limit: &Self) -> Option<Duration> {
|
||||
// Check retry after.
|
||||
let retry_after_delay = app_rate_limit.get_retry_after_delay()
|
||||
.and_then(|a| method_rate_limit.get_retry_after_delay().map(|m| cmp::max(a, m)));
|
||||
if retry_after_delay.is_some() {
|
||||
return retry_after_delay
|
||||
}
|
||||
// Check buckets.
|
||||
let app_buckets = app_rate_limit.buckets.read().unwrap();
|
||||
let method_buckets = app_rate_limit.buckets.read().unwrap();
|
||||
for bucket in app_buckets.iter().chain(method_buckets.iter()) {
|
||||
let delay = bucket.get_delay();
|
||||
if delay.is_some() {
|
||||
return delay;
|
||||
}
|
||||
}
|
||||
// Success.
|
||||
for bucket in app_buckets.iter().chain(method_buckets.iter()) {
|
||||
bucket.get_tokens(1);
|
||||
}
|
||||
None
|
||||
}
|
||||
|
||||
pub fn get_retry_after_delay(&self) -> Option<Duration> {
|
||||
self.retry_after.and_then(|i| Instant::now().checked_duration_since(i))
|
||||
}
|
||||
|
||||
pub fn on_response(&self, _response: &reqwest::Response) {
|
||||
unimplemented!();
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
fn send_sync() {
|
||||
fn is_send_sync<T: Send + Sync>() {}
|
||||
is_send_sync::<RateLimit>();
|
||||
}
|
||||
}
|
||||
|
|
27
src/req/rate_limit_type.rs
Normal file
27
src/req/rate_limit_type.rs
Normal file
|
@ -0,0 +1,27 @@
|
|||
pub enum RateLimitType {
|
||||
Application,
|
||||
Method,
|
||||
}
|
||||
|
||||
impl RateLimitType {
|
||||
pub fn type_name(self) -> &'static str {
|
||||
match self {
|
||||
Self::Application => "application",
|
||||
Self::Method => "method",
|
||||
}
|
||||
}
|
||||
|
||||
pub fn limit_header(self) -> &'static str {
|
||||
match self {
|
||||
Self::Application => "X-App-Rate-Limit",
|
||||
Self::Method => "X-Method-Rate-Limit",
|
||||
}
|
||||
}
|
||||
|
||||
pub fn count_header(self) -> &'static str {
|
||||
match self {
|
||||
Self::Application => "X-App-Rate-Limit-Count",
|
||||
Self::Method => "X-Method-Rate-Limit-Count",
|
||||
}
|
||||
}
|
||||
}
|
127
src/req/regional_requester.rs
Normal file
127
src/req/regional_requester.rs
Normal file
|
@ -0,0 +1,127 @@
|
|||
use std::collections::HashMap;
|
||||
use std::future::Future;
|
||||
use std::sync::RwLock;
|
||||
|
||||
use async_std::task;
|
||||
use reqwest::{
|
||||
Client,
|
||||
StatusCode,
|
||||
};
|
||||
use serde::de::DeserializeOwned;
|
||||
|
||||
use super::rate_limit::RateLimit;
|
||||
use super::rate_limit_type::RateLimitType;
|
||||
use crate::riot_api_config::RiotApiConfig;
|
||||
use crate::consts::region::Region;
|
||||
|
||||
pub struct RegionalRequester<'a> {
|
||||
/// Configuration settings.
|
||||
riot_api_config: &'a RiotApiConfig<'a>,
|
||||
/// Client for making requests.
|
||||
client: &'a Client,
|
||||
|
||||
/// Represents the app rate limit.
|
||||
app_rate_limit: RateLimit,
|
||||
/// Represents method rate limits.
|
||||
method_rate_limits: HashMap<&'a str, RateLimit>,
|
||||
}
|
||||
|
||||
impl <'a> RegionalRequester<'a> {
|
||||
/// Request header name for the Riot API key.
|
||||
const RIOT_KEY_HEADER: &'static str = "X-Riot-Token";
|
||||
|
||||
/// HttpStatus codes that are considered a success, but will return None.
|
||||
const NONE_STATUS_CODES: [u16; 3] = [ 204, 404, 422 ];
|
||||
|
||||
|
||||
pub fn new(riot_api_config: &'a RiotApiConfig<'a>, client: &'a Client) -> RegionalRequester<'a> {
|
||||
RegionalRequester {
|
||||
riot_api_config: riot_api_config,
|
||||
client: client,
|
||||
app_rate_limit: RateLimit::new(RateLimitType::Application),
|
||||
method_rate_limits: HashMap::new(),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn get<T: DeserializeOwned>(
|
||||
&mut self, method_id: &'a str, relative_url: &'_ str,
|
||||
region: &'_ Region<'_>, query: &[(&'_ str, &'_ str)]) -> Result<Option<T>, reqwest::Error> {
|
||||
|
||||
let mut attempts: u8 = 0;
|
||||
for _ in 0..=self.riot_api_config.retries {
|
||||
attempts += 1;
|
||||
|
||||
// Rate limiting.
|
||||
let app_rate_limit = &self.app_rate_limit;
|
||||
let method_rate_limit = self.method_rate_limits.entry(method_id)
|
||||
.or_insert_with(|| RateLimit::new(RateLimitType::Method));
|
||||
|
||||
while let Some(delay) = RateLimit::get_both_or_delay(app_rate_limit, method_rate_limit) {
|
||||
task::sleep(delay).await;
|
||||
}
|
||||
|
||||
// Send request.
|
||||
let url = &*format!("https://{}.api.riotgames.com{}", region.platform, relative_url);
|
||||
let result = self.client.get(url)
|
||||
.header(Self::RIOT_KEY_HEADER, self.riot_api_config.api_key)
|
||||
.query(query)
|
||||
.send()
|
||||
.await;
|
||||
|
||||
let response = match result {
|
||||
Err(e) => return Err(e),
|
||||
Ok(r) => r,
|
||||
};
|
||||
|
||||
// Update rate limits (if needed).
|
||||
app_rate_limit.on_response(&response);
|
||||
method_rate_limit.on_response(&response);
|
||||
|
||||
// Handle response.
|
||||
let status = response.status();
|
||||
// Success, return None.
|
||||
if Self::is_none_status_code(&status) {
|
||||
return Ok(None);
|
||||
}
|
||||
// Success, return a value.
|
||||
if status.is_success() {
|
||||
let value = response.json::<T>().await;
|
||||
return match value {
|
||||
Err(e) => Err(e),
|
||||
Ok(v) => Ok(Some(v)),
|
||||
}
|
||||
}
|
||||
// Retryable.
|
||||
if StatusCode::TOO_MANY_REQUESTS == status || status.is_server_error() {
|
||||
continue;
|
||||
}
|
||||
// Failure (non-retryable).
|
||||
if status.is_client_error() {
|
||||
break;
|
||||
}
|
||||
panic!("NOT HANDLED: {}!", status);
|
||||
}
|
||||
// TODO: return error.
|
||||
panic!("FAILED AFTER {} ATTEMPTS!", attempts);
|
||||
}
|
||||
|
||||
pub fn get2<T: 'a + DeserializeOwned>(&'a mut self, method_id: &'a str, relative_url: &'a str,
|
||||
region: &'a Region<'_>, query: &'a [(&'a str, &'a str)]) -> impl Future<Output = Result<Option<T>, reqwest::Error>> + 'a {
|
||||
|
||||
self.get(method_id, relative_url, region, query)
|
||||
}
|
||||
|
||||
fn is_none_status_code(status: &StatusCode) -> bool {
|
||||
Self::NONE_STATUS_CODES.contains(&status.as_u16())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
fn send_sync() {
|
||||
fn is_send_sync<T: Send + Sync>() {}
|
||||
is_send_sync::<RegionalRequester>();
|
||||
}
|
||||
}
|
21
src/req/requester_manager.rs
Normal file
21
src/req/requester_manager.rs
Normal file
|
@ -0,0 +1,21 @@
|
|||
use std::collections::HashMap;
|
||||
|
||||
use reqwest::{
|
||||
Client,
|
||||
};
|
||||
|
||||
use super::regional_requester::RegionalRequester;
|
||||
use crate::riot_api_config::RiotApiConfig;
|
||||
use crate::consts::region::Region;
|
||||
|
||||
// pub struct RequesterManager<'a> {
|
||||
// /// Configuration settings.
|
||||
// riot_api_config: &'a RiotApiConfig<'a>,
|
||||
// /// Client for making requests.
|
||||
// client: &'a Client,
|
||||
|
||||
// /// Represents the app rate limit.
|
||||
// app_rate_limit: RateLimit,
|
||||
// /// Represents method rate limits.
|
||||
// method_rate_limits: HashMap<&'a str, RateLimit>,
|
||||
// }
|
|
@ -1,18 +1,19 @@
|
|||
use std::collections::VecDeque;
|
||||
use std::time::{Duration, Instant};
|
||||
use std::sync::{Mutex, MutexGuard};
|
||||
|
||||
pub trait TokenBucket {
|
||||
/// Get the duration til the next available token, or 0 duration if a token is available.
|
||||
/// # Returns
|
||||
/// Duration or 0 duration.
|
||||
fn get_delay(&mut self) -> Duration;
|
||||
fn get_delay(&self) -> Option<Duration>;
|
||||
|
||||
/// Gets n tokens, regardless of whether they are available.
|
||||
/// # Parameters
|
||||
/// * `n` - Number of tokens to take.
|
||||
/// # Returns
|
||||
/// True if the tokens were obtained without violating limits, false otherwise.
|
||||
fn get_tokens(&mut self, n: usize) -> bool;
|
||||
fn get_tokens(&self, n: usize) -> bool;
|
||||
|
||||
/// Get the duration of this bucket.
|
||||
/// # Returns
|
||||
|
@ -25,54 +26,61 @@ pub trait TokenBucket {
|
|||
fn get_total_limit(&self) -> usize;
|
||||
}
|
||||
|
||||
struct VectorTokenBucket {
|
||||
pub struct VectorTokenBucket {
|
||||
/// Duration of this TokenBucket.
|
||||
duration: Duration,
|
||||
// Total tokens available from this TokenBucket.
|
||||
total_limit: usize,
|
||||
// Record of timestamps.
|
||||
timestamps: VecDeque<Instant>,
|
||||
// Record of timestamps (synchronized).
|
||||
timestamps: Mutex<VecDeque<Instant>>,
|
||||
}
|
||||
|
||||
impl VectorTokenBucket {
|
||||
fn create(duration: Duration, total_limit: usize) -> Self {
|
||||
pub fn new(duration: Duration, total_limit: usize) -> Self {
|
||||
VectorTokenBucket {
|
||||
duration: duration,
|
||||
total_limit: total_limit,
|
||||
timestamps: VecDeque::new(),
|
||||
timestamps: Mutex::new(VecDeque::new()),
|
||||
}
|
||||
}
|
||||
|
||||
fn update_state(&mut self) {
|
||||
fn update_get_timestamps(&self) -> MutexGuard<VecDeque<Instant>> {
|
||||
let mut timestamps = self.timestamps.lock().unwrap();
|
||||
let cutoff = Instant::now() - self.duration;
|
||||
while self.timestamps.back().map_or(false, |ts| ts < &cutoff) {
|
||||
self.timestamps.pop_back();
|
||||
while timestamps.back().map_or(false, |ts| ts < &cutoff) {
|
||||
timestamps.pop_back();
|
||||
}
|
||||
return timestamps;
|
||||
}
|
||||
}
|
||||
|
||||
impl TokenBucket for VectorTokenBucket {
|
||||
|
||||
fn get_delay(&mut self) -> Duration {
|
||||
self.update_state();
|
||||
if self.timestamps.len() < self.total_limit {
|
||||
Duration::new(0, 0)
|
||||
fn get_delay(&self) -> Option<Duration> {
|
||||
let timestamps = self.update_get_timestamps();
|
||||
|
||||
if timestamps.len() < self.total_limit {
|
||||
None
|
||||
}
|
||||
else {
|
||||
let ts = *self.timestamps.get(self.total_limit - 1).unwrap();
|
||||
Instant::now().saturating_duration_since(ts)
|
||||
// Timestamp that needs to be popped before
|
||||
// we can enter another timestamp.
|
||||
let ts = *timestamps.get(self.total_limit - 1).unwrap();
|
||||
Instant::now().checked_duration_since(ts)
|
||||
.and_then(|passed_dur| self.duration.checked_sub(passed_dur))
|
||||
}
|
||||
}
|
||||
|
||||
fn get_tokens(&mut self, n: usize) -> bool {
|
||||
self.update_state();
|
||||
fn get_tokens(&self, n: usize) -> bool {
|
||||
let mut timestamps = self.update_get_timestamps();
|
||||
|
||||
let now = Instant::now();
|
||||
|
||||
self.timestamps.reserve(n);
|
||||
timestamps.reserve(n);
|
||||
for _ in 0..n {
|
||||
self.timestamps.push_front(now);
|
||||
timestamps.push_front(now);
|
||||
}
|
||||
self.timestamps.len() <= self.total_limit
|
||||
timestamps.len() <= self.total_limit
|
||||
}
|
||||
|
||||
fn get_bucket_duration(&self) -> Duration {
|
||||
|
@ -86,6 +94,13 @@ impl TokenBucket for VectorTokenBucket {
|
|||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
fn send_sync() {
|
||||
fn is_send_sync<T: Send + Sync>() {}
|
||||
is_send_sync::<VectorTokenBucket>();
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn it_works() {
|
||||
assert_eq!(2 + 2, 4);
|
||||
|
|
4
src/riot_api_config.rs
Normal file
4
src/riot_api_config.rs
Normal file
|
@ -0,0 +1,4 @@
|
|||
pub struct RiotApiConfig<'a> {
|
||||
pub api_key: &'a str,
|
||||
pub retries: u8,
|
||||
}
|
Loading…
Reference in a new issue