forked from mirror/reqwest-middleware
feat: classify io error connection reset by peer (#78)
* feat: classify io error connection reset by peer * doc: explain tests * chore: add changelog * fix: use and_then * chore: docs and clippy * chore: bump reqwest-retry to 0.2.1 * fix: use get_source_error_type * fix: remove unused import * fix: make clippy happy * doc: describe canceled
This commit is contained in:
parent
8763ab1e30
commit
f854725791
5 changed files with 148 additions and 2 deletions
|
@ -6,6 +6,11 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
|
||||||
|
|
||||||
## [Unreleased]
|
## [Unreleased]
|
||||||
|
|
||||||
|
## [0.2.1] - 2022-12-01
|
||||||
|
|
||||||
|
### Changed
|
||||||
|
- Classify `io::Error`s and `hyper::Error(Canceled)` as transient
|
||||||
|
|
||||||
## [0.2.0] - 2022-11-15
|
## [0.2.0] - 2022-11-15
|
||||||
|
|
||||||
### Changed
|
### Changed
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
[package]
|
[package]
|
||||||
name = "reqwest-retry"
|
name = "reqwest-retry"
|
||||||
version = "0.2.0"
|
version = "0.2.1"
|
||||||
authors = ["Rodrigo Gryzinski <rodrigo.gryzinski@truelayer.com>"]
|
authors = ["Rodrigo Gryzinski <rodrigo.gryzinski@truelayer.com>"]
|
||||||
edition = "2018"
|
edition = "2018"
|
||||||
description = "Retry middleware for reqwest."
|
description = "Retry middleware for reqwest."
|
||||||
|
@ -29,3 +29,4 @@ async-std = { version = "1.10"}
|
||||||
paste = "1"
|
paste = "1"
|
||||||
tokio = { version = "1", features = ["macros"] }
|
tokio = { version = "1", features = ["macros"] }
|
||||||
wiremock = "0.5"
|
wiremock = "0.5"
|
||||||
|
futures = "0.3"
|
||||||
|
|
|
@ -1,3 +1,5 @@
|
||||||
|
use std::io;
|
||||||
|
|
||||||
use http::StatusCode;
|
use http::StatusCode;
|
||||||
use reqwest_middleware::Error;
|
use reqwest_middleware::Error;
|
||||||
|
|
||||||
|
@ -55,8 +57,17 @@ impl Retryable {
|
||||||
// The hyper::Error(IncompleteMessage) is raised if the HTTP response is well formatted but does not contain all the bytes.
|
// The hyper::Error(IncompleteMessage) is raised if the HTTP response is well formatted but does not contain all the bytes.
|
||||||
// This can happen when the server has started sending back the response but the connection is cut halfway thorugh.
|
// This can happen when the server has started sending back the response but the connection is cut halfway thorugh.
|
||||||
// We can safely retry the call, hence marking this error as [`Retryable::Transient`].
|
// We can safely retry the call, hence marking this error as [`Retryable::Transient`].
|
||||||
if hyper_error.is_incomplete_message() {
|
// Instead hyper::Error(Canceled) is raised when the connection is
|
||||||
|
// gracefully closed on the server side.
|
||||||
|
if hyper_error.is_incomplete_message() || hyper_error.is_canceled() {
|
||||||
Some(Retryable::Transient)
|
Some(Retryable::Transient)
|
||||||
|
|
||||||
|
// Try and downcast the hyper error to io::Error if that is the
|
||||||
|
// underlying error, and try and classify it.
|
||||||
|
} else if let Some(io_error) =
|
||||||
|
get_source_error_type::<io::Error>(hyper_error)
|
||||||
|
{
|
||||||
|
Some(classify_io_error(io_error))
|
||||||
} else {
|
} else {
|
||||||
Some(Retryable::Fatal)
|
Some(Retryable::Fatal)
|
||||||
}
|
}
|
||||||
|
@ -81,6 +92,13 @@ impl From<&reqwest::Error> for Retryable {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn classify_io_error(error: &io::Error) -> Retryable {
|
||||||
|
match error.kind() {
|
||||||
|
io::ErrorKind::ConnectionReset | io::ErrorKind::ConnectionAborted => Retryable::Transient,
|
||||||
|
_ => Retryable::Fatal,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/// Downcasts the given err source into T.
|
/// Downcasts the given err source into T.
|
||||||
fn get_source_error_type<T: std::error::Error + 'static>(
|
fn get_source_error_type<T: std::error::Error + 'static>(
|
||||||
err: &dyn std::error::Error,
|
err: &dyn std::error::Error,
|
||||||
|
|
|
@ -1,10 +1,15 @@
|
||||||
use async_std::io::ReadExt;
|
use async_std::io::ReadExt;
|
||||||
use async_std::io::WriteExt;
|
use async_std::io::WriteExt;
|
||||||
use async_std::net::{TcpListener, TcpStream};
|
use async_std::net::{TcpListener, TcpStream};
|
||||||
|
use futures::future::BoxFuture;
|
||||||
use futures::stream::StreamExt;
|
use futures::stream::StreamExt;
|
||||||
use std::error::Error;
|
use std::error::Error;
|
||||||
use std::fmt;
|
use std::fmt;
|
||||||
|
|
||||||
|
type CustomMessageHandler = Box<
|
||||||
|
dyn Fn(TcpStream) -> BoxFuture<'static, Result<(), Box<dyn std::error::Error>>> + Send + Sync,
|
||||||
|
>;
|
||||||
|
|
||||||
/// This is a simple server that returns the responses given at creation time: [`self.raw_http_responses`] following a round-robin mechanism.
|
/// This is a simple server that returns the responses given at creation time: [`self.raw_http_responses`] following a round-robin mechanism.
|
||||||
pub struct SimpleServer {
|
pub struct SimpleServer {
|
||||||
listener: TcpListener,
|
listener: TcpListener,
|
||||||
|
@ -12,6 +17,7 @@ pub struct SimpleServer {
|
||||||
host: String,
|
host: String,
|
||||||
raw_http_responses: Vec<String>,
|
raw_http_responses: Vec<String>,
|
||||||
calls_counter: usize,
|
calls_counter: usize,
|
||||||
|
custom_handler: Option<CustomMessageHandler>,
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Request-Line = Method SP Request-URI SP HTTP-Version CRLF
|
/// Request-Line = Method SP Request-URI SP HTTP-Version CRLF
|
||||||
|
@ -46,9 +52,21 @@ impl SimpleServer {
|
||||||
host: host.to_string(),
|
host: host.to_string(),
|
||||||
raw_http_responses,
|
raw_http_responses,
|
||||||
calls_counter: 0,
|
calls_counter: 0,
|
||||||
|
custom_handler: None,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn set_custom_handler(
|
||||||
|
&mut self,
|
||||||
|
custom_handler: impl Fn(TcpStream) -> BoxFuture<'static, Result<(), Box<dyn std::error::Error>>>
|
||||||
|
+ Send
|
||||||
|
+ Sync
|
||||||
|
+ 'static,
|
||||||
|
) -> &mut Self {
|
||||||
|
self.custom_handler.replace(Box::new(custom_handler));
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
/// Returns the uri in which the server is listening to.
|
/// Returns the uri in which the server is listening to.
|
||||||
pub fn uri(&self) -> String {
|
pub fn uri(&self) -> String {
|
||||||
format!("http://{}:{}", self.host, self.port)
|
format!("http://{}:{}", self.host, self.port)
|
||||||
|
@ -79,6 +97,10 @@ impl SimpleServer {
|
||||||
///
|
///
|
||||||
/// Returns a 400 if the request if formatted badly.
|
/// Returns a 400 if the request if formatted badly.
|
||||||
async fn handle_connection(&self, mut stream: TcpStream) -> Result<(), Box<dyn Error>> {
|
async fn handle_connection(&self, mut stream: TcpStream) -> Result<(), Box<dyn Error>> {
|
||||||
|
if let Some(ref custom_handler) = self.custom_handler {
|
||||||
|
return custom_handler(stream).await;
|
||||||
|
}
|
||||||
|
|
||||||
let mut buffer = vec![0; 1024];
|
let mut buffer = vec![0; 1024];
|
||||||
|
|
||||||
stream.read(&mut buffer).await.unwrap();
|
stream.read(&mut buffer).await.unwrap();
|
||||||
|
|
|
@ -1,8 +1,12 @@
|
||||||
|
use async_std::io::ReadExt;
|
||||||
|
use futures::AsyncWriteExt;
|
||||||
|
use futures::FutureExt;
|
||||||
use paste::paste;
|
use paste::paste;
|
||||||
use reqwest::Client;
|
use reqwest::Client;
|
||||||
use reqwest::StatusCode;
|
use reqwest::StatusCode;
|
||||||
use reqwest_middleware::ClientBuilder;
|
use reqwest_middleware::ClientBuilder;
|
||||||
use reqwest_retry::{policies::ExponentialBackoff, RetryTransientMiddleware};
|
use reqwest_retry::{policies::ExponentialBackoff, RetryTransientMiddleware};
|
||||||
|
use std::sync::atomic::AtomicI8;
|
||||||
use std::sync::{
|
use std::sync::{
|
||||||
atomic::{AtomicU32, Ordering},
|
atomic::{AtomicU32, Ordering},
|
||||||
Arc,
|
Arc,
|
||||||
|
@ -269,3 +273,99 @@ async fn assert_retry_on_incomplete_message() {
|
||||||
|
|
||||||
assert_eq!(resp.status(), 200);
|
assert_eq!(resp.status(), 200);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn assert_retry_on_hyper_canceled() {
|
||||||
|
let counter = Arc::new(AtomicI8::new(0));
|
||||||
|
let mut simple_server = SimpleServer::new("127.0.0.1", None, vec![])
|
||||||
|
.await
|
||||||
|
.expect("Error when creating a simple server");
|
||||||
|
simple_server.set_custom_handler(move |mut stream| {
|
||||||
|
let counter = counter.clone();
|
||||||
|
async move {
|
||||||
|
let mut buffer = Vec::new();
|
||||||
|
stream.read(&mut buffer).await.unwrap();
|
||||||
|
if counter.fetch_add(1, Ordering::SeqCst) > 1 {
|
||||||
|
// This triggeres hyper:Error(Canceled).
|
||||||
|
let _res = stream.shutdown(std::net::Shutdown::Both);
|
||||||
|
} else {
|
||||||
|
let _res = stream.write("HTTP/1.1 200 OK\r\n\r\n".as_bytes()).await;
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
.boxed()
|
||||||
|
});
|
||||||
|
|
||||||
|
let uri = simple_server.uri();
|
||||||
|
|
||||||
|
tokio::spawn(simple_server.start());
|
||||||
|
|
||||||
|
let reqwest_client = Client::builder().build().unwrap();
|
||||||
|
let client = ClientBuilder::new(reqwest_client)
|
||||||
|
.with(RetryTransientMiddleware::new_with_policy(
|
||||||
|
ExponentialBackoff {
|
||||||
|
max_n_retries: 3,
|
||||||
|
max_retry_interval: std::time::Duration::from_millis(100),
|
||||||
|
min_retry_interval: std::time::Duration::from_millis(30),
|
||||||
|
backoff_exponent: 2,
|
||||||
|
},
|
||||||
|
))
|
||||||
|
.build();
|
||||||
|
|
||||||
|
let resp = client
|
||||||
|
.get(&format!("{}/foo", uri))
|
||||||
|
.timeout(std::time::Duration::from_millis(100))
|
||||||
|
.send()
|
||||||
|
.await
|
||||||
|
.expect("call failed");
|
||||||
|
|
||||||
|
assert_eq!(resp.status(), 200);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn assert_retry_on_connection_reset_by_peer() {
|
||||||
|
let counter = Arc::new(AtomicI8::new(0));
|
||||||
|
let mut simple_server = SimpleServer::new("127.0.0.1", None, vec![])
|
||||||
|
.await
|
||||||
|
.expect("Error when creating a simple server");
|
||||||
|
simple_server.set_custom_handler(move |mut stream| {
|
||||||
|
let counter = counter.clone();
|
||||||
|
async move {
|
||||||
|
let mut buffer = Vec::new();
|
||||||
|
stream.read(&mut buffer).await.unwrap();
|
||||||
|
if counter.fetch_add(1, Ordering::SeqCst) > 1 {
|
||||||
|
// This triggeres hyper:Error(Io, io::Error(ConnectionReset)).
|
||||||
|
drop(stream);
|
||||||
|
} else {
|
||||||
|
let _res = stream.write("HTTP/1.1 200 OK\r\n\r\n".as_bytes()).await;
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
.boxed()
|
||||||
|
});
|
||||||
|
|
||||||
|
let uri = simple_server.uri();
|
||||||
|
|
||||||
|
tokio::spawn(simple_server.start());
|
||||||
|
|
||||||
|
let reqwest_client = Client::builder().build().unwrap();
|
||||||
|
let client = ClientBuilder::new(reqwest_client)
|
||||||
|
.with(RetryTransientMiddleware::new_with_policy(
|
||||||
|
ExponentialBackoff {
|
||||||
|
max_n_retries: 3,
|
||||||
|
max_retry_interval: std::time::Duration::from_millis(100),
|
||||||
|
min_retry_interval: std::time::Duration::from_millis(30),
|
||||||
|
backoff_exponent: 2,
|
||||||
|
},
|
||||||
|
))
|
||||||
|
.build();
|
||||||
|
|
||||||
|
let resp = client
|
||||||
|
.get(&format!("{}/foo", uri))
|
||||||
|
.timeout(std::time::Duration::from_millis(100))
|
||||||
|
.send()
|
||||||
|
.await
|
||||||
|
.expect("call failed");
|
||||||
|
|
||||||
|
assert_eq!(resp.status(), 200);
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in a new issue