From d745aaa8682bbf4d75935b7bbd8e71f1c08f2305 Mon Sep 17 00:00:00 2001 From: Mingwei Samuel Date: Sun, 5 May 2024 23:47:40 -0700 Subject: [PATCH] feat: Add RSO support --- riven/src/config.rs | 21 ++++++++++++ riven/src/endpoints.rs | 62 +++++++++++++++++++++++++----------- riven/src/meta.rs | 2 +- riven/src/models.rs | 2 +- riven/src/riot_api.rs | 5 +++ riven/srcgen/endpoints.rs.dt | 46 +++++++++++++------------- riven/tests/tests_rso.rs | 25 +++++++++++++++ 7 files changed, 118 insertions(+), 45 deletions(-) create mode 100644 riven/tests/tests_rso.rs diff --git a/riven/src/config.rs b/riven/src/config.rs index e1c7fc9..48614da 100644 --- a/riven/src/config.rs +++ b/riven/src/config.rs @@ -14,6 +14,7 @@ pub struct RiotApiConfig { pub(crate) burst_factor: f32, pub(crate) duration_overhead: Duration, pub(crate) client_builder: Option, + pub(crate) rso_clear_header: Option, } impl RiotApiConfig { @@ -81,6 +82,7 @@ impl RiotApiConfig { burst_factor: Self::PRECONFIG_BURST_BURST_FACTOR, duration_overhead: Self::PRECONFIG_BURST_DURATION_OVERHEAD, client_builder: Some(ClientBuilder::new().default_headers(default_headers)), + rso_clear_header: Some(Self::RIOT_KEY_HEADER.to_owned()), } } @@ -101,6 +103,7 @@ impl RiotApiConfig { burst_factor: Self::PRECONFIG_BURST_BURST_FACTOR, duration_overhead: Self::PRECONFIG_BURST_DURATION_OVERHEAD, client_builder: Some(client_builder), + rso_clear_header: Some(Self::RIOT_KEY_HEADER.to_owned()), } } @@ -295,6 +298,24 @@ impl RiotApiConfig { self.duration_overhead = duration_overhead; self } + + /// Sets the header to clear for RSO requests (if `Some`), or will not override any headers (if + /// `None`). + /// + /// This is a bit of a hack. The client used by Riven is expected to include the API key as a + /// default header. However, if the API key is included in an [RSO](https://developer.riotgames.com/docs/lol#rso-integration) + /// request the server responds with a 400 "Bad request - Invalid authorization specified" + /// error. To avoid this the `rso_clear_header` header is overridden to be empty for RSO + /// requests. + /// + /// This is set to `Some(`[`Self::RIOT_KEY_HEADER`]`)` by default. + /// + /// # Returns + /// `self`, for chaining. + pub fn set_rso_clear_header(mut self, rso_clear_header: Option) -> Self { + self.rso_clear_header = rso_clear_header; + self + } } impl> From for RiotApiConfig { diff --git a/riven/src/endpoints.rs b/riven/src/endpoints.rs index 6a1098c..6ce7695 100644 --- a/riven/src/endpoints.rs +++ b/riven/src/endpoints.rs @@ -8,7 +8,7 @@ /////////////////////////////////////////////// // http://www.mingweisamuel.com/riotapi-schema/tool/ -// Version a70746fcf353ba0ad0aceceafcc70d4ba8de4431 +// Version 92f57e3e7279cc02ec6a5ce6665ca08354d6a178 //! Automatically generated endpoint handles. #![allow(clippy::let_and_return, clippy::too_many_arguments)] @@ -324,17 +324,21 @@ impl<'a> AccountV1<'a> { /// Get account by access token /// # Parameters /// * `route` - Route to query. - /// * `authorization` (required, in header) + /// * `access_token` - RSO access token. + /// # RSO + /// This endpoint uses [Riot Sign On](https://developer.riotgames.com/docs/lol#rso-integration) + /// via the `access_token` parameter, instead of the Riot API key. /// # Riot Developer API Reference /// `account-v1.getByAccessToken` /// /// Note: this method is automatically generated. - pub fn get_by_access_token(&self, route: RegionalRoute, authorization: &str) + pub fn get_by_access_token(&self, route: RegionalRoute, access_token: impl std::fmt::Display) -> impl Future> + 'a { let route_str = route.into(); let request = self.base.request(Method::GET, route_str, "/riot/account/v1/accounts/me"); - let request = request.header("Authorization", authorization); + let mut request = request.bearer_auth(access_token); + if let Some(clear) = self.base.get_rso_clear_header() { request = request.header(clear, "") } let future = self.base.execute_val::("account-v1.getByAccessToken", route_str, request); #[cfg(feature = "tracing")] let future = future.instrument(tracing::info_span!("account-v1.getByAccessToken")); @@ -927,17 +931,21 @@ impl<'a> LorDeckV1<'a> { /// Get a list of the calling user's decks. /// # Parameters /// * `route` - Route to query. - /// * `authorization` (required, in header) + /// * `access_token` - RSO access token. + /// # RSO + /// This endpoint uses [Riot Sign On](https://developer.riotgames.com/docs/lol#rso-integration) + /// via the `access_token` parameter, instead of the Riot API key. /// # Riot Developer API Reference /// `lor-deck-v1.getDecks` /// /// Note: this method is automatically generated. - pub fn get_decks(&self, route: RegionalRoute, authorization: &str) + pub fn get_decks(&self, route: RegionalRoute, access_token: impl std::fmt::Display) -> impl Future>> + 'a { let route_str = route.into(); let request = self.base.request(Method::GET, route_str, "/lor/deck/v1/decks/me"); - let request = request.header("Authorization", authorization); + let mut request = request.bearer_auth(access_token); + if let Some(clear) = self.base.get_rso_clear_header() { request = request.header(clear, "") } let future = self.base.execute_val::>("lor-deck-v1.getDecks", route_str, request); #[cfg(feature = "tracing")] let future = future.instrument(tracing::info_span!("lor-deck-v1.getDecks")); @@ -947,17 +955,21 @@ impl<'a> LorDeckV1<'a> { /// Create a new deck for the calling user. /// # Parameters /// * `route` - Route to query. - /// * `authorization` (required, in header) + /// * `access_token` - RSO access token. + /// # RSO + /// This endpoint uses [Riot Sign On](https://developer.riotgames.com/docs/lol#rso-integration) + /// via the `access_token` parameter, instead of the Riot API key. /// # Riot Developer API Reference /// `lor-deck-v1.createDeck` /// /// Note: this method is automatically generated. - pub fn create_deck(&self, route: RegionalRoute, body: &lor_deck_v1::NewDeck, authorization: &str) + pub fn create_deck(&self, route: RegionalRoute, access_token: impl std::fmt::Display, body: &lor_deck_v1::NewDeck) -> impl Future> + 'a { let route_str = route.into(); let request = self.base.request(Method::POST, route_str, "/lor/deck/v1/decks/me"); - let request = request.header("Authorization", authorization); + let mut request = request.bearer_auth(access_token); + if let Some(clear) = self.base.get_rso_clear_header() { request = request.header(clear, "") } let request = request.body(serde_json::ser::to_vec(body).unwrap()); let future = self.base.execute_val::("lor-deck-v1.createDeck", route_str, request); #[cfg(feature = "tracing")] @@ -980,17 +992,21 @@ impl<'a> LorInventoryV1<'a> { /// Return a list of cards owned by the calling user. /// # Parameters /// * `route` - Route to query. - /// * `authorization` (required, in header) + /// * `access_token` - RSO access token. + /// # RSO + /// This endpoint uses [Riot Sign On](https://developer.riotgames.com/docs/lol#rso-integration) + /// via the `access_token` parameter, instead of the Riot API key. /// # Riot Developer API Reference /// `lor-inventory-v1.getCards` /// /// Note: this method is automatically generated. - pub fn get_cards(&self, route: RegionalRoute, authorization: &str) + pub fn get_cards(&self, route: RegionalRoute, access_token: impl std::fmt::Display) -> impl Future>> + 'a { let route_str = route.into(); let request = self.base.request(Method::GET, route_str, "/lor/inventory/v1/cards/me"); - let request = request.header("Authorization", authorization); + let mut request = request.bearer_auth(access_token); + if let Some(clear) = self.base.get_rso_clear_header() { request = request.header(clear, "") } let future = self.base.execute_val::>("lor-inventory-v1.getCards", route_str, request); #[cfg(feature = "tracing")] let future = future.instrument(tracing::info_span!("lor-inventory-v1.getCards")); @@ -1358,17 +1374,21 @@ impl<'a> SummonerV4<'a> { /// Get a summoner by access token. /// # Parameters /// * `route` - Route to query. - /// * `authorization` (optional, in header) - Bearer token + /// * `access_token` - RSO access token. + /// # RSO + /// This endpoint uses [Riot Sign On](https://developer.riotgames.com/docs/lol#rso-integration) + /// via the `access_token` parameter, instead of the Riot API key. /// # Riot Developer API Reference /// `summoner-v4.getByAccessToken` /// /// Note: this method is automatically generated. - pub fn get_by_access_token(&self, route: PlatformRoute, authorization: Option<&str>) + pub fn get_by_access_token(&self, route: PlatformRoute, access_token: impl std::fmt::Display) -> impl Future> + 'a { let route_str = route.into(); let request = self.base.request(Method::GET, route_str, "/lol/summoner/v4/summoners/me"); - let mut request = request; if let Some(authorization) = authorization { request = request.header("Authorization", authorization); } + let mut request = request.bearer_auth(access_token); + if let Some(clear) = self.base.get_rso_clear_header() { request = request.header(clear, "") } let future = self.base.execute_val::("summoner-v4.getByAccessToken", route_str, request); #[cfg(feature = "tracing")] let future = future.instrument(tracing::info_span!("summoner-v4.getByAccessToken")); @@ -1688,17 +1708,21 @@ impl<'a> TftSummonerV1<'a> { /// Get a summoner by access token. /// # Parameters /// * `route` - Route to query. - /// * `authorization` (optional, in header) - Bearer token. + /// * `access_token` - RSO access token. + /// # RSO + /// This endpoint uses [Riot Sign On](https://developer.riotgames.com/docs/lol#rso-integration) + /// via the `access_token` parameter, instead of the Riot API key. /// # Riot Developer API Reference /// `tft-summoner-v1.getByAccessToken` /// /// Note: this method is automatically generated. - pub fn get_by_access_token(&self, route: PlatformRoute, authorization: Option<&str>) + pub fn get_by_access_token(&self, route: PlatformRoute, access_token: impl std::fmt::Display) -> impl Future> + 'a { let route_str = route.into(); let request = self.base.request(Method::GET, route_str, "/tft/summoner/v1/summoners/me"); - let mut request = request; if let Some(authorization) = authorization { request = request.header("Authorization", authorization); } + let mut request = request.bearer_auth(access_token); + if let Some(clear) = self.base.get_rso_clear_header() { request = request.header(clear, "") } let future = self.base.execute_val::("tft-summoner-v1.getByAccessToken", route_str, request); #[cfg(feature = "tracing")] let future = future.instrument(tracing::info_span!("tft-summoner-v1.getByAccessToken")); diff --git a/riven/src/meta.rs b/riven/src/meta.rs index 5a127a8..4615430 100644 --- a/riven/src/meta.rs +++ b/riven/src/meta.rs @@ -8,7 +8,7 @@ /////////////////////////////////////////////// // http://www.mingweisamuel.com/riotapi-schema/tool/ -// Version a70746fcf353ba0ad0aceceafcc70d4ba8de4431 +// Version 92f57e3e7279cc02ec6a5ce6665ca08354d6a178 //! Metadata about the Riot API and Riven. //! diff --git a/riven/src/models.rs b/riven/src/models.rs index 20aa408..899dd6f 100644 --- a/riven/src/models.rs +++ b/riven/src/models.rs @@ -8,7 +8,7 @@ /////////////////////////////////////////////// // http://www.mingweisamuel.com/riotapi-schema/tool/ -// Version a70746fcf353ba0ad0aceceafcc70d4ba8de4431 +// Version 92f57e3e7279cc02ec6a5ce6665ca08354d6a178 #![allow(missing_docs)] diff --git a/riven/src/riot_api.rs b/riven/src/riot_api.rs index f09a600..8eca341 100644 --- a/riven/src/riot_api.rs +++ b/riven/src/riot_api.rs @@ -187,6 +187,11 @@ impl RiotApi { .execute(&self.config, method_id, request) } + /// Gets the [`RiotApiConfig::rso_clear_header`] for use in RSO endpoints. + pub(crate) fn get_rso_clear_header(&self) -> Option<&str> { + self.config.rso_clear_header.as_deref() + } + /// Get or create the RegionalRequester for the given region. fn regional_requester(&self, region_platform: &'static str) -> Arc { self.regional_requesters diff --git a/riven/srcgen/endpoints.rs.dt b/riven/srcgen/endpoints.rs.dt index ea76b33..d06d0d8 100644 --- a/riven/srcgen/endpoints.rs.dt +++ b/riven/srcgen/endpoints.rs.dt @@ -77,6 +77,7 @@ impl<'a> {{= endpoint }}<'a> { const method = dotUtils.changeCase.snakeCase(operationId.slice(operationId.indexOf('.') + 1)); const resp200 = operation.responses['200']; + const isRso = (null != operation.security) && (null != operation.security[0]['rso']); /* Return type checks. */ let hasReturn = false; @@ -109,6 +110,9 @@ impl<'a> {{= endpoint }}<'a> { const argBuilder = [ 'route: ', dotUtils.changeCase.pascalCase(operation['x-route-enum']), 'Route' ]; + if (isRso) { + argBuilder.push(', access_token: impl std::fmt::Display'); + } /* Add body params before path/query. */ if (bodyType) { @@ -161,17 +165,17 @@ impl<'a> {{= endpoint }}<'a> { }} /// # Parameters /// * `route` - Route to query. -{{ - if (allParams) - { - for (let param of allParams) - { -}} +{{? isRso }} + /// * `access_token` - RSO access token. +{{?}} +{{~ allParams || [] :param }} /// * `{{= dotUtils.changeCase.snakeCase(param.name) }}` ({{= param.required ? 'required' : 'optional' }}, in {{= param.in }}){{= param.description ? ' - ' + param.description : ''}} -{{ - } - } -}} +{{~}} +{{? isRso }} + /// # RSO + /// This endpoint uses [Riot Sign On](https://developer.riotgames.com/docs/lol#rso-integration) + /// via the `access_token` parameter, instead of the Riot API key. +{{?}} /// # Riot Developer API Reference /// `{{= operationId }}` /// @@ -181,22 +185,16 @@ impl<'a> {{= endpoint }}<'a> { { let route_str = route.into(); let request = self.base.request(Method::{{= verb.toUpperCase() }}, route_str, {{= routeArgument }}); -{{ - for (let queryParam of queryParams) - { -}} +{{? isRso }} + let mut request = request.bearer_auth(access_token); + if let Some(clear) = self.base.get_rso_clear_header() { request = request.header(clear, "") } +{{?}} +{{~ queryParams :queryParam }} {{= dotUtils.formatAddQueryParam(queryParam) }} -{{ - } -}} -{{ - for (const headerParam of headerParams) - { -}} +{{~}} +{{~ headerParams :headerParam }} {{= dotUtils.formatAddHeaderParam(headerParam) }} -{{ - } -}} +{{~}} {{? bodyType }} let request = request.body(serde_json::ser::to_vec(body).unwrap()); {{?}} diff --git a/riven/tests/tests_rso.rs b/riven/tests/tests_rso.rs new file mode 100644 index 0000000..58baf6f --- /dev/null +++ b/riven/tests/tests_rso.rs @@ -0,0 +1,25 @@ +mod testutils; + +use riven::consts::*; +use testutils::*; + +const ROUTE: RegionalRoute = RegionalRoute::AMERICAS; + +/// https://developer.riotgames.com/apis#account-v1/GET_getByAccessToken +#[riven_test] +async fn account_v1_getbyaccesstoken() -> Result<(), String> { + let Ok(access_token) = std::env::var("RGAPI_ACCESS_TOKEN") else { + eprintln!("`RGAPI_ACCESS_TOKEN` env var not set, cannot test RSO."); + return Ok(()); + }; + + let account = riot_api() + .account_v1() + .get_by_access_token(ROUTE, access_token) + .await + .map_err(|e| format!("Failed to get account by riot ID: {}", e))?; + + println!("{:#?}", account); + + Ok(()) +}