diff --git a/reqwest-retry/Cargo.toml b/reqwest-retry/Cargo.toml index b611c81..491dfb2 100644 --- a/reqwest-retry/Cargo.toml +++ b/reqwest-retry/Cargo.toml @@ -10,13 +10,14 @@ keywords = ["reqwest", "http", "middleware", "retry"] categories = ["web-programming::http-client"] [dependencies] -reqwest-middleware = { version = "0.1.2", path = "../reqwest-middleware" } +reqwest-middleware = { version = "0.1.2" } # TODO remove this change: it was just to make rust-analyzer happy locally anyhow = "1" async-trait = "0.1.51" chrono = "0.4" futures = "0.3" http = "0.2" +hyper = "0.14" retry-policies = "0.1" reqwest = { version = "0.11", default-features = false } tokio = { version = "1.6", features = ["time"] } diff --git a/reqwest-retry/src/retryable.rs b/reqwest-retry/src/retryable.rs index 0087f3e..387e860 100644 --- a/reqwest-retry/src/retryable.rs +++ b/reqwest-retry/src/retryable.rs @@ -1,4 +1,5 @@ use http::StatusCode; +use hyper; use reqwest_middleware::Error; /// Classification of an error/status returned by request. @@ -49,6 +50,14 @@ impl Retryable { || error.is_redirect() { Some(Retryable::Fatal) + } else if let Some(hyper_error) = get_source_error_type::(&error) + { + if hyper_error.is_incomplete_message() { + Some(Retryable::Fatal) + } else { + Some(Retryable::Transient) + } + // TODO: map all the hyper_error types } else { // We omit checking if error.is_status() since we check that already. // However, if Response::error_for_status is used the status will still @@ -66,3 +75,18 @@ impl From<&reqwest::Error> for Retryable { Retryable::Transient } } + +fn get_source_error_type( + err: &dyn std::error::Error, +) -> Option<&T> { + let mut source = err.source(); + + while let Some(err) = source { + if let Some(hyper_err) = err.downcast_ref::() { + return Some(hyper_err); + } + + source = err.source(); + } + None +} diff --git a/reqwest-retry/tests/all/helpers/mod.rs b/reqwest-retry/tests/all/helpers/mod.rs new file mode 100644 index 0000000..78bd5ca --- /dev/null +++ b/reqwest-retry/tests/all/helpers/mod.rs @@ -0,0 +1,3 @@ +mod simple_server; + +pub use simple_server::SimpleServer; diff --git a/reqwest-retry/tests/all/helpers/simple_server.rs b/reqwest-retry/tests/all/helpers/simple_server.rs new file mode 100644 index 0000000..5a739b8 --- /dev/null +++ b/reqwest-retry/tests/all/helpers/simple_server.rs @@ -0,0 +1,113 @@ +use std::error::Error; +use std::io::{Read, Write}; +use std::net::TcpStream; +use std::path::Path; +use std::{ + error::Error, + fmt, + net::{TcpListener, TcpStream}, +}; +use std::{fmt, fs}; + +pub struct SimpleServer { + listener: TcpListener, + port: u16, + host: String, + raw_http_response: String, +} + +/// Request-Line = Method SP Request-URI SP HTTP-Version CRLF +struct Request<'a> { + method: &'a str, + uri: &'a str, + http_version: &'a str, +} + +impl<'a> fmt::Display for Request<'a> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{} {} {}\r\n", self.method, self.uri, self.http_version) + } +} + +/// This is a simple server that +impl SimpleServer { + pub fn new( + host: &str, + port: Option, + raw_http_response: &str, + ) -> Result { + let listener = TcpListener::bind(format!("{}:{}", host, port.ok_or(0).unwrap()))?; + + let port = listener.local_addr()?.port(); + + Ok(Self { + listener, + port, + host: host.to_string(), + raw_http_response: raw_http_response.to_string(), + }) + } + + pub fn uri(&self) -> String { + format!("http://{}:{}", self.host, self.port) + } + + pub fn start(&self) { + for stream in self.listener.incoming() { + match stream { + Ok(stream) => { + match self.handle_connection(stream, self.raw_http_response.clone()) { + Ok(_) => (), + Err(e) => println!("Error handling connection: {}", e), + } + } + Err(e) => println!("Connection failed: {}", e), + } + } + } + + fn handle_connection( + &self, + mut stream: TcpStream, + raw_http_response: String, + ) -> Result<(), Box> { + // 512 bytes is enough for a toy HTTP server + let mut buffer = [0; 512]; + + // writes stream into buffer + stream.read(&mut buffer).unwrap(); + + let request = String::from_utf8_lossy(&buffer[..]); + let request_line = request.lines().next().unwrap(); + + match Self::parse_request_line(&request_line) { + Ok(request) => { + println!("Request: {}", &request); + + let response = format!("{}", raw_http_response.clone()); + + stream.write(response.as_bytes()).unwrap(); + stream.flush().unwrap(); + } + Err(e) => print!("Bad request: {}", e), + } + + Ok(()) + } + + fn parse_request_line(request: &str) -> Result> { + let mut parts = request.split_whitespace(); + + let method = parts.next().ok_or("Method not specified")?; + + let uri = parts.next().ok_or("URI not specified")?; + + let http_version = parts.next().ok_or("HTTP version not specified")?; + + Ok(Request { + method, + uri, + http_version, + }) + } +} diff --git a/reqwest-retry/tests/all/main.rs b/reqwest-retry/tests/all/main.rs new file mode 100644 index 0000000..28b7da7 --- /dev/null +++ b/reqwest-retry/tests/all/main.rs @@ -0,0 +1,2 @@ +mod helpers; +mod retry; diff --git a/reqwest-retry/tests/retry.rs b/reqwest-retry/tests/all/retry.rs similarity index 87% rename from reqwest-retry/tests/retry.rs rename to reqwest-retry/tests/all/retry.rs index ff721f4..9fc6d71 100644 --- a/reqwest-retry/tests/retry.rs +++ b/reqwest-retry/tests/all/retry.rs @@ -10,6 +10,8 @@ use std::sync::{ use wiremock::matchers::{method, path}; use wiremock::{Mock, MockServer, Respond, ResponseTemplate}; +use crate::helpers::SimpleServer; + pub struct RetryResponder(Arc, u32, u16); impl RetryResponder { @@ -213,3 +215,33 @@ async fn assert_retry_on_request_timeout() { assert_eq!(resp.status(), 200); } + +#[tokio::test] +async fn assert_retry_on_incomplete_message() { + let raw_response = "HTTP/1.1 200"; // the full working response is: "HTTP/1.1 200 OK\r\n\r\n" + + let simple_server = SimpleServer::new("127.0.0.1", None, raw_response) + .expect("Error when creating a simple server"); + tokio::spawn(simple_server.start()); + + let reqwest_client = Client::builder().build().unwrap(); + let client = ClientBuilder::new(reqwest_client) + .with(RetryTransientMiddleware::new_with_policy( + ExponentialBackoff { + max_n_retries: 3, + max_retry_interval: std::time::Duration::from_millis(30), + min_retry_interval: std::time::Duration::from_millis(100), + backoff_exponent: 2, + }, + )) + .build(); + + let resp = client + .get(&format!("{}/foo", simple_server.uri())) + .timeout(std::time::Duration::from_millis(10)) + .send() + .await + .expect("call failed"); + + assert_eq!(resp.status(), 200); +}