From f854725791ccf4a02c401a26cab3d9db753f468c Mon Sep 17 00:00:00 2001 From: tl-helge-hoff <71260463+tl-helge-hoff@users.noreply.github.com> Date: Fri, 2 Dec 2022 10:57:24 +0100 Subject: [PATCH] feat: classify io error connection reset by peer (#78) * feat: classify io error connection reset by peer * doc: explain tests * chore: add changelog * fix: use and_then * chore: docs and clippy * chore: bump reqwest-retry to 0.2.1 * fix: use get_source_error_type * fix: remove unused import * fix: make clippy happy * doc: describe canceled --- reqwest-retry/CHANGELOG.md | 5 + reqwest-retry/Cargo.toml | 3 +- reqwest-retry/src/retryable.rs | 20 +++- .../tests/all/helpers/simple_server.rs | 22 ++++ reqwest-retry/tests/all/retry.rs | 100 ++++++++++++++++++ 5 files changed, 148 insertions(+), 2 deletions(-) diff --git a/reqwest-retry/CHANGELOG.md b/reqwest-retry/CHANGELOG.md index 3beb93d..cf1d023 100644 --- a/reqwest-retry/CHANGELOG.md +++ b/reqwest-retry/CHANGELOG.md @@ -6,6 +6,11 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] +## [0.2.1] - 2022-12-01 + +### Changed +- Classify `io::Error`s and `hyper::Error(Canceled)` as transient + ## [0.2.0] - 2022-11-15 ### Changed diff --git a/reqwest-retry/Cargo.toml b/reqwest-retry/Cargo.toml index ef77147..6a86198 100644 --- a/reqwest-retry/Cargo.toml +++ b/reqwest-retry/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "reqwest-retry" -version = "0.2.0" +version = "0.2.1" authors = ["Rodrigo Gryzinski "] edition = "2018" description = "Retry middleware for reqwest." @@ -29,3 +29,4 @@ async-std = { version = "1.10"} paste = "1" tokio = { version = "1", features = ["macros"] } wiremock = "0.5" +futures = "0.3" diff --git a/reqwest-retry/src/retryable.rs b/reqwest-retry/src/retryable.rs index d6e4cae..b599038 100644 --- a/reqwest-retry/src/retryable.rs +++ b/reqwest-retry/src/retryable.rs @@ -1,3 +1,5 @@ +use std::io; + use http::StatusCode; use reqwest_middleware::Error; @@ -55,8 +57,17 @@ impl Retryable { // The hyper::Error(IncompleteMessage) is raised if the HTTP response is well formatted but does not contain all the bytes. // This can happen when the server has started sending back the response but the connection is cut halfway thorugh. // We can safely retry the call, hence marking this error as [`Retryable::Transient`]. - if hyper_error.is_incomplete_message() { + // Instead hyper::Error(Canceled) is raised when the connection is + // gracefully closed on the server side. + if hyper_error.is_incomplete_message() || hyper_error.is_canceled() { Some(Retryable::Transient) + + // Try and downcast the hyper error to io::Error if that is the + // underlying error, and try and classify it. + } else if let Some(io_error) = + get_source_error_type::(hyper_error) + { + Some(classify_io_error(io_error)) } else { Some(Retryable::Fatal) } @@ -81,6 +92,13 @@ impl From<&reqwest::Error> for Retryable { } } +fn classify_io_error(error: &io::Error) -> Retryable { + match error.kind() { + io::ErrorKind::ConnectionReset | io::ErrorKind::ConnectionAborted => Retryable::Transient, + _ => Retryable::Fatal, + } +} + /// Downcasts the given err source into T. fn get_source_error_type( err: &dyn std::error::Error, diff --git a/reqwest-retry/tests/all/helpers/simple_server.rs b/reqwest-retry/tests/all/helpers/simple_server.rs index 0fdf5f3..fe698c0 100644 --- a/reqwest-retry/tests/all/helpers/simple_server.rs +++ b/reqwest-retry/tests/all/helpers/simple_server.rs @@ -1,10 +1,15 @@ use async_std::io::ReadExt; use async_std::io::WriteExt; use async_std::net::{TcpListener, TcpStream}; +use futures::future::BoxFuture; use futures::stream::StreamExt; use std::error::Error; use std::fmt; +type CustomMessageHandler = Box< + dyn Fn(TcpStream) -> BoxFuture<'static, Result<(), Box>> + Send + Sync, +>; + /// This is a simple server that returns the responses given at creation time: [`self.raw_http_responses`] following a round-robin mechanism. pub struct SimpleServer { listener: TcpListener, @@ -12,6 +17,7 @@ pub struct SimpleServer { host: String, raw_http_responses: Vec, calls_counter: usize, + custom_handler: Option, } /// Request-Line = Method SP Request-URI SP HTTP-Version CRLF @@ -46,9 +52,21 @@ impl SimpleServer { host: host.to_string(), raw_http_responses, calls_counter: 0, + custom_handler: None, }) } + pub fn set_custom_handler( + &mut self, + custom_handler: impl Fn(TcpStream) -> BoxFuture<'static, Result<(), Box>> + + Send + + Sync + + 'static, + ) -> &mut Self { + self.custom_handler.replace(Box::new(custom_handler)); + self + } + /// Returns the uri in which the server is listening to. pub fn uri(&self) -> String { format!("http://{}:{}", self.host, self.port) @@ -79,6 +97,10 @@ impl SimpleServer { /// /// Returns a 400 if the request if formatted badly. async fn handle_connection(&self, mut stream: TcpStream) -> Result<(), Box> { + if let Some(ref custom_handler) = self.custom_handler { + return custom_handler(stream).await; + } + let mut buffer = vec![0; 1024]; stream.read(&mut buffer).await.unwrap(); diff --git a/reqwest-retry/tests/all/retry.rs b/reqwest-retry/tests/all/retry.rs index 9bfcbf3..79b98a0 100644 --- a/reqwest-retry/tests/all/retry.rs +++ b/reqwest-retry/tests/all/retry.rs @@ -1,8 +1,12 @@ +use async_std::io::ReadExt; +use futures::AsyncWriteExt; +use futures::FutureExt; use paste::paste; use reqwest::Client; use reqwest::StatusCode; use reqwest_middleware::ClientBuilder; use reqwest_retry::{policies::ExponentialBackoff, RetryTransientMiddleware}; +use std::sync::atomic::AtomicI8; use std::sync::{ atomic::{AtomicU32, Ordering}, Arc, @@ -269,3 +273,99 @@ async fn assert_retry_on_incomplete_message() { assert_eq!(resp.status(), 200); } + +#[tokio::test] +async fn assert_retry_on_hyper_canceled() { + let counter = Arc::new(AtomicI8::new(0)); + let mut simple_server = SimpleServer::new("127.0.0.1", None, vec![]) + .await + .expect("Error when creating a simple server"); + simple_server.set_custom_handler(move |mut stream| { + let counter = counter.clone(); + async move { + let mut buffer = Vec::new(); + stream.read(&mut buffer).await.unwrap(); + if counter.fetch_add(1, Ordering::SeqCst) > 1 { + // This triggeres hyper:Error(Canceled). + let _res = stream.shutdown(std::net::Shutdown::Both); + } else { + let _res = stream.write("HTTP/1.1 200 OK\r\n\r\n".as_bytes()).await; + } + Ok(()) + } + .boxed() + }); + + let uri = simple_server.uri(); + + tokio::spawn(simple_server.start()); + + let reqwest_client = Client::builder().build().unwrap(); + let client = ClientBuilder::new(reqwest_client) + .with(RetryTransientMiddleware::new_with_policy( + ExponentialBackoff { + max_n_retries: 3, + max_retry_interval: std::time::Duration::from_millis(100), + min_retry_interval: std::time::Duration::from_millis(30), + backoff_exponent: 2, + }, + )) + .build(); + + let resp = client + .get(&format!("{}/foo", uri)) + .timeout(std::time::Duration::from_millis(100)) + .send() + .await + .expect("call failed"); + + assert_eq!(resp.status(), 200); +} + +#[tokio::test] +async fn assert_retry_on_connection_reset_by_peer() { + let counter = Arc::new(AtomicI8::new(0)); + let mut simple_server = SimpleServer::new("127.0.0.1", None, vec![]) + .await + .expect("Error when creating a simple server"); + simple_server.set_custom_handler(move |mut stream| { + let counter = counter.clone(); + async move { + let mut buffer = Vec::new(); + stream.read(&mut buffer).await.unwrap(); + if counter.fetch_add(1, Ordering::SeqCst) > 1 { + // This triggeres hyper:Error(Io, io::Error(ConnectionReset)). + drop(stream); + } else { + let _res = stream.write("HTTP/1.1 200 OK\r\n\r\n".as_bytes()).await; + } + Ok(()) + } + .boxed() + }); + + let uri = simple_server.uri(); + + tokio::spawn(simple_server.start()); + + let reqwest_client = Client::builder().build().unwrap(); + let client = ClientBuilder::new(reqwest_client) + .with(RetryTransientMiddleware::new_with_policy( + ExponentialBackoff { + max_n_retries: 3, + max_retry_interval: std::time::Duration::from_millis(100), + min_retry_interval: std::time::Duration::from_millis(30), + backoff_exponent: 2, + }, + )) + .build(); + + let resp = client + .get(&format!("{}/foo", uri)) + .timeout(std::time::Duration::from_millis(100)) + .send() + .await + .expect("call failed"); + + assert_eq!(resp.status(), 200); +}