feat: classify io error connection reset by peer (#78)

* feat: classify io error connection reset by peer

* doc: explain tests

* chore: add changelog

* fix: use and_then

* chore: docs and clippy

* chore: bump reqwest-retry to 0.2.1

* fix: use get_source_error_type

* fix: remove unused import

* fix: make clippy happy

* doc: describe canceled
pull/80/head
tl-helge-hoff 2022-12-02 10:57:24 +01:00 committed by GitHub
parent 8763ab1e30
commit f854725791
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 148 additions and 2 deletions

View File

@ -6,6 +6,11 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
## [Unreleased]
## [0.2.1] - 2022-12-01
### Changed
- Classify `io::Error`s and `hyper::Error(Canceled)` as transient
## [0.2.0] - 2022-11-15
### Changed

View File

@ -1,6 +1,6 @@
[package]
name = "reqwest-retry"
version = "0.2.0"
version = "0.2.1"
authors = ["Rodrigo Gryzinski <rodrigo.gryzinski@truelayer.com>"]
edition = "2018"
description = "Retry middleware for reqwest."
@ -29,3 +29,4 @@ async-std = { version = "1.10"}
paste = "1"
tokio = { version = "1", features = ["macros"] }
wiremock = "0.5"
futures = "0.3"

View File

@ -1,3 +1,5 @@
use std::io;
use http::StatusCode;
use reqwest_middleware::Error;
@ -55,8 +57,17 @@ impl Retryable {
// The hyper::Error(IncompleteMessage) is raised if the HTTP response is well formatted but does not contain all the bytes.
// This can happen when the server has started sending back the response but the connection is cut halfway thorugh.
// We can safely retry the call, hence marking this error as [`Retryable::Transient`].
if hyper_error.is_incomplete_message() {
// Instead hyper::Error(Canceled) is raised when the connection is
// gracefully closed on the server side.
if hyper_error.is_incomplete_message() || hyper_error.is_canceled() {
Some(Retryable::Transient)
// Try and downcast the hyper error to io::Error if that is the
// underlying error, and try and classify it.
} else if let Some(io_error) =
get_source_error_type::<io::Error>(hyper_error)
{
Some(classify_io_error(io_error))
} else {
Some(Retryable::Fatal)
}
@ -81,6 +92,13 @@ impl From<&reqwest::Error> for Retryable {
}
}
fn classify_io_error(error: &io::Error) -> Retryable {
match error.kind() {
io::ErrorKind::ConnectionReset | io::ErrorKind::ConnectionAborted => Retryable::Transient,
_ => Retryable::Fatal,
}
}
/// Downcasts the given err source into T.
fn get_source_error_type<T: std::error::Error + 'static>(
err: &dyn std::error::Error,

View File

@ -1,10 +1,15 @@
use async_std::io::ReadExt;
use async_std::io::WriteExt;
use async_std::net::{TcpListener, TcpStream};
use futures::future::BoxFuture;
use futures::stream::StreamExt;
use std::error::Error;
use std::fmt;
type CustomMessageHandler = Box<
dyn Fn(TcpStream) -> BoxFuture<'static, Result<(), Box<dyn std::error::Error>>> + Send + Sync,
>;
/// This is a simple server that returns the responses given at creation time: [`self.raw_http_responses`] following a round-robin mechanism.
pub struct SimpleServer {
listener: TcpListener,
@ -12,6 +17,7 @@ pub struct SimpleServer {
host: String,
raw_http_responses: Vec<String>,
calls_counter: usize,
custom_handler: Option<CustomMessageHandler>,
}
/// Request-Line = Method SP Request-URI SP HTTP-Version CRLF
@ -46,9 +52,21 @@ impl SimpleServer {
host: host.to_string(),
raw_http_responses,
calls_counter: 0,
custom_handler: None,
})
}
pub fn set_custom_handler(
&mut self,
custom_handler: impl Fn(TcpStream) -> BoxFuture<'static, Result<(), Box<dyn std::error::Error>>>
+ Send
+ Sync
+ 'static,
) -> &mut Self {
self.custom_handler.replace(Box::new(custom_handler));
self
}
/// Returns the uri in which the server is listening to.
pub fn uri(&self) -> String {
format!("http://{}:{}", self.host, self.port)
@ -79,6 +97,10 @@ impl SimpleServer {
///
/// Returns a 400 if the request if formatted badly.
async fn handle_connection(&self, mut stream: TcpStream) -> Result<(), Box<dyn Error>> {
if let Some(ref custom_handler) = self.custom_handler {
return custom_handler(stream).await;
}
let mut buffer = vec![0; 1024];
stream.read(&mut buffer).await.unwrap();

View File

@ -1,8 +1,12 @@
use async_std::io::ReadExt;
use futures::AsyncWriteExt;
use futures::FutureExt;
use paste::paste;
use reqwest::Client;
use reqwest::StatusCode;
use reqwest_middleware::ClientBuilder;
use reqwest_retry::{policies::ExponentialBackoff, RetryTransientMiddleware};
use std::sync::atomic::AtomicI8;
use std::sync::{
atomic::{AtomicU32, Ordering},
Arc,
@ -269,3 +273,99 @@ async fn assert_retry_on_incomplete_message() {
assert_eq!(resp.status(), 200);
}
#[tokio::test]
async fn assert_retry_on_hyper_canceled() {
let counter = Arc::new(AtomicI8::new(0));
let mut simple_server = SimpleServer::new("127.0.0.1", None, vec![])
.await
.expect("Error when creating a simple server");
simple_server.set_custom_handler(move |mut stream| {
let counter = counter.clone();
async move {
let mut buffer = Vec::new();
stream.read(&mut buffer).await.unwrap();
if counter.fetch_add(1, Ordering::SeqCst) > 1 {
// This triggeres hyper:Error(Canceled).
let _res = stream.shutdown(std::net::Shutdown::Both);
} else {
let _res = stream.write("HTTP/1.1 200 OK\r\n\r\n".as_bytes()).await;
}
Ok(())
}
.boxed()
});
let uri = simple_server.uri();
tokio::spawn(simple_server.start());
let reqwest_client = Client::builder().build().unwrap();
let client = ClientBuilder::new(reqwest_client)
.with(RetryTransientMiddleware::new_with_policy(
ExponentialBackoff {
max_n_retries: 3,
max_retry_interval: std::time::Duration::from_millis(100),
min_retry_interval: std::time::Duration::from_millis(30),
backoff_exponent: 2,
},
))
.build();
let resp = client
.get(&format!("{}/foo", uri))
.timeout(std::time::Duration::from_millis(100))
.send()
.await
.expect("call failed");
assert_eq!(resp.status(), 200);
}
#[tokio::test]
async fn assert_retry_on_connection_reset_by_peer() {
let counter = Arc::new(AtomicI8::new(0));
let mut simple_server = SimpleServer::new("127.0.0.1", None, vec![])
.await
.expect("Error when creating a simple server");
simple_server.set_custom_handler(move |mut stream| {
let counter = counter.clone();
async move {
let mut buffer = Vec::new();
stream.read(&mut buffer).await.unwrap();
if counter.fetch_add(1, Ordering::SeqCst) > 1 {
// This triggeres hyper:Error(Io, io::Error(ConnectionReset)).
drop(stream);
} else {
let _res = stream.write("HTTP/1.1 200 OK\r\n\r\n".as_bytes()).await;
}
Ok(())
}
.boxed()
});
let uri = simple_server.uri();
tokio::spawn(simple_server.start());
let reqwest_client = Client::builder().build().unwrap();
let client = ClientBuilder::new(reqwest_client)
.with(RetryTransientMiddleware::new_with_policy(
ExponentialBackoff {
max_n_retries: 3,
max_retry_interval: std::time::Duration::from_millis(100),
min_retry_interval: std::time::Duration::from_millis(30),
backoff_exponent: 2,
},
))
.build();
let resp = client
.get(&format!("{}/foo", uri))
.timeout(std::time::Duration::from_millis(100))
.send()
.await
.expect("call failed");
assert_eq!(resp.status(), 200);
}