mirror of
https://github.com/TrueLayer/reqwest-middleware.git
synced 2024-12-26 02:46:30 +00:00
feat: classify io error connection reset by peer (#78)
* feat: classify io error connection reset by peer * doc: explain tests * chore: add changelog * fix: use and_then * chore: docs and clippy * chore: bump reqwest-retry to 0.2.1 * fix: use get_source_error_type * fix: remove unused import * fix: make clippy happy * doc: describe canceled
This commit is contained in:
parent
8763ab1e30
commit
f854725791
5 changed files with 148 additions and 2 deletions
|
@ -6,6 +6,11 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
|
|||
|
||||
## [Unreleased]
|
||||
|
||||
## [0.2.1] - 2022-12-01
|
||||
|
||||
### Changed
|
||||
- Classify `io::Error`s and `hyper::Error(Canceled)` as transient
|
||||
|
||||
## [0.2.0] - 2022-11-15
|
||||
|
||||
### Changed
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
[package]
|
||||
name = "reqwest-retry"
|
||||
version = "0.2.0"
|
||||
version = "0.2.1"
|
||||
authors = ["Rodrigo Gryzinski <rodrigo.gryzinski@truelayer.com>"]
|
||||
edition = "2018"
|
||||
description = "Retry middleware for reqwest."
|
||||
|
@ -29,3 +29,4 @@ async-std = { version = "1.10"}
|
|||
paste = "1"
|
||||
tokio = { version = "1", features = ["macros"] }
|
||||
wiremock = "0.5"
|
||||
futures = "0.3"
|
||||
|
|
|
@ -1,3 +1,5 @@
|
|||
use std::io;
|
||||
|
||||
use http::StatusCode;
|
||||
use reqwest_middleware::Error;
|
||||
|
||||
|
@ -55,8 +57,17 @@ impl Retryable {
|
|||
// The hyper::Error(IncompleteMessage) is raised if the HTTP response is well formatted but does not contain all the bytes.
|
||||
// This can happen when the server has started sending back the response but the connection is cut halfway thorugh.
|
||||
// We can safely retry the call, hence marking this error as [`Retryable::Transient`].
|
||||
if hyper_error.is_incomplete_message() {
|
||||
// Instead hyper::Error(Canceled) is raised when the connection is
|
||||
// gracefully closed on the server side.
|
||||
if hyper_error.is_incomplete_message() || hyper_error.is_canceled() {
|
||||
Some(Retryable::Transient)
|
||||
|
||||
// Try and downcast the hyper error to io::Error if that is the
|
||||
// underlying error, and try and classify it.
|
||||
} else if let Some(io_error) =
|
||||
get_source_error_type::<io::Error>(hyper_error)
|
||||
{
|
||||
Some(classify_io_error(io_error))
|
||||
} else {
|
||||
Some(Retryable::Fatal)
|
||||
}
|
||||
|
@ -81,6 +92,13 @@ impl From<&reqwest::Error> for Retryable {
|
|||
}
|
||||
}
|
||||
|
||||
fn classify_io_error(error: &io::Error) -> Retryable {
|
||||
match error.kind() {
|
||||
io::ErrorKind::ConnectionReset | io::ErrorKind::ConnectionAborted => Retryable::Transient,
|
||||
_ => Retryable::Fatal,
|
||||
}
|
||||
}
|
||||
|
||||
/// Downcasts the given err source into T.
|
||||
fn get_source_error_type<T: std::error::Error + 'static>(
|
||||
err: &dyn std::error::Error,
|
||||
|
|
|
@ -1,10 +1,15 @@
|
|||
use async_std::io::ReadExt;
|
||||
use async_std::io::WriteExt;
|
||||
use async_std::net::{TcpListener, TcpStream};
|
||||
use futures::future::BoxFuture;
|
||||
use futures::stream::StreamExt;
|
||||
use std::error::Error;
|
||||
use std::fmt;
|
||||
|
||||
type CustomMessageHandler = Box<
|
||||
dyn Fn(TcpStream) -> BoxFuture<'static, Result<(), Box<dyn std::error::Error>>> + Send + Sync,
|
||||
>;
|
||||
|
||||
/// This is a simple server that returns the responses given at creation time: [`self.raw_http_responses`] following a round-robin mechanism.
|
||||
pub struct SimpleServer {
|
||||
listener: TcpListener,
|
||||
|
@ -12,6 +17,7 @@ pub struct SimpleServer {
|
|||
host: String,
|
||||
raw_http_responses: Vec<String>,
|
||||
calls_counter: usize,
|
||||
custom_handler: Option<CustomMessageHandler>,
|
||||
}
|
||||
|
||||
/// Request-Line = Method SP Request-URI SP HTTP-Version CRLF
|
||||
|
@ -46,9 +52,21 @@ impl SimpleServer {
|
|||
host: host.to_string(),
|
||||
raw_http_responses,
|
||||
calls_counter: 0,
|
||||
custom_handler: None,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn set_custom_handler(
|
||||
&mut self,
|
||||
custom_handler: impl Fn(TcpStream) -> BoxFuture<'static, Result<(), Box<dyn std::error::Error>>>
|
||||
+ Send
|
||||
+ Sync
|
||||
+ 'static,
|
||||
) -> &mut Self {
|
||||
self.custom_handler.replace(Box::new(custom_handler));
|
||||
self
|
||||
}
|
||||
|
||||
/// Returns the uri in which the server is listening to.
|
||||
pub fn uri(&self) -> String {
|
||||
format!("http://{}:{}", self.host, self.port)
|
||||
|
@ -79,6 +97,10 @@ impl SimpleServer {
|
|||
///
|
||||
/// Returns a 400 if the request if formatted badly.
|
||||
async fn handle_connection(&self, mut stream: TcpStream) -> Result<(), Box<dyn Error>> {
|
||||
if let Some(ref custom_handler) = self.custom_handler {
|
||||
return custom_handler(stream).await;
|
||||
}
|
||||
|
||||
let mut buffer = vec![0; 1024];
|
||||
|
||||
stream.read(&mut buffer).await.unwrap();
|
||||
|
|
|
@ -1,8 +1,12 @@
|
|||
use async_std::io::ReadExt;
|
||||
use futures::AsyncWriteExt;
|
||||
use futures::FutureExt;
|
||||
use paste::paste;
|
||||
use reqwest::Client;
|
||||
use reqwest::StatusCode;
|
||||
use reqwest_middleware::ClientBuilder;
|
||||
use reqwest_retry::{policies::ExponentialBackoff, RetryTransientMiddleware};
|
||||
use std::sync::atomic::AtomicI8;
|
||||
use std::sync::{
|
||||
atomic::{AtomicU32, Ordering},
|
||||
Arc,
|
||||
|
@ -269,3 +273,99 @@ async fn assert_retry_on_incomplete_message() {
|
|||
|
||||
assert_eq!(resp.status(), 200);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn assert_retry_on_hyper_canceled() {
|
||||
let counter = Arc::new(AtomicI8::new(0));
|
||||
let mut simple_server = SimpleServer::new("127.0.0.1", None, vec![])
|
||||
.await
|
||||
.expect("Error when creating a simple server");
|
||||
simple_server.set_custom_handler(move |mut stream| {
|
||||
let counter = counter.clone();
|
||||
async move {
|
||||
let mut buffer = Vec::new();
|
||||
stream.read(&mut buffer).await.unwrap();
|
||||
if counter.fetch_add(1, Ordering::SeqCst) > 1 {
|
||||
// This triggeres hyper:Error(Canceled).
|
||||
let _res = stream.shutdown(std::net::Shutdown::Both);
|
||||
} else {
|
||||
let _res = stream.write("HTTP/1.1 200 OK\r\n\r\n".as_bytes()).await;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
.boxed()
|
||||
});
|
||||
|
||||
let uri = simple_server.uri();
|
||||
|
||||
tokio::spawn(simple_server.start());
|
||||
|
||||
let reqwest_client = Client::builder().build().unwrap();
|
||||
let client = ClientBuilder::new(reqwest_client)
|
||||
.with(RetryTransientMiddleware::new_with_policy(
|
||||
ExponentialBackoff {
|
||||
max_n_retries: 3,
|
||||
max_retry_interval: std::time::Duration::from_millis(100),
|
||||
min_retry_interval: std::time::Duration::from_millis(30),
|
||||
backoff_exponent: 2,
|
||||
},
|
||||
))
|
||||
.build();
|
||||
|
||||
let resp = client
|
||||
.get(&format!("{}/foo", uri))
|
||||
.timeout(std::time::Duration::from_millis(100))
|
||||
.send()
|
||||
.await
|
||||
.expect("call failed");
|
||||
|
||||
assert_eq!(resp.status(), 200);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn assert_retry_on_connection_reset_by_peer() {
|
||||
let counter = Arc::new(AtomicI8::new(0));
|
||||
let mut simple_server = SimpleServer::new("127.0.0.1", None, vec![])
|
||||
.await
|
||||
.expect("Error when creating a simple server");
|
||||
simple_server.set_custom_handler(move |mut stream| {
|
||||
let counter = counter.clone();
|
||||
async move {
|
||||
let mut buffer = Vec::new();
|
||||
stream.read(&mut buffer).await.unwrap();
|
||||
if counter.fetch_add(1, Ordering::SeqCst) > 1 {
|
||||
// This triggeres hyper:Error(Io, io::Error(ConnectionReset)).
|
||||
drop(stream);
|
||||
} else {
|
||||
let _res = stream.write("HTTP/1.1 200 OK\r\n\r\n".as_bytes()).await;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
.boxed()
|
||||
});
|
||||
|
||||
let uri = simple_server.uri();
|
||||
|
||||
tokio::spawn(simple_server.start());
|
||||
|
||||
let reqwest_client = Client::builder().build().unwrap();
|
||||
let client = ClientBuilder::new(reqwest_client)
|
||||
.with(RetryTransientMiddleware::new_with_policy(
|
||||
ExponentialBackoff {
|
||||
max_n_retries: 3,
|
||||
max_retry_interval: std::time::Duration::from_millis(100),
|
||||
min_retry_interval: std::time::Duration::from_millis(30),
|
||||
backoff_exponent: 2,
|
||||
},
|
||||
))
|
||||
.build();
|
||||
|
||||
let resp = client
|
||||
.get(&format!("{}/foo", uri))
|
||||
.timeout(std::time::Duration::from_millis(100))
|
||||
.send()
|
||||
.await
|
||||
.expect("call failed");
|
||||
|
||||
assert_eq!(resp.status(), 200);
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue