feat: classify io error connection reset by peer (#78)

* feat: classify io error connection reset by peer

* doc: explain tests

* chore: add changelog

* fix: use and_then

* chore: docs and clippy

* chore: bump reqwest-retry to 0.2.1

* fix: use get_source_error_type

* fix: remove unused import

* fix: make clippy happy

* doc: describe canceled
This commit is contained in:
tl-helge-hoff 2022-12-02 10:57:24 +01:00 committed by GitHub
parent 8763ab1e30
commit f854725791
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 148 additions and 2 deletions

View file

@ -6,6 +6,11 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
## [Unreleased] ## [Unreleased]
## [0.2.1] - 2022-12-01
### Changed
- Classify `io::Error`s and `hyper::Error(Canceled)` as transient
## [0.2.0] - 2022-11-15 ## [0.2.0] - 2022-11-15
### Changed ### Changed

View file

@ -1,6 +1,6 @@
[package] [package]
name = "reqwest-retry" name = "reqwest-retry"
version = "0.2.0" version = "0.2.1"
authors = ["Rodrigo Gryzinski <rodrigo.gryzinski@truelayer.com>"] authors = ["Rodrigo Gryzinski <rodrigo.gryzinski@truelayer.com>"]
edition = "2018" edition = "2018"
description = "Retry middleware for reqwest." description = "Retry middleware for reqwest."
@ -29,3 +29,4 @@ async-std = { version = "1.10"}
paste = "1" paste = "1"
tokio = { version = "1", features = ["macros"] } tokio = { version = "1", features = ["macros"] }
wiremock = "0.5" wiremock = "0.5"
futures = "0.3"

View file

@ -1,3 +1,5 @@
use std::io;
use http::StatusCode; use http::StatusCode;
use reqwest_middleware::Error; use reqwest_middleware::Error;
@ -55,8 +57,17 @@ impl Retryable {
// The hyper::Error(IncompleteMessage) is raised if the HTTP response is well formatted but does not contain all the bytes. // The hyper::Error(IncompleteMessage) is raised if the HTTP response is well formatted but does not contain all the bytes.
// This can happen when the server has started sending back the response but the connection is cut halfway thorugh. // This can happen when the server has started sending back the response but the connection is cut halfway thorugh.
// We can safely retry the call, hence marking this error as [`Retryable::Transient`]. // We can safely retry the call, hence marking this error as [`Retryable::Transient`].
if hyper_error.is_incomplete_message() { // Instead hyper::Error(Canceled) is raised when the connection is
// gracefully closed on the server side.
if hyper_error.is_incomplete_message() || hyper_error.is_canceled() {
Some(Retryable::Transient) Some(Retryable::Transient)
// Try and downcast the hyper error to io::Error if that is the
// underlying error, and try and classify it.
} else if let Some(io_error) =
get_source_error_type::<io::Error>(hyper_error)
{
Some(classify_io_error(io_error))
} else { } else {
Some(Retryable::Fatal) Some(Retryable::Fatal)
} }
@ -81,6 +92,13 @@ impl From<&reqwest::Error> for Retryable {
} }
} }
fn classify_io_error(error: &io::Error) -> Retryable {
match error.kind() {
io::ErrorKind::ConnectionReset | io::ErrorKind::ConnectionAborted => Retryable::Transient,
_ => Retryable::Fatal,
}
}
/// Downcasts the given err source into T. /// Downcasts the given err source into T.
fn get_source_error_type<T: std::error::Error + 'static>( fn get_source_error_type<T: std::error::Error + 'static>(
err: &dyn std::error::Error, err: &dyn std::error::Error,

View file

@ -1,10 +1,15 @@
use async_std::io::ReadExt; use async_std::io::ReadExt;
use async_std::io::WriteExt; use async_std::io::WriteExt;
use async_std::net::{TcpListener, TcpStream}; use async_std::net::{TcpListener, TcpStream};
use futures::future::BoxFuture;
use futures::stream::StreamExt; use futures::stream::StreamExt;
use std::error::Error; use std::error::Error;
use std::fmt; use std::fmt;
type CustomMessageHandler = Box<
dyn Fn(TcpStream) -> BoxFuture<'static, Result<(), Box<dyn std::error::Error>>> + Send + Sync,
>;
/// This is a simple server that returns the responses given at creation time: [`self.raw_http_responses`] following a round-robin mechanism. /// This is a simple server that returns the responses given at creation time: [`self.raw_http_responses`] following a round-robin mechanism.
pub struct SimpleServer { pub struct SimpleServer {
listener: TcpListener, listener: TcpListener,
@ -12,6 +17,7 @@ pub struct SimpleServer {
host: String, host: String,
raw_http_responses: Vec<String>, raw_http_responses: Vec<String>,
calls_counter: usize, calls_counter: usize,
custom_handler: Option<CustomMessageHandler>,
} }
/// Request-Line = Method SP Request-URI SP HTTP-Version CRLF /// Request-Line = Method SP Request-URI SP HTTP-Version CRLF
@ -46,9 +52,21 @@ impl SimpleServer {
host: host.to_string(), host: host.to_string(),
raw_http_responses, raw_http_responses,
calls_counter: 0, calls_counter: 0,
custom_handler: None,
}) })
} }
pub fn set_custom_handler(
&mut self,
custom_handler: impl Fn(TcpStream) -> BoxFuture<'static, Result<(), Box<dyn std::error::Error>>>
+ Send
+ Sync
+ 'static,
) -> &mut Self {
self.custom_handler.replace(Box::new(custom_handler));
self
}
/// Returns the uri in which the server is listening to. /// Returns the uri in which the server is listening to.
pub fn uri(&self) -> String { pub fn uri(&self) -> String {
format!("http://{}:{}", self.host, self.port) format!("http://{}:{}", self.host, self.port)
@ -79,6 +97,10 @@ impl SimpleServer {
/// ///
/// Returns a 400 if the request if formatted badly. /// Returns a 400 if the request if formatted badly.
async fn handle_connection(&self, mut stream: TcpStream) -> Result<(), Box<dyn Error>> { async fn handle_connection(&self, mut stream: TcpStream) -> Result<(), Box<dyn Error>> {
if let Some(ref custom_handler) = self.custom_handler {
return custom_handler(stream).await;
}
let mut buffer = vec![0; 1024]; let mut buffer = vec![0; 1024];
stream.read(&mut buffer).await.unwrap(); stream.read(&mut buffer).await.unwrap();

View file

@ -1,8 +1,12 @@
use async_std::io::ReadExt;
use futures::AsyncWriteExt;
use futures::FutureExt;
use paste::paste; use paste::paste;
use reqwest::Client; use reqwest::Client;
use reqwest::StatusCode; use reqwest::StatusCode;
use reqwest_middleware::ClientBuilder; use reqwest_middleware::ClientBuilder;
use reqwest_retry::{policies::ExponentialBackoff, RetryTransientMiddleware}; use reqwest_retry::{policies::ExponentialBackoff, RetryTransientMiddleware};
use std::sync::atomic::AtomicI8;
use std::sync::{ use std::sync::{
atomic::{AtomicU32, Ordering}, atomic::{AtomicU32, Ordering},
Arc, Arc,
@ -269,3 +273,99 @@ async fn assert_retry_on_incomplete_message() {
assert_eq!(resp.status(), 200); assert_eq!(resp.status(), 200);
} }
#[tokio::test]
async fn assert_retry_on_hyper_canceled() {
let counter = Arc::new(AtomicI8::new(0));
let mut simple_server = SimpleServer::new("127.0.0.1", None, vec![])
.await
.expect("Error when creating a simple server");
simple_server.set_custom_handler(move |mut stream| {
let counter = counter.clone();
async move {
let mut buffer = Vec::new();
stream.read(&mut buffer).await.unwrap();
if counter.fetch_add(1, Ordering::SeqCst) > 1 {
// This triggeres hyper:Error(Canceled).
let _res = stream.shutdown(std::net::Shutdown::Both);
} else {
let _res = stream.write("HTTP/1.1 200 OK\r\n\r\n".as_bytes()).await;
}
Ok(())
}
.boxed()
});
let uri = simple_server.uri();
tokio::spawn(simple_server.start());
let reqwest_client = Client::builder().build().unwrap();
let client = ClientBuilder::new(reqwest_client)
.with(RetryTransientMiddleware::new_with_policy(
ExponentialBackoff {
max_n_retries: 3,
max_retry_interval: std::time::Duration::from_millis(100),
min_retry_interval: std::time::Duration::from_millis(30),
backoff_exponent: 2,
},
))
.build();
let resp = client
.get(&format!("{}/foo", uri))
.timeout(std::time::Duration::from_millis(100))
.send()
.await
.expect("call failed");
assert_eq!(resp.status(), 200);
}
#[tokio::test]
async fn assert_retry_on_connection_reset_by_peer() {
let counter = Arc::new(AtomicI8::new(0));
let mut simple_server = SimpleServer::new("127.0.0.1", None, vec![])
.await
.expect("Error when creating a simple server");
simple_server.set_custom_handler(move |mut stream| {
let counter = counter.clone();
async move {
let mut buffer = Vec::new();
stream.read(&mut buffer).await.unwrap();
if counter.fetch_add(1, Ordering::SeqCst) > 1 {
// This triggeres hyper:Error(Io, io::Error(ConnectionReset)).
drop(stream);
} else {
let _res = stream.write("HTTP/1.1 200 OK\r\n\r\n".as_bytes()).await;
}
Ok(())
}
.boxed()
});
let uri = simple_server.uri();
tokio::spawn(simple_server.start());
let reqwest_client = Client::builder().build().unwrap();
let client = ClientBuilder::new(reqwest_client)
.with(RetryTransientMiddleware::new_with_policy(
ExponentialBackoff {
max_n_retries: 3,
max_retry_interval: std::time::Duration::from_millis(100),
min_retry_interval: std::time::Duration::from_millis(30),
backoff_exponent: 2,
},
))
.build();
let resp = client
.get(&format!("{}/foo", uri))
.timeout(std::time::Duration::from_millis(100))
.send()
.await
.expect("call failed");
assert_eq!(resp.status(), 200);
}