mirror of
https://github.com/TrueLayer/reqwest-middleware.git
synced 2025-01-13 20:37:27 -08:00
363 lines
12 KiB
Rust
363 lines
12 KiB
Rust
use futures::FutureExt;
|
|
use paste::paste;
|
|
use reqwest::Client;
|
|
use reqwest::StatusCode;
|
|
use reqwest_middleware::ClientBuilder;
|
|
use reqwest_retry::{policies::ExponentialBackoff, RetryTransientMiddleware};
|
|
use std::sync::atomic::AtomicI8;
|
|
use std::sync::{
|
|
atomic::{AtomicU32, Ordering},
|
|
Arc,
|
|
};
|
|
use tokio::io::AsyncReadExt;
|
|
use tokio::io::AsyncWriteExt;
|
|
use wiremock::matchers::{method, path};
|
|
use wiremock::{Mock, MockServer, Respond, ResponseTemplate};
|
|
|
|
use crate::helpers::SimpleServer;
|
|
pub struct RetryResponder(Arc<AtomicU32>, u32, u16);
|
|
|
|
impl RetryResponder {
|
|
fn new(retries: u32, status_code: u16) -> Self {
|
|
Self(Arc::new(AtomicU32::new(0)), retries, status_code)
|
|
}
|
|
}
|
|
|
|
impl Respond for RetryResponder {
|
|
fn respond(&self, _request: &wiremock::Request) -> ResponseTemplate {
|
|
let mut retries = self.0.load(Ordering::SeqCst);
|
|
retries += 1;
|
|
self.0.store(retries, Ordering::SeqCst);
|
|
|
|
if retries + 1 >= self.1 {
|
|
ResponseTemplate::new(200)
|
|
} else {
|
|
ResponseTemplate::new(self.2)
|
|
}
|
|
}
|
|
}
|
|
|
|
macro_rules! assert_retry_succeeds_inner {
|
|
($x:tt, $name:ident, $status:expr, $retry:tt, $exact:tt, $responder:expr) => {
|
|
#[tokio::test]
|
|
async fn $name() {
|
|
let server = MockServer::start().await;
|
|
let retry_amount: u32 = $retry;
|
|
Mock::given(method("GET"))
|
|
.and(path("/foo"))
|
|
.respond_with($responder)
|
|
.expect($exact)
|
|
.mount(&server)
|
|
.await;
|
|
|
|
let reqwest_client = Client::builder().build().unwrap();
|
|
let client = ClientBuilder::new(reqwest_client)
|
|
.with(RetryTransientMiddleware::new_with_policy(
|
|
ExponentialBackoff::builder()
|
|
.retry_bounds(
|
|
std::time::Duration::from_millis(30),
|
|
std::time::Duration::from_millis(100),
|
|
)
|
|
.build_with_max_retries(retry_amount),
|
|
))
|
|
.build();
|
|
|
|
let resp = client
|
|
.get(&format!("{}/foo", server.uri()))
|
|
.send()
|
|
.await
|
|
.expect("call failed");
|
|
|
|
assert_eq!(resp.status(), $status);
|
|
}
|
|
};
|
|
}
|
|
|
|
macro_rules! assert_retry_succeeds {
|
|
($x:tt, $status:expr) => {
|
|
paste! {
|
|
assert_retry_succeeds_inner!($x, [<assert_retry_succeeds_on_ $x>], $status, 3, 2, RetryResponder::new(3 as u32, $x));
|
|
}
|
|
};
|
|
}
|
|
|
|
macro_rules! assert_no_retry {
|
|
($x:tt, $status:expr) => {
|
|
paste! {
|
|
assert_retry_succeeds_inner!($x, [<assert_no_retry_on_ $x>], $status, 1, 1, ResponseTemplate::new($x));
|
|
}
|
|
};
|
|
}
|
|
|
|
// 2xx.
|
|
assert_no_retry!(200, StatusCode::OK);
|
|
assert_no_retry!(201, StatusCode::CREATED);
|
|
assert_no_retry!(202, StatusCode::ACCEPTED);
|
|
assert_no_retry!(203, StatusCode::NON_AUTHORITATIVE_INFORMATION);
|
|
assert_no_retry!(204, StatusCode::NO_CONTENT);
|
|
assert_no_retry!(205, StatusCode::RESET_CONTENT);
|
|
assert_no_retry!(206, StatusCode::PARTIAL_CONTENT);
|
|
assert_no_retry!(207, StatusCode::MULTI_STATUS);
|
|
assert_no_retry!(226, StatusCode::IM_USED);
|
|
|
|
// 3xx.
|
|
assert_no_retry!(300, StatusCode::MULTIPLE_CHOICES);
|
|
assert_no_retry!(301, StatusCode::MOVED_PERMANENTLY);
|
|
assert_no_retry!(302, StatusCode::FOUND);
|
|
assert_no_retry!(303, StatusCode::SEE_OTHER);
|
|
assert_no_retry!(304, StatusCode::NOT_MODIFIED);
|
|
assert_no_retry!(307, StatusCode::TEMPORARY_REDIRECT);
|
|
assert_no_retry!(308, StatusCode::PERMANENT_REDIRECT);
|
|
|
|
// 5xx.
|
|
assert_retry_succeeds!(500, StatusCode::OK);
|
|
assert_retry_succeeds!(501, StatusCode::OK);
|
|
assert_retry_succeeds!(502, StatusCode::OK);
|
|
assert_retry_succeeds!(503, StatusCode::OK);
|
|
assert_retry_succeeds!(504, StatusCode::OK);
|
|
assert_retry_succeeds!(505, StatusCode::OK);
|
|
assert_retry_succeeds!(506, StatusCode::OK);
|
|
assert_retry_succeeds!(507, StatusCode::OK);
|
|
assert_retry_succeeds!(508, StatusCode::OK);
|
|
assert_retry_succeeds!(510, StatusCode::OK);
|
|
assert_retry_succeeds!(511, StatusCode::OK);
|
|
// 4xx.
|
|
assert_no_retry!(400, StatusCode::BAD_REQUEST);
|
|
assert_no_retry!(401, StatusCode::UNAUTHORIZED);
|
|
assert_no_retry!(402, StatusCode::PAYMENT_REQUIRED);
|
|
assert_no_retry!(403, StatusCode::FORBIDDEN);
|
|
assert_no_retry!(404, StatusCode::NOT_FOUND);
|
|
assert_no_retry!(405, StatusCode::METHOD_NOT_ALLOWED);
|
|
assert_no_retry!(406, StatusCode::NOT_ACCEPTABLE);
|
|
assert_no_retry!(407, StatusCode::PROXY_AUTHENTICATION_REQUIRED);
|
|
assert_retry_succeeds!(408, StatusCode::OK);
|
|
assert_no_retry!(409, StatusCode::CONFLICT);
|
|
assert_no_retry!(410, StatusCode::GONE);
|
|
assert_no_retry!(411, StatusCode::LENGTH_REQUIRED);
|
|
assert_no_retry!(412, StatusCode::PRECONDITION_FAILED);
|
|
assert_no_retry!(413, StatusCode::PAYLOAD_TOO_LARGE);
|
|
assert_no_retry!(414, StatusCode::URI_TOO_LONG);
|
|
assert_no_retry!(415, StatusCode::UNSUPPORTED_MEDIA_TYPE);
|
|
assert_no_retry!(416, StatusCode::RANGE_NOT_SATISFIABLE);
|
|
assert_no_retry!(417, StatusCode::EXPECTATION_FAILED);
|
|
assert_no_retry!(418, StatusCode::IM_A_TEAPOT);
|
|
assert_no_retry!(421, StatusCode::MISDIRECTED_REQUEST);
|
|
assert_no_retry!(422, StatusCode::UNPROCESSABLE_ENTITY);
|
|
assert_no_retry!(423, StatusCode::LOCKED);
|
|
assert_no_retry!(424, StatusCode::FAILED_DEPENDENCY);
|
|
assert_no_retry!(426, StatusCode::UPGRADE_REQUIRED);
|
|
assert_no_retry!(428, StatusCode::PRECONDITION_REQUIRED);
|
|
assert_retry_succeeds!(429, StatusCode::OK);
|
|
assert_no_retry!(431, StatusCode::REQUEST_HEADER_FIELDS_TOO_LARGE);
|
|
assert_no_retry!(451, StatusCode::UNAVAILABLE_FOR_LEGAL_REASONS);
|
|
|
|
pub struct RetryTimeoutResponder(Arc<AtomicU32>, u32, std::time::Duration);
|
|
|
|
impl RetryTimeoutResponder {
|
|
fn new(retries: u32, initial_timeout: std::time::Duration) -> Self {
|
|
Self(Arc::new(AtomicU32::new(0)), retries, initial_timeout)
|
|
}
|
|
}
|
|
|
|
impl Respond for RetryTimeoutResponder {
|
|
fn respond(&self, _request: &wiremock::Request) -> ResponseTemplate {
|
|
let mut retries = self.0.load(Ordering::SeqCst);
|
|
retries += 1;
|
|
self.0.store(retries, Ordering::SeqCst);
|
|
|
|
if retries + 1 >= self.1 {
|
|
ResponseTemplate::new(200)
|
|
} else {
|
|
ResponseTemplate::new(500).set_delay(self.2)
|
|
}
|
|
}
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn assert_retry_on_request_timeout() {
|
|
let server = MockServer::start().await;
|
|
Mock::given(method("GET"))
|
|
.and(path("/foo"))
|
|
.respond_with(RetryTimeoutResponder::new(
|
|
3,
|
|
std::time::Duration::from_millis(1000),
|
|
))
|
|
.expect(2)
|
|
.mount(&server)
|
|
.await;
|
|
|
|
let reqwest_client = Client::builder().build().unwrap();
|
|
let client = ClientBuilder::new(reqwest_client)
|
|
.with(RetryTransientMiddleware::new_with_policy(
|
|
ExponentialBackoff::builder()
|
|
.retry_bounds(
|
|
std::time::Duration::from_millis(30),
|
|
std::time::Duration::from_millis(100),
|
|
)
|
|
.build_with_max_retries(3),
|
|
))
|
|
.build();
|
|
|
|
let resp = client
|
|
.get(format!("{}/foo", server.uri()))
|
|
.timeout(std::time::Duration::from_millis(10))
|
|
.send()
|
|
.await
|
|
.expect("call failed");
|
|
|
|
assert_eq!(resp.status(), 200);
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn assert_retry_on_incomplete_message() {
|
|
// Following the HTTP/1.1 specification (https://en.wikipedia.org/wiki/HTTP_message_body) a valid response contains:
|
|
// - status line
|
|
// - headers
|
|
// - empty line
|
|
// - optional message body
|
|
//
|
|
// After a few tries we have noticed that:
|
|
// - "message_that_makes_no_sense" triggers a hyper::ParseError because the format is completely wrong
|
|
// - "HTTP/1.1" triggers a hyper::IncompleteMessage because the format is correct until that point but misses mandatory parts
|
|
let incomplete_message = "HTTP/1.1";
|
|
let complete_message = "HTTP/1.1 200 OK\r\n\r\n";
|
|
|
|
// create a SimpleServer that returns the correct response after 3 attempts.
|
|
// the first 3 attempts are incomplete http response and internally they result in a [`hyper::Error(IncompleteMessage)`] error.
|
|
let simple_server = SimpleServer::new(
|
|
"127.0.0.1",
|
|
None,
|
|
vec![
|
|
incomplete_message.to_string(),
|
|
incomplete_message.to_string(),
|
|
incomplete_message.to_string(),
|
|
complete_message.to_string(),
|
|
],
|
|
)
|
|
.await
|
|
.expect("Error when creating a simple server");
|
|
|
|
let uri = simple_server.uri();
|
|
|
|
tokio::spawn(simple_server.start());
|
|
|
|
let reqwest_client = Client::builder().build().unwrap();
|
|
let client = ClientBuilder::new(reqwest_client)
|
|
.with(RetryTransientMiddleware::new_with_policy(
|
|
ExponentialBackoff::builder()
|
|
.retry_bounds(
|
|
std::time::Duration::from_millis(30),
|
|
std::time::Duration::from_millis(100),
|
|
)
|
|
.build_with_max_retries(3),
|
|
))
|
|
.build();
|
|
|
|
let resp = client
|
|
.get(format!("{}/foo", uri))
|
|
.timeout(std::time::Duration::from_millis(100))
|
|
.send()
|
|
.await
|
|
.expect("call failed");
|
|
|
|
assert_eq!(resp.status(), 200);
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn assert_retry_on_hyper_canceled() {
|
|
let counter = Arc::new(AtomicI8::new(0));
|
|
let mut simple_server = SimpleServer::new("127.0.0.1", None, vec![])
|
|
.await
|
|
.expect("Error when creating a simple server");
|
|
simple_server.set_custom_handler(move |mut stream| {
|
|
let counter = counter.clone();
|
|
async move {
|
|
let mut buffer = Vec::new();
|
|
stream.read_buf(&mut buffer).await.unwrap();
|
|
if counter.fetch_add(1, Ordering::SeqCst) > 1 {
|
|
// This triggers hyper:Error(Canceled).
|
|
let _res = stream
|
|
.into_std()
|
|
.unwrap()
|
|
.shutdown(std::net::Shutdown::Both);
|
|
} else {
|
|
let _res = stream.write("HTTP/1.1 200 OK\r\n\r\n".as_bytes()).await;
|
|
}
|
|
Ok(())
|
|
}
|
|
.boxed()
|
|
});
|
|
|
|
let uri = simple_server.uri();
|
|
|
|
tokio::spawn(simple_server.start());
|
|
|
|
let reqwest_client = Client::builder().build().unwrap();
|
|
let client = ClientBuilder::new(reqwest_client)
|
|
.with(RetryTransientMiddleware::new_with_policy(
|
|
ExponentialBackoff::builder()
|
|
.retry_bounds(
|
|
std::time::Duration::from_millis(30),
|
|
std::time::Duration::from_millis(100),
|
|
)
|
|
.build_with_max_retries(3),
|
|
))
|
|
.build();
|
|
|
|
let resp = client
|
|
.get(format!("{}/foo", uri))
|
|
.timeout(std::time::Duration::from_millis(100))
|
|
.send()
|
|
.await
|
|
.expect("call failed");
|
|
|
|
assert_eq!(resp.status(), 200);
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn assert_retry_on_connection_reset_by_peer() {
|
|
let counter = Arc::new(AtomicI8::new(0));
|
|
let mut simple_server = SimpleServer::new("127.0.0.1", None, vec![])
|
|
.await
|
|
.expect("Error when creating a simple server");
|
|
simple_server.set_custom_handler(move |mut stream| {
|
|
let counter = counter.clone();
|
|
async move {
|
|
let mut buffer = Vec::new();
|
|
stream.read_buf(&mut buffer).await.unwrap();
|
|
if counter.fetch_add(1, Ordering::SeqCst) > 1 {
|
|
// This triggers hyper:Error(Io, io::Error(ConnectionReset)).
|
|
drop(stream);
|
|
} else {
|
|
let _res = stream.write("HTTP/1.1 200 OK\r\n\r\n".as_bytes()).await;
|
|
}
|
|
Ok(())
|
|
}
|
|
.boxed()
|
|
});
|
|
|
|
let uri = simple_server.uri();
|
|
|
|
tokio::spawn(simple_server.start());
|
|
|
|
let reqwest_client = Client::builder().build().unwrap();
|
|
let client = ClientBuilder::new(reqwest_client)
|
|
.with(RetryTransientMiddleware::new_with_policy(
|
|
ExponentialBackoff::builder()
|
|
.retry_bounds(
|
|
std::time::Duration::from_millis(30),
|
|
std::time::Duration::from_millis(100),
|
|
)
|
|
.build_with_max_retries(3),
|
|
))
|
|
.build();
|
|
|
|
let resp = client
|
|
.get(format!("{}/foo", uri))
|
|
.timeout(std::time::Duration::from_millis(100))
|
|
.send()
|
|
.await
|
|
.expect("call failed");
|
|
|
|
assert_eq!(resp.status(), 200);
|
|
}
|