diff --git a/CHANGELOG.md b/CHANGELOG.md index 42fa4bb..77b0ece 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,6 +6,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] +## [0.1.3] - 2021-10-13 +### Fixed +- Handle the `hyper::Error(IncompleteMessage)` as a `Retryable::Transient`. + ## [0.1.2] - 2021-09-28 ### Changed - Disabled default features on `reqwest` diff --git a/reqwest-retry/Cargo.toml b/reqwest-retry/Cargo.toml index b611c81..d11984d 100644 --- a/reqwest-retry/Cargo.toml +++ b/reqwest-retry/Cargo.toml @@ -17,6 +17,7 @@ async-trait = "0.1.51" chrono = "0.4" futures = "0.3" http = "0.2" +hyper = "0.14" retry-policies = "0.1" reqwest = { version = "0.11", default-features = false } tokio = { version = "1.6", features = ["time"] } @@ -27,3 +28,4 @@ task-local-extensions = "0.1.1" wiremock = "0.5" tokio = { version = "1", features = ["macros"] } paste = "1" +async-std = { version = "1.10"} \ No newline at end of file diff --git a/reqwest-retry/src/retryable.rs b/reqwest-retry/src/retryable.rs index 0087f3e..d6e4cae 100644 --- a/reqwest-retry/src/retryable.rs +++ b/reqwest-retry/src/retryable.rs @@ -44,11 +44,25 @@ impl Retryable { Some(Retryable::Transient) } else if error.is_body() || error.is_decode() - || error.is_request() || error.is_builder() || error.is_redirect() { Some(Retryable::Fatal) + } else if error.is_request() { + // It seems that hyper::Error(IncompleteMessage) is not correctly handled by reqwest. + // Here we check if the Reqwest error was originated by hyper and map it consistently. + if let Some(hyper_error) = get_source_error_type::(&error) { + // The hyper::Error(IncompleteMessage) is raised if the HTTP response is well formatted but does not contain all the bytes. + // This can happen when the server has started sending back the response but the connection is cut halfway thorugh. + // We can safely retry the call, hence marking this error as [`Retryable::Transient`]. + if hyper_error.is_incomplete_message() { + Some(Retryable::Transient) + } else { + Some(Retryable::Fatal) + } + } else { + Some(Retryable::Fatal) + } } else { // We omit checking if error.is_status() since we check that already. // However, if Response::error_for_status is used the status will still @@ -66,3 +80,19 @@ impl From<&reqwest::Error> for Retryable { Retryable::Transient } } + +/// Downcasts the given err source into T. +fn get_source_error_type( + err: &dyn std::error::Error, +) -> Option<&T> { + let mut source = err.source(); + + while let Some(err) = source { + if let Some(hyper_err) = err.downcast_ref::() { + return Some(hyper_err); + } + + source = err.source(); + } + None +} diff --git a/reqwest-retry/tests/all/helpers/mod.rs b/reqwest-retry/tests/all/helpers/mod.rs new file mode 100644 index 0000000..78bd5ca --- /dev/null +++ b/reqwest-retry/tests/all/helpers/mod.rs @@ -0,0 +1,3 @@ +mod simple_server; + +pub use simple_server::SimpleServer; diff --git a/reqwest-retry/tests/all/helpers/simple_server.rs b/reqwest-retry/tests/all/helpers/simple_server.rs new file mode 100644 index 0000000..0fdf5f3 --- /dev/null +++ b/reqwest-retry/tests/all/helpers/simple_server.rs @@ -0,0 +1,140 @@ +use async_std::io::ReadExt; +use async_std::io::WriteExt; +use async_std::net::{TcpListener, TcpStream}; +use futures::stream::StreamExt; +use std::error::Error; +use std::fmt; + +/// This is a simple server that returns the responses given at creation time: [`self.raw_http_responses`] following a round-robin mechanism. +pub struct SimpleServer { + listener: TcpListener, + port: u16, + host: String, + raw_http_responses: Vec, + calls_counter: usize, +} + +/// Request-Line = Method SP Request-URI SP HTTP-Version CRLF +struct Request<'a> { + method: &'a str, + uri: &'a str, + http_version: &'a str, +} + +impl<'a> fmt::Display for Request<'a> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{} {} {}\r\n", self.method, self.uri, self.http_version) + } +} + +impl SimpleServer { + /// Creates an instance of a [`SimpleServer`] + /// If [`port`] is None os Some(0), it gets randomly chosen between the available ones. + pub async fn new( + host: &str, + port: Option, + raw_http_responses: Vec, + ) -> Result { + let port = port.unwrap_or(0); + let listener = TcpListener::bind(format!("{}:{}", host, port)).await?; + + let port = listener.local_addr()?.port(); + + Ok(Self { + listener, + port, + host: host.to_string(), + raw_http_responses, + calls_counter: 0, + }) + } + + /// Returns the uri in which the server is listening to. + pub fn uri(&self) -> String { + format!("http://{}:{}", self.host, self.port) + } + + /// Starts the TcpListener and handles the requests. + pub async fn start(mut self) { + while let Some(stream) = self.listener.incoming().next().await { + match stream { + Ok(stream) => { + match self.handle_connection(stream).await { + Ok(_) => (), + Err(e) => { + println!("Error handling connection: {}", e); + } + } + self.calls_counter += 1; + } + Err(e) => { + println!("Connection failed: {}", e); + } + } + } + } + + /// Asyncrounously reads from the buffer and handle the request. + /// It first checks that the format is correct, then returns the response. + /// + /// Returns a 400 if the request if formatted badly. + async fn handle_connection(&self, mut stream: TcpStream) -> Result<(), Box> { + let mut buffer = vec![0; 1024]; + + stream.read(&mut buffer).await.unwrap(); + + let request = String::from_utf8_lossy(&buffer[..]); + let request_line = request.lines().next().unwrap(); + + let response = match Self::parse_request_line(request_line) { + Ok(request) => { + println!("== Request == \n{}\n=============", request); + self.get_response().clone() + } + Err(e) => { + println!("++ Bad request: {} ++++++", e); + self.get_bad_request_response() + } + }; + + println!("-- Response --\n{}\n--------------", response.clone()); + stream.write(response.as_bytes()).await.unwrap(); + stream.flush().await.unwrap(); + + Ok(()) + } + + /// Parses the request line and checks that it contains the method, uri and http_version parts. + /// It does not check if the content of the checked parts is correct. It just checks the format (it contains enough parts) of the request. + fn parse_request_line(request: &str) -> Result> { + let mut parts = request.split_whitespace(); + + let method = parts.next().ok_or("Method not specified")?; + + let uri = parts.next().ok_or("URI not specified")?; + + let http_version = parts.next().ok_or("HTTP version not specified")?; + + Ok(Request { + method, + uri, + http_version, + }) + } + + /// Returns the response to use based on the calls counter. + /// It uses a round-robin mechanism. + fn get_response(&self) -> String { + let index = if self.calls_counter >= self.raw_http_responses.len() { + self.raw_http_responses.len() % self.calls_counter + } else { + self.calls_counter + }; + self.raw_http_responses[index].clone() + } + + /// Returns the raw HTTP response in case of a 400 Bad Request. + fn get_bad_request_response(&self) -> String { + "HTTP/1.1 400 Bad Request\r\n\r\n".to_string() + } +} diff --git a/reqwest-retry/tests/all/main.rs b/reqwest-retry/tests/all/main.rs new file mode 100644 index 0000000..28b7da7 --- /dev/null +++ b/reqwest-retry/tests/all/main.rs @@ -0,0 +1,2 @@ +mod helpers; +mod retry; diff --git a/reqwest-retry/tests/retry.rs b/reqwest-retry/tests/all/retry.rs similarity index 78% rename from reqwest-retry/tests/retry.rs rename to reqwest-retry/tests/all/retry.rs index ff721f4..9bfcbf3 100644 --- a/reqwest-retry/tests/retry.rs +++ b/reqwest-retry/tests/all/retry.rs @@ -10,6 +10,7 @@ use std::sync::{ use wiremock::matchers::{method, path}; use wiremock::{Mock, MockServer, Respond, ResponseTemplate}; +use crate::helpers::SimpleServer; pub struct RetryResponder(Arc, u32, u16); impl RetryResponder { @@ -197,8 +198,8 @@ async fn assert_retry_on_request_timeout() { .with(RetryTransientMiddleware::new_with_policy( ExponentialBackoff { max_n_retries: 3, - max_retry_interval: std::time::Duration::from_millis(30), - min_retry_interval: std::time::Duration::from_millis(100), + max_retry_interval: std::time::Duration::from_millis(100), + min_retry_interval: std::time::Duration::from_millis(30), backoff_exponent: 2, }, )) @@ -213,3 +214,58 @@ async fn assert_retry_on_request_timeout() { assert_eq!(resp.status(), 200); } + +#[tokio::test] +async fn assert_retry_on_incomplete_message() { + // Following the HTTP/1.1 specification (https://en.wikipedia.org/wiki/HTTP_message_body) a valid response contains: + // - status line + // - headers + // - empty line + // - optional message body + // + // After a few tries we have noticed that: + // - "message_that_makes_no_sense" triggers a hyper::ParseError because the format is completely wrong + // - "HTTP/1.1" triggers a hyper::IncompleteMessage because the format is correct until that point but misses mandatory parts + let incomplete_message = "HTTP/1.1"; + let complete_message = "HTTP/1.1 200 OK\r\n\r\n"; + + // create a SimpleServer that returns the correct response after 3 attempts. + // the first 3 attempts are incomplete http response and internally they result in a [`hyper::Error(IncompleteMessage)`] error. + let simple_server = SimpleServer::new( + "127.0.0.1", + None, + vec![ + incomplete_message.to_string(), + incomplete_message.to_string(), + incomplete_message.to_string(), + complete_message.to_string(), + ], + ) + .await + .expect("Error when creating a simple server"); + + let uri = simple_server.uri(); + + tokio::spawn(simple_server.start()); + + let reqwest_client = Client::builder().build().unwrap(); + let client = ClientBuilder::new(reqwest_client) + .with(RetryTransientMiddleware::new_with_policy( + ExponentialBackoff { + max_n_retries: 3, + max_retry_interval: std::time::Duration::from_millis(100), + min_retry_interval: std::time::Duration::from_millis(30), + backoff_exponent: 2, + }, + )) + .build(); + + let resp = client + .get(&format!("{}/foo", uri)) + .timeout(std::time::Duration::from_millis(100)) + .send() + .await + .expect("call failed"); + + assert_eq!(resp.status(), 200); +}