From 594075583c87041521ed0342ecf008de386b0dac Mon Sep 17 00:00:00 2001 From: Rutger <35080130+Rutgerdj@users.noreply.github.com> Date: Mon, 22 May 2023 12:53:31 +0200 Subject: [PATCH] Added a way to specify custom functions which decide whether a request should be retried or not (#33) * Add a generic function to the middleware struct for the `Retryable` decision. The generic function can be used to define custom behaviour to decide whether to retry a request or not. By default, this function is `Retryable::from_reqwest_response` which is the same as it was before. * Add a way to create custom retry policies. A RetryStrategy will dictate what decision will be made based on the result of the sent request. * Add RetryableStrategy in the `RetryTransientMiddleware` struct instead of the seperate functions * Add constructor to create a `RetryTransientMiddleware` with a custom `RetryableStrategy` * Run `cargo fmt` * Add example code to the `RetryableStrategy` struct * Run `cargo fmt` * Updated changelog * use a trait * docs * include latest changes Co-authored-by: Conrad Ludgate --- reqwest-retry/CHANGELOG.md | 3 +- reqwest-retry/src/lib.rs | 5 + reqwest-retry/src/middleware.rs | 44 ++++- reqwest-retry/src/retryable.rs | 102 +----------- reqwest-retry/src/retryable_strategy.rs | 213 ++++++++++++++++++++++++ 5 files changed, 257 insertions(+), 110 deletions(-) create mode 100644 reqwest-retry/src/retryable_strategy.rs diff --git a/reqwest-retry/CHANGELOG.md b/reqwest-retry/CHANGELOG.md index cf1d023..73f76f2 100644 --- a/reqwest-retry/CHANGELOG.md +++ b/reqwest-retry/CHANGELOG.md @@ -5,6 +5,8 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). ## [Unreleased] +### Added +- `RetryableStrategy` which allows for custom retry decisions based on the response that a request got ## [0.2.1] - 2022-12-01 @@ -12,7 +14,6 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Classify `io::Error`s and `hyper::Error(Canceled)` as transient ## [0.2.0] - 2022-11-15 - ### Changed - Updated `reqwest-middleware` to `0.2.0` diff --git a/reqwest-retry/src/lib.rs b/reqwest-retry/src/lib.rs index bef3763..7ab0ee0 100644 --- a/reqwest-retry/src/lib.rs +++ b/reqwest-retry/src/lib.rs @@ -27,8 +27,13 @@ mod middleware; mod retryable; +mod retryable_strategy; pub use retry_policies::{policies, RetryPolicy}; pub use middleware::RetryTransientMiddleware; pub use retryable::Retryable; +pub use retryable_strategy::{ + default_on_request_failure, default_on_request_success, DefaultRetryableStrategy, + RetryableStrategy, +}; diff --git a/reqwest-retry/src/middleware.rs b/reqwest-retry/src/middleware.rs index 4f1151a..84b71da 100644 --- a/reqwest-retry/src/middleware.rs +++ b/reqwest-retry/src/middleware.rs @@ -1,6 +1,6 @@ //! `RetryTransientMiddleware` implements retrying requests on transient errors. - -use crate::retryable::Retryable; +use crate::retryable_strategy::RetryableStrategy; +use crate::{retryable::Retryable, retryable_strategy::DefaultRetryableStrategy}; use anyhow::anyhow; use chrono::Utc; use reqwest::{Request, Response}; @@ -44,20 +44,42 @@ use task_local_extensions::Extensions; /// * You can wrap this middleware in a custom one which skips retries for streaming requests. /// * You can write a custom retry middleware that builds new streaming requests from the data /// source directly, avoiding the issue of streaming requests not being clonable. -pub struct RetryTransientMiddleware { +pub struct RetryTransientMiddleware< + T: RetryPolicy + Send + Sync + 'static, + R: RetryableStrategy + Send + Sync + 'static = DefaultRetryableStrategy, +> { retry_policy: T, + retryable_strategy: R, } -impl RetryTransientMiddleware { - /// Construct `RetryTransientMiddleware` with a [retry_policy][retry_policies::RetryPolicy]. +impl RetryTransientMiddleware { + /// Construct `RetryTransientMiddleware` with a [retry_policy][RetryPolicy]. pub fn new_with_policy(retry_policy: T) -> Self { - Self { retry_policy } + Self::new_with_policy_and_strategy(retry_policy, DefaultRetryableStrategy) + } +} + +impl RetryTransientMiddleware +where + T: RetryPolicy + Send + Sync, + R: RetryableStrategy + Send + Sync, +{ + /// Construct `RetryTransientMiddleware` with a [retry_policy][RetryPolicy] and [retryable_strategy](RetryableStrategy). + pub fn new_with_policy_and_strategy(retry_policy: T, retryable_strategy: R) -> Self { + Self { + retry_policy, + retryable_strategy, + } } } #[cfg_attr(not(target_arch = "wasm32"), async_trait::async_trait)] #[cfg_attr(target_arch = "wasm32", async_trait::async_trait(?Send))] -impl Middleware for RetryTransientMiddleware { +impl Middleware for RetryTransientMiddleware +where + T: RetryPolicy + Send + Sync, + R: RetryableStrategy + Send + Sync + 'static, +{ async fn handle( &self, req: Request, @@ -72,7 +94,11 @@ impl Middleware for RetryTransientMiddleware { } } -impl RetryTransientMiddleware { +impl RetryTransientMiddleware +where + T: RetryPolicy + Send + Sync, + R: RetryableStrategy + Send + Sync, +{ /// This function will try to execute the request, if it fails /// with an error classified as transient it will call itself /// to retry the request. @@ -98,7 +124,7 @@ impl RetryTransientMiddleware { // We classify the response which will return None if not // errors were returned. - break match Retryable::from_reqwest_response(&result) { + break match self.retryable_strategy.handle(&result) { Some(Retryable::Transient) => { // If the response failed and the error type was transient // we can safely try to retry the request. diff --git a/reqwest-retry/src/retryable.rs b/reqwest-retry/src/retryable.rs index cdf864e..dfeed24 100644 --- a/reqwest-retry/src/retryable.rs +++ b/reqwest-retry/src/retryable.rs @@ -1,4 +1,4 @@ -use http::StatusCode; +use crate::retryable_strategy::{DefaultRetryableStrategy, RetryableStrategy}; use reqwest_middleware::Error; /// Classification of an error/status returned by request. @@ -16,78 +16,7 @@ impl Retryable { /// Returns `None` if the response object does not contain any errors. /// pub fn from_reqwest_response(res: &Result) -> Option { - match res { - Ok(success) => { - let status = success.status(); - if status.is_server_error() { - Some(Retryable::Transient) - } else if status.is_client_error() - && status != StatusCode::REQUEST_TIMEOUT - && status != StatusCode::TOO_MANY_REQUESTS - { - Some(Retryable::Fatal) - } else if status.is_success() { - None - } else if status == StatusCode::REQUEST_TIMEOUT - || status == StatusCode::TOO_MANY_REQUESTS - { - Some(Retryable::Transient) - } else { - Some(Retryable::Fatal) - } - } - Err(error) => match error { - // If something fails in the middleware we're screwed. - Error::Middleware(_) => Some(Retryable::Fatal), - Error::Reqwest(error) => { - #[cfg(not(target_arch = "wasm32"))] - let is_connect = error.is_connect(); - #[cfg(target_arch = "wasm32")] - let is_connect = false; - if error.is_timeout() || is_connect { - Some(Retryable::Transient) - } else if error.is_body() - || error.is_decode() - || error.is_builder() - || error.is_redirect() - { - Some(Retryable::Fatal) - } else if error.is_request() { - // It seems that hyper::Error(IncompleteMessage) is not correctly handled by reqwest. - // Here we check if the Reqwest error was originated by hyper and map it consistently. - #[cfg(not(target_arch = "wasm32"))] - if let Some(hyper_error) = get_source_error_type::(&error) { - // The hyper::Error(IncompleteMessage) is raised if the HTTP response is well formatted but does not contain all the bytes. - // This can happen when the server has started sending back the response but the connection is cut halfway thorugh. - // We can safely retry the call, hence marking this error as [`Retryable::Transient`]. - // Instead hyper::Error(Canceled) is raised when the connection is - // gracefully closed on the server side. - if hyper_error.is_incomplete_message() || hyper_error.is_canceled() { - Some(Retryable::Transient) - - // Try and downcast the hyper error to io::Error if that is the - // underlying error, and try and classify it. - } else if let Some(io_error) = - get_source_error_type::(hyper_error) - { - Some(classify_io_error(io_error)) - } else { - Some(Retryable::Fatal) - } - } else { - Some(Retryable::Fatal) - } - #[cfg(target_arch = "wasm32")] - Some(Retryable::Fatal) - } else { - // We omit checking if error.is_status() since we check that already. - // However, if Response::error_for_status is used the status will still - // remain in the response object. - None - } - } - }, - } + DefaultRetryableStrategy.handle(res) } } @@ -96,30 +25,3 @@ impl From<&reqwest::Error> for Retryable { Retryable::Transient } } - -#[cfg(not(target_arch = "wasm32"))] -fn classify_io_error(error: &std::io::Error) -> Retryable { - match error.kind() { - std::io::ErrorKind::ConnectionReset | std::io::ErrorKind::ConnectionAborted => { - Retryable::Transient - } - _ => Retryable::Fatal, - } -} - -/// Downcasts the given err source into T. -#[cfg(not(target_arch = "wasm32"))] -fn get_source_error_type( - err: &dyn std::error::Error, -) -> Option<&T> { - let mut source = err.source(); - - while let Some(err) = source { - if let Some(hyper_err) = err.downcast_ref::() { - return Some(hyper_err); - } - - source = err.source(); - } - None -} diff --git a/reqwest-retry/src/retryable_strategy.rs b/reqwest-retry/src/retryable_strategy.rs new file mode 100644 index 0000000..dddce05 --- /dev/null +++ b/reqwest-retry/src/retryable_strategy.rs @@ -0,0 +1,213 @@ +use crate::retryable::Retryable; +use http::StatusCode; +use reqwest_middleware::Error; + +/// A strategy to create a [`Retryable`] from a [`Result`] +/// +/// A [`RetryableStrategy`] has a single `handler` functions. +/// The result of calling the request could be: +/// - [`reqwest::Response`] In case the request has been sent and received correctly +/// This could however still mean that the server responded with a erroneous response. +/// For example a HTTP statuscode of 500 +/// - [`reqwest_middleware::Error`] In this case the request actually failed. +/// This could, for example, be caused by a timeout on the connection. +/// +/// Example: +/// +/// ``` +/// use reqwest_retry::{default_on_request_failure, policies::ExponentialBackoff, Retryable, RetryableStrategy, RetryTransientMiddleware}; +/// use reqwest::{Request, Response}; +/// use reqwest_middleware::{ClientBuilder, Middleware, Next, Result}; +/// use task_local_extensions::Extensions; +/// +/// // Log each request to show that the requests will be retried +/// struct LoggingMiddleware; +/// +/// #[async_trait::async_trait] +/// impl Middleware for LoggingMiddleware { +/// async fn handle( +/// &self, +/// req: Request, +/// extensions: &mut Extensions, +/// next: Next<'_>, +/// ) -> Result { +/// println!("Request started {}", req.url()); +/// let res = next.run(req, extensions).await; +/// println!("Request finished"); +/// res +/// } +/// } +/// +/// // Just a toy example, retry when the successful response code is 201, else do nothing. +/// struct Retry201; +/// impl RetryableStrategy for Retry201 { +/// fn handle(&self, res: &Result) -> Option { +/// match res { +/// // retry if 201 +/// Ok(success) if success.status() == 201 => Some(Retryable::Transient), +/// // otherwise do not retry a successful request +/// Ok(success) => None, +/// // but maybe retry a request failure +/// Err(error) => default_on_request_failure(error), +/// } +/// } +/// } +/// +/// #[tokio::main] +/// async fn main() { +/// // Exponential backoff with max 2 retries +/// let retry_policy = ExponentialBackoff::builder() +/// .build_with_max_retries(2); +/// +/// // Create the actual middleware, with the exponential backoff and custom retry stategy. +/// let ret_s = RetryTransientMiddleware::new_with_policy_and_strategy( +/// retry_policy, +/// Retry201, +/// ); +/// +/// let client = ClientBuilder::new(reqwest::Client::new()) +/// // Retry failed requests. +/// .with(ret_s) +/// // Log the requests +/// .with(LoggingMiddleware) +/// .build(); +/// +/// // Send request which should get a 201 response. So it will be retried +/// let r = client +/// .get("https://httpbin.org/status/201") +/// .send() +/// .await; +/// println!("{:?}", r); +/// +/// // Send request which should get a 200 response. So it will not be retried +/// let r = client +/// .get("https://httpbin.org/status/200") +/// .send() +/// .await; +/// println!("{:?}", r); +/// } +/// ``` +pub trait RetryableStrategy { + fn handle(&self, res: &Result) -> Option; +} + +/// The default [`RetryableStrategy`] for [`RetryTransientMiddleware`](crate::RetryTransientMiddleware). +pub struct DefaultRetryableStrategy; + +impl RetryableStrategy for DefaultRetryableStrategy { + fn handle(&self, res: &Result) -> Option { + match res { + Ok(success) => default_on_request_success(success), + Err(error) => default_on_request_failure(error), + } + } +} + +/// Default request success retry strategy. +/// +/// Will only retry if: +/// * The status was 5XX (server error) +/// * The status was 408 (request timeout) or 429 (too many requests) +/// +/// Note that success here means that the request finished without interruption, not that it was logically OK. +pub fn default_on_request_success(success: &reqwest::Response) -> Option { + let status = success.status(); + if status.is_server_error() { + Some(Retryable::Transient) + } else if status.is_client_error() + && status != StatusCode::REQUEST_TIMEOUT + && status != StatusCode::TOO_MANY_REQUESTS + { + Some(Retryable::Fatal) + } else if status.is_success() { + None + } else if status == StatusCode::REQUEST_TIMEOUT || status == StatusCode::TOO_MANY_REQUESTS { + Some(Retryable::Transient) + } else { + Some(Retryable::Fatal) + } +} + +/// Default request failure retry strategy. +/// +/// Will only retry if the request failed due to a network error +pub fn default_on_request_failure(error: &Error) -> Option { + match error { + // If something fails in the middleware we're screwed. + Error::Middleware(_) => Some(Retryable::Fatal), + Error::Reqwest(error) => { + #[cfg(not(target_arch = "wasm32"))] + let is_connect = error.is_connect(); + #[cfg(target_arch = "wasm32")] + let is_connect = false; + if error.is_timeout() || is_connect { + Some(Retryable::Transient) + } else if error.is_body() + || error.is_decode() + || error.is_builder() + || error.is_redirect() + { + Some(Retryable::Fatal) + } else if error.is_request() { + // It seems that hyper::Error(IncompleteMessage) is not correctly handled by reqwest. + // Here we check if the Reqwest error was originated by hyper and map it consistently. + #[cfg(not(target_arch = "wasm32"))] + if let Some(hyper_error) = get_source_error_type::(&error) { + // The hyper::Error(IncompleteMessage) is raised if the HTTP response is well formatted but does not contain all the bytes. + // This can happen when the server has started sending back the response but the connection is cut halfway thorugh. + // We can safely retry the call, hence marking this error as [`Retryable::Transient`]. + // Instead hyper::Error(Canceled) is raised when the connection is + // gracefully closed on the server side. + if hyper_error.is_incomplete_message() || hyper_error.is_canceled() { + Some(Retryable::Transient) + + // Try and downcast the hyper error to io::Error if that is the + // underlying error, and try and classify it. + } else if let Some(io_error) = + get_source_error_type::(hyper_error) + { + Some(classify_io_error(io_error)) + } else { + Some(Retryable::Fatal) + } + } else { + Some(Retryable::Fatal) + } + #[cfg(target_arch = "wasm32")] + Some(Retryable::Fatal) + } else { + // We omit checking if error.is_status() since we check that already. + // However, if Response::error_for_status is used the status will still + // remain in the response object. + None + } + } + } +} + +#[cfg(not(target_arch = "wasm32"))] +fn classify_io_error(error: &std::io::Error) -> Retryable { + match error.kind() { + std::io::ErrorKind::ConnectionReset | std::io::ErrorKind::ConnectionAborted => { + Retryable::Transient + } + _ => Retryable::Fatal, + } +} + +/// Downcasts the given err source into T. +#[cfg(not(target_arch = "wasm32"))] +fn get_source_error_type( + err: &dyn std::error::Error, +) -> Option<&T> { + let mut source = err.source(); + + while let Some(err) = source { + if let Some(err) = err.downcast_ref::() { + return Some(err); + } + + source = err.source(); + } + None +}