diff --git a/reqwest-retry/Cargo.toml b/reqwest-retry/Cargo.toml index 491dfb2..a055d71 100644 --- a/reqwest-retry/Cargo.toml +++ b/reqwest-retry/Cargo.toml @@ -28,3 +28,4 @@ task-local-extensions = "0.1.1" wiremock = "0.5" tokio = { version = "1", features = ["macros"] } paste = "1" +async-std = { version = "1.10"} \ No newline at end of file diff --git a/reqwest-retry/src/retryable.rs b/reqwest-retry/src/retryable.rs index 387e860..a904e1c 100644 --- a/reqwest-retry/src/retryable.rs +++ b/reqwest-retry/src/retryable.rs @@ -45,19 +45,20 @@ impl Retryable { Some(Retryable::Transient) } else if error.is_body() || error.is_decode() - || error.is_request() || error.is_builder() || error.is_redirect() { Some(Retryable::Fatal) - } else if let Some(hyper_error) = get_source_error_type::(&error) - { - if hyper_error.is_incomplete_message() { - Some(Retryable::Fatal) + } else if error.is_request() { + if let Some(hyper_error) = get_source_error_type::(&error) { + if hyper_error.is_incomplete_message() { + Some(Retryable::Transient) + } else { + Some(Retryable::Fatal) + } } else { - Some(Retryable::Transient) + Some(Retryable::Fatal) } - // TODO: map all the hyper_error types } else { // We omit checking if error.is_status() since we check that already. // However, if Response::error_for_status is used the status will still diff --git a/reqwest-retry/tests/all/helpers/simple_server.rs b/reqwest-retry/tests/all/helpers/simple_server.rs index 5a739b8..f33d3a2 100644 --- a/reqwest-retry/tests/all/helpers/simple_server.rs +++ b/reqwest-retry/tests/all/helpers/simple_server.rs @@ -1,19 +1,17 @@ +use async_std::io::ReadExt; +use async_std::io::WriteExt; +use async_std::net::{TcpListener, TcpStream}; +use futures::stream::StreamExt; use std::error::Error; -use std::io::{Read, Write}; -use std::net::TcpStream; -use std::path::Path; -use std::{ - error::Error, - fmt, - net::{TcpListener, TcpStream}, -}; -use std::{fmt, fs}; +use std::fmt; +/// This is a simple server that returns the responses given at creation time: [`self.raw_http_responses`] following a round-robin mechanism. pub struct SimpleServer { listener: TcpListener, port: u16, host: String, - raw_http_response: String, + raw_http_responses: Vec, + calls_counter: usize, } /// Request-Line = Method SP Request-URI SP HTTP-Version CRLF @@ -29,14 +27,16 @@ impl<'a> fmt::Display for Request<'a> { } } -/// This is a simple server that impl SimpleServer { - pub fn new( + /// Creates an instance of a [`SimpleServer`] + /// If [`port`] is None os Some(0), it gets randomly chosen between the available ones. + pub async fn new( host: &str, port: Option, - raw_http_response: &str, + raw_http_responses: Vec, ) -> Result { - let listener = TcpListener::bind(format!("{}:{}", host, port.ok_or(0).unwrap()))?; + let port = port.unwrap_or(0); + let listener = TcpListener::bind(format!("{}:{}", host, port)).await?; let port = listener.local_addr()?.port(); @@ -44,57 +44,70 @@ impl SimpleServer { listener, port, host: host.to_string(), - raw_http_response: raw_http_response.to_string(), + raw_http_responses, + calls_counter: 0, }) } + /// Returns the uri in which the server is listening to. pub fn uri(&self) -> String { format!("http://{}:{}", self.host, self.port) } - pub fn start(&self) { - for stream in self.listener.incoming() { + /// Starts the TcpListener and handles the requests. + pub async fn start(mut self) -> () { + while let Some(stream) = self.listener.incoming().next().await { match stream { Ok(stream) => { - match self.handle_connection(stream, self.raw_http_response.clone()) { + match self.handle_connection(stream).await { Ok(_) => (), - Err(e) => println!("Error handling connection: {}", e), + Err(e) => { + println!("Error handling connection: {}", e); + () + } } + self.calls_counter = self.calls_counter + 1; + } + Err(e) => { + println!("Connection failed: {}", e); + () } - Err(e) => println!("Connection failed: {}", e), } } } - fn handle_connection( - &self, - mut stream: TcpStream, - raw_http_response: String, - ) -> Result<(), Box> { - // 512 bytes is enough for a toy HTTP server - let mut buffer = [0; 512]; + /// Asyncrounously reads from the buffer and handle the request. + /// It first checks that the format is correct, then returns the response. + /// + /// Returns a 400 if the request if formatted badly. + async fn handle_connection(&self, mut stream: TcpStream) -> Result<(), Box> { + let mut buffer = vec![0; 1024]; - // writes stream into buffer - stream.read(&mut buffer).unwrap(); + stream.read(&mut buffer).await.unwrap(); let request = String::from_utf8_lossy(&buffer[..]); let request_line = request.lines().next().unwrap(); - match Self::parse_request_line(&request_line) { + let response = match Self::parse_request_line(&request_line) { Ok(request) => { - println!("Request: {}", &request); - - let response = format!("{}", raw_http_response.clone()); - - stream.write(response.as_bytes()).unwrap(); - stream.flush().unwrap(); + println!("== Request == \n{}\n=============", request); + self.get_response().clone() } - Err(e) => print!("Bad request: {}", e), - } + Err(e) => { + println!("++ Bad request: {} ++++++", e); + self.get_bad_request_response() + } + }; + + println!("-- Response --\n{}\n--------------", response.clone()); + stream.write(response.as_bytes()).await.unwrap(); + stream.flush().await.unwrap(); Ok(()) } + /// Parses the request line and checks that it contains the method, uri and http_version parts. + /// It does not check if the content of the checked parts is correct. It just checks the format (it contains enough parts) of the request. fn parse_request_line(request: &str) -> Result> { let mut parts = request.split_whitespace(); @@ -110,4 +123,20 @@ impl SimpleServer { http_version, }) } + + /// Returns the response to use based on the calls counter. + /// It uses a round-robin mechanism. + fn get_response(&self) -> String { + let index = if self.calls_counter >= self.raw_http_responses.len() { + self.raw_http_responses.len() % self.calls_counter + } else { + self.calls_counter + }; + self.raw_http_responses[index].clone() + } + + /// Returns the raw HTTP response in case of a 400 Bad Request. + fn get_bad_request_response(&self) -> String { + "HTTP/1.1 400 Bad Request\r\n\r\n".to_string() + } } diff --git a/reqwest-retry/tests/all/retry.rs b/reqwest-retry/tests/all/retry.rs index 9fc6d71..0ddf0c7 100644 --- a/reqwest-retry/tests/all/retry.rs +++ b/reqwest-retry/tests/all/retry.rs @@ -11,7 +11,6 @@ use wiremock::matchers::{method, path}; use wiremock::{Mock, MockServer, Respond, ResponseTemplate}; use crate::helpers::SimpleServer; - pub struct RetryResponder(Arc, u32, u16); impl RetryResponder { @@ -218,10 +217,24 @@ async fn assert_retry_on_request_timeout() { #[tokio::test] async fn assert_retry_on_incomplete_message() { - let raw_response = "HTTP/1.1 200"; // the full working response is: "HTTP/1.1 200 OK\r\n\r\n" + let incomplete_message = "HTTP/1.1"; + let complete_message = "HTTP/1.1 200 OK\r\n\r\n"; + + let simple_server = SimpleServer::new( + "127.0.0.1", + None, + vec![ + incomplete_message.to_string(), + incomplete_message.to_string(), + incomplete_message.to_string(), + complete_message.to_string(), + ], + ) + .await + .expect("Error when creating a simple server"); + + let uri = simple_server.uri(); - let simple_server = SimpleServer::new("127.0.0.1", None, raw_response) - .expect("Error when creating a simple server"); tokio::spawn(simple_server.start()); let reqwest_client = Client::builder().build().unwrap(); @@ -237,8 +250,8 @@ async fn assert_retry_on_incomplete_message() { .build(); let resp = client - .get(&format!("{}/foo", simple_server.uri())) - .timeout(std::time::Duration::from_millis(10)) + .get(&format!("{}/foo", uri)) + .timeout(std::time::Duration::from_millis(100)) .send() .await .expect("call failed");