use futures::FutureExt;
use paste::paste;
use reqwest::Client;
use reqwest::StatusCode;
use reqwest_middleware::ClientBuilder;
use reqwest_retry::{policies::ExponentialBackoff, RetryTransientMiddleware};
use std::sync::atomic::AtomicI8;
use std::sync::{
    atomic::{AtomicU32, Ordering},
    Arc,
};
use tokio::io::AsyncReadExt;
use tokio::io::AsyncWriteExt;
use wiremock::matchers::{method, path};
use wiremock::{Mock, MockServer, Respond, ResponseTemplate};

use crate::helpers::SimpleServer;
pub struct RetryResponder(Arc<AtomicU32>, u32, u16);

impl RetryResponder {
    fn new(retries: u32, status_code: u16) -> Self {
        Self(Arc::new(AtomicU32::new(0)), retries, status_code)
    }
}

impl Respond for RetryResponder {
    fn respond(&self, _request: &wiremock::Request) -> ResponseTemplate {
        let mut retries = self.0.load(Ordering::SeqCst);
        retries += 1;
        self.0.store(retries, Ordering::SeqCst);

        if retries + 1 >= self.1 {
            ResponseTemplate::new(200)
        } else {
            ResponseTemplate::new(self.2)
        }
    }
}

macro_rules! assert_retry_succeeds_inner {
    ($x:tt, $name:ident, $status:expr, $retry:tt, $exact:tt, $responder:expr) => {
        #[tokio::test]
        async fn $name() {
            let server = MockServer::start().await;
            let retry_amount: u32 = $retry;
            Mock::given(method("GET"))
                .and(path("/foo"))
                .respond_with($responder)
                .expect($exact)
                .mount(&server)
                .await;

            let reqwest_client = Client::builder().build().unwrap();
            let client = ClientBuilder::new(reqwest_client)
                .with(RetryTransientMiddleware::new_with_policy(
                    ExponentialBackoff::builder()
                        .retry_bounds(
                            std::time::Duration::from_millis(30),
                            std::time::Duration::from_millis(100),
                        )
                        .build_with_max_retries(retry_amount),
                ))
                .build();

            let resp = client
                .get(&format!("{}/foo", server.uri()))
                .send()
                .await
                .expect("call failed");

            assert_eq!(resp.status(), $status);
        }
    };
}

macro_rules! assert_retry_succeeds {
    ($x:tt, $status:expr) => {
        paste! {
            assert_retry_succeeds_inner!($x, [<assert_retry_succeeds_on_ $x>], $status, 3, 2, RetryResponder::new(3 as u32, $x));
        }
    };
}

macro_rules! assert_no_retry {
    ($x:tt, $status:expr) => {
        paste! {
            assert_retry_succeeds_inner!($x, [<assert_no_retry_on_ $x>], $status, 1, 1, ResponseTemplate::new($x));
        }
    };
}

// 2xx.
assert_no_retry!(200, StatusCode::OK);
assert_no_retry!(201, StatusCode::CREATED);
assert_no_retry!(202, StatusCode::ACCEPTED);
assert_no_retry!(203, StatusCode::NON_AUTHORITATIVE_INFORMATION);
assert_no_retry!(204, StatusCode::NO_CONTENT);
assert_no_retry!(205, StatusCode::RESET_CONTENT);
assert_no_retry!(206, StatusCode::PARTIAL_CONTENT);
assert_no_retry!(207, StatusCode::MULTI_STATUS);
assert_no_retry!(226, StatusCode::IM_USED);

// 3xx.
assert_no_retry!(300, StatusCode::MULTIPLE_CHOICES);
assert_no_retry!(301, StatusCode::MOVED_PERMANENTLY);
assert_no_retry!(302, StatusCode::FOUND);
assert_no_retry!(303, StatusCode::SEE_OTHER);
assert_no_retry!(304, StatusCode::NOT_MODIFIED);
assert_no_retry!(307, StatusCode::TEMPORARY_REDIRECT);
assert_no_retry!(308, StatusCode::PERMANENT_REDIRECT);

// 5xx.
assert_retry_succeeds!(500, StatusCode::OK);
assert_retry_succeeds!(501, StatusCode::OK);
assert_retry_succeeds!(502, StatusCode::OK);
assert_retry_succeeds!(503, StatusCode::OK);
assert_retry_succeeds!(504, StatusCode::OK);
assert_retry_succeeds!(505, StatusCode::OK);
assert_retry_succeeds!(506, StatusCode::OK);
assert_retry_succeeds!(507, StatusCode::OK);
assert_retry_succeeds!(508, StatusCode::OK);
assert_retry_succeeds!(510, StatusCode::OK);
assert_retry_succeeds!(511, StatusCode::OK);
// 4xx.
assert_no_retry!(400, StatusCode::BAD_REQUEST);
assert_no_retry!(401, StatusCode::UNAUTHORIZED);
assert_no_retry!(402, StatusCode::PAYMENT_REQUIRED);
assert_no_retry!(403, StatusCode::FORBIDDEN);
assert_no_retry!(404, StatusCode::NOT_FOUND);
assert_no_retry!(405, StatusCode::METHOD_NOT_ALLOWED);
assert_no_retry!(406, StatusCode::NOT_ACCEPTABLE);
assert_no_retry!(407, StatusCode::PROXY_AUTHENTICATION_REQUIRED);
assert_retry_succeeds!(408, StatusCode::OK);
assert_no_retry!(409, StatusCode::CONFLICT);
assert_no_retry!(410, StatusCode::GONE);
assert_no_retry!(411, StatusCode::LENGTH_REQUIRED);
assert_no_retry!(412, StatusCode::PRECONDITION_FAILED);
assert_no_retry!(413, StatusCode::PAYLOAD_TOO_LARGE);
assert_no_retry!(414, StatusCode::URI_TOO_LONG);
assert_no_retry!(415, StatusCode::UNSUPPORTED_MEDIA_TYPE);
assert_no_retry!(416, StatusCode::RANGE_NOT_SATISFIABLE);
assert_no_retry!(417, StatusCode::EXPECTATION_FAILED);
assert_no_retry!(418, StatusCode::IM_A_TEAPOT);
assert_no_retry!(421, StatusCode::MISDIRECTED_REQUEST);
assert_no_retry!(422, StatusCode::UNPROCESSABLE_ENTITY);
assert_no_retry!(423, StatusCode::LOCKED);
assert_no_retry!(424, StatusCode::FAILED_DEPENDENCY);
assert_no_retry!(426, StatusCode::UPGRADE_REQUIRED);
assert_no_retry!(428, StatusCode::PRECONDITION_REQUIRED);
assert_retry_succeeds!(429, StatusCode::OK);
assert_no_retry!(431, StatusCode::REQUEST_HEADER_FIELDS_TOO_LARGE);
assert_no_retry!(451, StatusCode::UNAVAILABLE_FOR_LEGAL_REASONS);

pub struct RetryTimeoutResponder(Arc<AtomicU32>, u32, std::time::Duration);

impl RetryTimeoutResponder {
    fn new(retries: u32, initial_timeout: std::time::Duration) -> Self {
        Self(Arc::new(AtomicU32::new(0)), retries, initial_timeout)
    }
}

impl Respond for RetryTimeoutResponder {
    fn respond(&self, _request: &wiremock::Request) -> ResponseTemplate {
        let mut retries = self.0.load(Ordering::SeqCst);
        retries += 1;
        self.0.store(retries, Ordering::SeqCst);

        if retries + 1 >= self.1 {
            ResponseTemplate::new(200)
        } else {
            ResponseTemplate::new(500).set_delay(self.2)
        }
    }
}

#[tokio::test]
async fn assert_retry_on_request_timeout() {
    let server = MockServer::start().await;
    Mock::given(method("GET"))
        .and(path("/foo"))
        .respond_with(RetryTimeoutResponder::new(
            3,
            std::time::Duration::from_millis(1000),
        ))
        .expect(2)
        .mount(&server)
        .await;

    let reqwest_client = Client::builder().build().unwrap();
    let client = ClientBuilder::new(reqwest_client)
        .with(RetryTransientMiddleware::new_with_policy(
            ExponentialBackoff::builder()
                .retry_bounds(
                    std::time::Duration::from_millis(30),
                    std::time::Duration::from_millis(100),
                )
                .build_with_max_retries(3),
        ))
        .build();

    let resp = client
        .get(format!("{}/foo", server.uri()))
        .timeout(std::time::Duration::from_millis(10))
        .send()
        .await
        .expect("call failed");

    assert_eq!(resp.status(), 200);
}

#[tokio::test]
async fn assert_retry_on_incomplete_message() {
    // Following the HTTP/1.1 specification (https://en.wikipedia.org/wiki/HTTP_message_body) a valid response contains:
    // - status line
    // - headers
    // - empty line
    // - optional message body
    //
    // After a few tries we have noticed that:
    // - "message_that_makes_no_sense" triggers a hyper::ParseError because the format is completely wrong
    // - "HTTP/1.1" triggers a hyper::IncompleteMessage because the format is correct until that point but misses mandatory parts
    let incomplete_message = "HTTP/1.1";
    let complete_message = "HTTP/1.1 200 OK\r\n\r\n";

    // create a SimpleServer that returns the correct response after 3 attempts.
    // the first 3 attempts are incomplete http response and internally they result in a [`hyper::Error(IncompleteMessage)`] error.
    let simple_server = SimpleServer::new(
        "127.0.0.1",
        None,
        vec![
            incomplete_message.to_string(),
            incomplete_message.to_string(),
            incomplete_message.to_string(),
            complete_message.to_string(),
        ],
    )
    .await
    .expect("Error when creating a simple server");

    let uri = simple_server.uri();

    tokio::spawn(simple_server.start());

    let reqwest_client = Client::builder().build().unwrap();
    let client = ClientBuilder::new(reqwest_client)
        .with(RetryTransientMiddleware::new_with_policy(
            ExponentialBackoff::builder()
                .retry_bounds(
                    std::time::Duration::from_millis(30),
                    std::time::Duration::from_millis(100),
                )
                .build_with_max_retries(3),
        ))
        .build();

    let resp = client
        .get(format!("{}/foo", uri))
        .timeout(std::time::Duration::from_millis(100))
        .send()
        .await
        .expect("call failed");

    assert_eq!(resp.status(), 200);
}

#[tokio::test]
async fn assert_retry_on_hyper_canceled() {
    let counter = Arc::new(AtomicI8::new(0));
    let mut simple_server = SimpleServer::new("127.0.0.1", None, vec![])
        .await
        .expect("Error when creating a simple server");
    simple_server.set_custom_handler(move |mut stream| {
        let counter = counter.clone();
        async move {
            let mut buffer = Vec::new();
            stream.read_buf(&mut buffer).await.unwrap();
            if counter.fetch_add(1, Ordering::SeqCst) > 1 {
                // This triggers hyper:Error(Canceled).
                let _res = stream
                    .into_std()
                    .unwrap()
                    .shutdown(std::net::Shutdown::Both);
            } else {
                let _res = stream.write("HTTP/1.1 200 OK\r\n\r\n".as_bytes()).await;
            }
            Ok(())
        }
        .boxed()
    });

    let uri = simple_server.uri();

    tokio::spawn(simple_server.start());

    let reqwest_client = Client::builder().build().unwrap();
    let client = ClientBuilder::new(reqwest_client)
        .with(RetryTransientMiddleware::new_with_policy(
            ExponentialBackoff::builder()
                .retry_bounds(
                    std::time::Duration::from_millis(30),
                    std::time::Duration::from_millis(100),
                )
                .build_with_max_retries(3),
        ))
        .build();

    let resp = client
        .get(format!("{}/foo", uri))
        .timeout(std::time::Duration::from_millis(100))
        .send()
        .await
        .expect("call failed");

    assert_eq!(resp.status(), 200);
}

#[tokio::test]
async fn assert_retry_on_connection_reset_by_peer() {
    let counter = Arc::new(AtomicI8::new(0));
    let mut simple_server = SimpleServer::new("127.0.0.1", None, vec![])
        .await
        .expect("Error when creating a simple server");
    simple_server.set_custom_handler(move |mut stream| {
        let counter = counter.clone();
        async move {
            let mut buffer = Vec::new();
            stream.read_buf(&mut buffer).await.unwrap();
            if counter.fetch_add(1, Ordering::SeqCst) > 1 {
                // This triggers hyper:Error(Io, io::Error(ConnectionReset)).
                drop(stream);
            } else {
                let _res = stream.write("HTTP/1.1 200 OK\r\n\r\n".as_bytes()).await;
            }
            Ok(())
        }
        .boxed()
    });

    let uri = simple_server.uri();

    tokio::spawn(simple_server.start());

    let reqwest_client = Client::builder().build().unwrap();
    let client = ClientBuilder::new(reqwest_client)
        .with(RetryTransientMiddleware::new_with_policy(
            ExponentialBackoff::builder()
                .retry_bounds(
                    std::time::Duration::from_millis(30),
                    std::time::Duration::from_millis(100),
                )
                .build_with_max_retries(3),
        ))
        .build();

    let resp = client
        .get(format!("{}/foo", uri))
        .timeout(std::time::Duration::from_millis(100))
        .send()
        .await
        .expect("call failed");

    assert_eq!(resp.status(), 200);
}