diff --git a/reqwest-middleware/Cargo.toml b/reqwest-middleware/Cargo.toml index 5b55510..c1034cb 100644 --- a/reqwest-middleware/Cargo.toml +++ b/reqwest-middleware/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "reqwest-middleware" -version = "0.2.0" +version = "0.3.0" authors = ["Rodrigo Gryzinski "] edition = "2018" description = "Wrapper around reqwest to allow for client middleware chains." @@ -18,6 +18,7 @@ reqwest = { version = "0.11", default-features = false, features = ["json", "mul serde = "1" task-local-extensions = "0.1.1" thiserror = "1" +tower = { version = "0.4", features = ["util"] } [dev-dependencies] reqwest = "0.11" diff --git a/reqwest-middleware/src/client.rs b/reqwest-middleware/src/client.rs index a8a4ac5..e0348d6 100644 --- a/reqwest-middleware/src/client.rs +++ b/reqwest-middleware/src/client.rs @@ -4,78 +4,81 @@ use reqwest::{Body, Client, IntoUrl, Method, Request, Response}; use serde::Serialize; use std::convert::TryFrom; use std::fmt::{self, Display}; -use std::sync::Arc; use std::time::Duration; use task_local_extensions::Extensions; +use tower::layer::util::{Identity, Stack}; +use tower::{Layer, Service, ServiceBuilder, ServiceExt}; use crate::error::Result; -use crate::middleware::{Middleware, Next}; -use crate::RequestInitialiser; +use crate::{Error, MiddlewareRequest, RequestInitialiser}; /// A `ClientBuilder` is used to build a [`ClientWithMiddleware`]. /// /// [`ClientWithMiddleware`]: crate::ClientWithMiddleware -pub struct ClientBuilder { +pub struct ClientBuilder { client: Client, - middleware_stack: Vec>, - initialiser_stack: Vec>, + middleware_stack: ServiceBuilder, + initialiser_stack: (), } -impl ClientBuilder { +impl ClientBuilder { pub fn new(client: Client) -> Self { ClientBuilder { client, - middleware_stack: Vec::new(), - initialiser_stack: Vec::new(), + middleware_stack: ServiceBuilder::new(), + initialiser_stack: (), } } +} +impl ClientBuilder { /// Convenience method to attach middleware. /// /// If you need to keep a reference to the middleware after attaching, use [`with_arc`]. /// /// [`with_arc`]: Self::with_arc - pub fn with(self, middleware: M) -> Self - where - M: Middleware, - { - self.with_arc(Arc::new(middleware)) + pub fn layer(self, layer: T) -> ClientBuilder> { + ClientBuilder { + client: self.client, + middleware_stack: self.middleware_stack.layer(layer), + initialiser_stack: (), + } } - /// Add middleware to the chain. [`with`] is more ergonomic if you don't need the `Arc`. - /// - /// [`with`]: Self::with - pub fn with_arc(mut self, middleware: Arc) -> Self { - self.middleware_stack.push(middleware); - self - } + // /// Add middleware to the chain. [`with`] is more ergonomic if you don't need the `Arc`. + // /// + // /// [`with`]: Self::with + // pub fn with_arc(mut self, middleware: Arc) -> Self { + // self.middleware_stack.push(middleware); + // self + // } - /// Convenience method to attach a request initialiser. - /// - /// If you need to keep a reference to the initialiser after attaching, use [`with_arc_init`]. - /// - /// [`with_arc_init`]: Self::with_arc_init - pub fn with_init(self, initialiser: I) -> Self - where - I: RequestInitialiser, - { - self.with_arc_init(Arc::new(initialiser)) - } + // /// Convenience method to attach a request initialiser. + // /// + // /// If you need to keep a reference to the initialiser after attaching, use [`with_arc_init`]. + // /// + // /// [`with_arc_init`]: Self::with_arc_init + // pub fn with_init(self, initialiser: I) -> Self + // where + // I: RequestInitialiser, + // { + // self.with_arc_init(Arc::new(initialiser)) + // } - /// Add a request initialiser to the chain. [`with_init`] is more ergonomic if you don't need the `Arc`. - /// - /// [`with_init`]: Self::with_init - pub fn with_arc_init(mut self, initialiser: Arc) -> Self { - self.initialiser_stack.push(initialiser); - self - } + // /// Add a request initialiser to the chain. [`with_init`] is more ergonomic if you don't need the `Arc`. + // /// + // /// [`with_init`]: Self::with_init + // pub fn with_arc_init(mut self, initialiser: Arc) -> Self { + // self.initialiser_stack.push(initialiser); + // self + // } /// Returns a `ClientWithMiddleware` using this builder configuration. - pub fn build(self) -> ClientWithMiddleware { + pub fn build(self) -> ClientWithMiddleware { ClientWithMiddleware { inner: self.client, - middleware_stack: self.middleware_stack.into_boxed_slice(), - initialiser_stack: self.initialiser_stack.into_boxed_slice(), + middleware_stack: self.middleware_stack, + initialiser_stack: self.initialiser_stack, } } } @@ -83,97 +86,85 @@ impl ClientBuilder { /// `ClientWithMiddleware` is a wrapper around [`reqwest::Client`] which runs middleware on every /// request. #[derive(Clone)] -pub struct ClientWithMiddleware { +pub struct ClientWithMiddleware { inner: reqwest::Client, - middleware_stack: Box<[Arc]>, - initialiser_stack: Box<[Arc]>, + middleware_stack: ServiceBuilder, + initialiser_stack: I, } -impl ClientWithMiddleware { - /// See [`ClientBuilder`] for a more ergonomic way to build `ClientWithMiddleware` instances. - pub fn new(client: Client, middleware_stack: T) -> Self - where - T: Into]>>, - { - ClientWithMiddleware { - inner: client, - middleware_stack: middleware_stack.into(), - // TODO(conradludgate) - allow downstream code to control this manually if desired - initialiser_stack: Box::new([]), - } - } +// impl> ClientWithMiddleware +// where +// M::Service: Service, +// { +// /// See [`ClientBuilder`] for a more ergonomic way to build `ClientWithMiddleware` instances. +// pub fn new(client: Client, middleware_stack: M) -> Self { +// ClientWithMiddleware { +// inner: client, +// middleware_stack, +// initialiser_stack: (), +// } +// } +// } +impl, I: RequestInitialiser> ClientWithMiddleware +where + M::Service: Service, +{ /// See [`Client::get`] - pub fn get(&self, url: U) -> RequestBuilder { + pub fn get(&self, url: U) -> RequestBuilder { self.request(Method::GET, url) } /// See [`Client::post`] - pub fn post(&self, url: U) -> RequestBuilder { + pub fn post(&self, url: U) -> RequestBuilder { self.request(Method::POST, url) } /// See [`Client::put`] - pub fn put(&self, url: U) -> RequestBuilder { + pub fn put(&self, url: U) -> RequestBuilder { self.request(Method::PUT, url) } /// See [`Client::patch`] - pub fn patch(&self, url: U) -> RequestBuilder { + pub fn patch(&self, url: U) -> RequestBuilder { self.request(Method::PATCH, url) } /// See [`Client::delete`] - pub fn delete(&self, url: U) -> RequestBuilder { + pub fn delete(&self, url: U) -> RequestBuilder { self.request(Method::DELETE, url) } /// See [`Client::head`] - pub fn head(&self, url: U) -> RequestBuilder { + pub fn head(&self, url: U) -> RequestBuilder { self.request(Method::HEAD, url) } /// See [`Client::request`] - pub fn request(&self, method: Method, url: U) -> RequestBuilder { - let req = RequestBuilder { - inner: self.inner.request(method, url), - client: self.clone(), - extensions: Extensions::new(), - }; - self.initialiser_stack - .iter() - .fold(req, |req, i| i.init(req)) - } - - /// See [`Client::execute`] - pub async fn execute(&self, req: Request) -> Result { - let mut ext = Extensions::new(); - self.execute_with_extensions(req, &mut ext).await - } - - /// Executes a request with initial [`Extensions`]. - pub async fn execute_with_extensions( - &self, - req: Request, - ext: &mut Extensions, - ) -> Result { - let next = Next::new(&self.inner, &self.middleware_stack); - next.run(req, ext).await - } -} - -/// Create a `ClientWithMiddleware` without any middleware. -impl From for ClientWithMiddleware { - fn from(client: Client) -> Self { - ClientWithMiddleware { - inner: client, - middleware_stack: Box::new([]), - initialiser_stack: Box::new([]), + pub fn request(&self, method: Method, url: U) -> RequestBuilder<'_, M, I> { + let mut extensions = Extensions::new(); + let request = self.inner.request(method, url); + let request = self.initialiser_stack.init(request, &mut extensions); + RequestBuilder { + inner: request, + client: self, + extensions, } } } -impl fmt::Debug for ClientWithMiddleware { +/// Create a `ClientWithMiddleware` without any middleware. +impl From for ClientWithMiddleware { + fn from(client: Client) -> Self { + ClientWithMiddleware { + inner: client, + middleware_stack: ServiceBuilder::new(), + initialiser_stack: (), + } + } +} + +impl fmt::Debug for ClientWithMiddleware { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { // skipping middleware_stack field for now f.debug_struct("ClientWithMiddleware") @@ -184,13 +175,37 @@ impl fmt::Debug for ClientWithMiddleware { /// This is a wrapper around [`reqwest::RequestBuilder`] exposing the same API. #[must_use = "RequestBuilder does nothing until you 'send' it"] -pub struct RequestBuilder { +pub struct RequestBuilder<'client, M, I> { inner: reqwest::RequestBuilder, - client: ClientWithMiddleware, + client: &'client ClientWithMiddleware, extensions: Extensions, } -impl RequestBuilder { +pub type BoxFuture<'a, T> = std::pin::Pin + Send + 'a>>; + +#[derive(Clone)] +pub struct ReqService(Client); + +impl Service for ReqService { + type Response = Response; + type Error = Error; + type Future = BoxFuture<'static, Result>; + + fn poll_ready(&mut self, _: &mut std::task::Context<'_>) -> std::task::Poll> { + std::task::Poll::Ready(Ok(())) + } + + fn call(&mut self, req: MiddlewareRequest) -> Self::Future { + let req = req.request; + let client = self.0.clone(); + Box::pin(async move { client.execute(req).await.map_err(Error::from) }) + } +} + +impl, I: RequestInitialiser> RequestBuilder<'_, M, I> +where + M::Service: Service, +{ pub fn header(self, key: K, value: V) -> Self where HeaderName: TryFrom, @@ -293,10 +308,19 @@ impl RequestBuilder { let Self { inner, client, - mut extensions, + extensions, } = self; let req = inner.build()?; - client.execute_with_extensions(req, &mut extensions).await + client + .middleware_stack + .service(ReqService(client.inner.clone())) + .oneshot(MiddlewareRequest { + request: req, + extensions, + }) + .await + + // client.execute_with_extensions(req, &mut extensions).await } /// Attempt to clone the RequestBuilder. @@ -309,13 +333,13 @@ impl RequestBuilder { pub fn try_clone(&self) -> Option { self.inner.try_clone().map(|inner| RequestBuilder { inner, - client: self.client.clone(), + client: self.client, extensions: Extensions::new(), }) } } -impl fmt::Debug for RequestBuilder { +impl fmt::Debug for RequestBuilder<'_, M, I> { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { // skipping middleware_stack field for now f.debug_struct("RequestBuilder") diff --git a/reqwest-middleware/src/lib.rs b/reqwest-middleware/src/lib.rs index a92243d..a4ea824 100644 --- a/reqwest-middleware/src/lib.rs +++ b/reqwest-middleware/src/lib.rs @@ -9,28 +9,44 @@ //! use reqwest::{Client, Request, Response}; //! use reqwest_middleware::{ClientBuilder, Middleware, Next, Result}; //! use task_local_extensions::Extensions; +//! use futures::FutureExt; +//! use std::task::{Context, Poll}; //! -//! struct LoggingMiddleware; +//! struct LoggingLayer; +//! struct LoggingService(S); +//! +//! impl tower::Layer for LoggingLayer { +//! type Service = LoggingService; +//! +//! fn layer(&self, inner: S) -> Self::Service { +//! LoggingService(inner) +//! } +//! } //! -//! #[async_trait::async_trait] -//! impl Middleware for LoggingMiddleware { -//! async fn handle( -//! &self, -//! req: Request, -//! extensions: &mut Extensions, -//! next: Next<'_>, -//! ) -> Result { -//! println!("Request started {:?}", req); -//! let res = next.run(req, extensions).await; -//! println!("Result: {:?}", res); -//! res +//! impl> tower::Service for LoggingService { +//! type Response = S::Response; +//! type Error = S::Error; +//! type Future = futures::BoxFuture<'static, Result>; +//! +//! fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { +//! self.0.poll_ready(cx) +//! } +//! +//! fn call(&mut self, req: MiddlewareRequest) -> Self::Future { +//! println!("Request started {:?}", &req.request); +//! let fut = self.0.call(req); +//! async { +//! let res = fut.await; +//! println!("Result: {:?}", res); +//! res +//! }.boxed() //! } //! } //! //! async fn run() { //! let reqwest_client = Client::builder().build().unwrap(); //! let client = ClientBuilder::new(reqwest_client) -//! .with(LoggingMiddleware) +//! .layer(LoggingLayer) //! .build(); //! let resp = client.get("https://truelayer.com").send().await.unwrap(); //! println!("TrueLayer page HTML: {}", resp.text().await.unwrap()); @@ -51,10 +67,13 @@ pub struct ReadmeDoctests; mod client; mod error; -mod middleware; mod req_init; -pub use client::{ClientBuilder, ClientWithMiddleware, RequestBuilder}; +pub use client::{ClientBuilder, ClientWithMiddleware, ReqService, RequestBuilder}; pub use error::{Error, Result}; -pub use middleware::{Middleware, Next}; pub use req_init::{Extension, RequestInitialiser}; + +pub struct MiddlewareRequest { + pub request: reqwest::Request, + pub extensions: task_local_extensions::Extensions, +} diff --git a/reqwest-middleware/src/middleware.rs b/reqwest-middleware/src/middleware.rs deleted file mode 100644 index 5120224..0000000 --- a/reqwest-middleware/src/middleware.rs +++ /dev/null @@ -1,100 +0,0 @@ -use reqwest::{Client, Request, Response}; -use std::sync::Arc; -use task_local_extensions::Extensions; - -use crate::error::{Error, Result}; - -/// When attached to a [`ClientWithMiddleware`] (generally using [`with`]), middleware is run -/// whenever the client issues a request, in the order it was attached. -/// -/// # Example -/// -/// ``` -/// use reqwest::{Client, Request, Response}; -/// use reqwest_middleware::{ClientBuilder, Middleware, Next, Result}; -/// use task_local_extensions::Extensions; -/// -/// struct TransparentMiddleware; -/// -/// #[async_trait::async_trait] -/// impl Middleware for TransparentMiddleware { -/// async fn handle( -/// &self, -/// req: Request, -/// extensions: &mut Extensions, -/// next: Next<'_>, -/// ) -> Result { -/// next.run(req, extensions).await -/// } -/// } -/// ``` -/// -/// [`ClientWithMiddleware`]: crate::ClientWithMiddleware -/// [`with`]: crate::ClientBuilder::with -#[async_trait::async_trait] -pub trait Middleware: 'static + Send + Sync { - /// Invoked with a request before sending it. If you want to continue processing the request, - /// you should explicitly call `next.run(req, extensions)`. - /// - /// If you need to forward data down the middleware stack, you can use the `extensions` - /// argument. - async fn handle( - &self, - req: Request, - extensions: &mut Extensions, - next: Next<'_>, - ) -> Result; -} - -#[async_trait::async_trait] -impl Middleware for F -where - F: Send - + Sync - + 'static - + for<'a> Fn(Request, &'a mut Extensions, Next<'a>) -> BoxFuture<'a, Result>, -{ - async fn handle( - &self, - req: Request, - extensions: &mut Extensions, - next: Next<'_>, - ) -> Result { - (self)(req, extensions, next).await - } -} - -/// Next encapsulates the remaining middleware chain to run in [`Middleware::handle`]. You can -/// forward the request down the chain with [`run`]. -/// -/// [`Middleware::handle`]: Middleware::handle -/// [`run`]: Self::run -#[derive(Clone)] -pub struct Next<'a> { - client: &'a Client, - middlewares: &'a [Arc], -} - -pub type BoxFuture<'a, T> = std::pin::Pin + Send + 'a>>; - -impl<'a> Next<'a> { - pub(crate) fn new(client: &'a Client, middlewares: &'a [Arc]) -> Self { - Next { - client, - middlewares, - } - } - - pub fn run( - mut self, - req: Request, - extensions: &'a mut Extensions, - ) -> BoxFuture<'a, Result> { - if let Some((current, rest)) = self.middlewares.split_first() { - self.middlewares = rest; - Box::pin(current.handle(req, extensions, self)) - } else { - Box::pin(async move { self.client.execute(req).await.map_err(Error::from) }) - } - } -} diff --git a/reqwest-middleware/src/req_init.rs b/reqwest-middleware/src/req_init.rs index 92c8167..6b15fd0 100644 --- a/reqwest-middleware/src/req_init.rs +++ b/reqwest-middleware/src/req_init.rs @@ -1,4 +1,5 @@ -use crate::RequestBuilder; +use reqwest::RequestBuilder; +use task_local_extensions::Extensions; /// When attached to a [`ClientWithMiddleware`] (generally using [`with_init`]), it is run /// whenever the client starts building a request, in the order it was attached. @@ -6,12 +7,12 @@ use crate::RequestBuilder; /// # Example /// /// ``` -/// use reqwest_middleware::{RequestInitialiser, RequestBuilder}; +/// use reqwest_middleware::{RequestInitialiser, MiddlewareRequest}; /// /// struct AuthInit; /// /// impl RequestInitialiser for AuthInit { -/// fn init(&self, req: RequestBuilder) -> RequestBuilder { +/// fn init(&self, req: MiddlewareRequest) -> MiddlewareRequest { /// req.bearer_auth("my_auth_token") /// } /// } @@ -20,18 +21,24 @@ use crate::RequestBuilder; /// [`ClientWithMiddleware`]: crate::ClientWithMiddleware /// [`with_init`]: crate::ClientBuilder::with_init pub trait RequestInitialiser: 'static + Send + Sync { - fn init(&self, req: RequestBuilder) -> RequestBuilder; + fn init(&self, req: RequestBuilder, ext: &mut Extensions) -> RequestBuilder; } -impl RequestInitialiser for F -where - F: Send + Sync + 'static + Fn(RequestBuilder) -> RequestBuilder, -{ - fn init(&self, req: RequestBuilder) -> RequestBuilder { - (self)(req) +impl RequestInitialiser for () { + fn init(&self, req: RequestBuilder, _: &mut Extensions) -> RequestBuilder { + req } } +// impl RequestInitialiser for F +// where +// F: Send + Sync + 'static + Fn(MiddlewareRequest) -> MiddlewareRequest, +// { +// fn init(&self, req: RequestBuilder, ext: &mut Extensions) -> RequestBuilder { +// (self)(req) +// } +// } + /// A middleware that inserts the value into the [`Extensions`](task_local_extensions::Extensions) during the call. /// /// This is a good way to inject extensions to middleware deeper in the stack @@ -78,7 +85,8 @@ where pub struct Extension(pub T); impl RequestInitialiser for Extension { - fn init(&self, req: RequestBuilder) -> RequestBuilder { - req.with_extension(self.0.clone()) + fn init(&self, req: RequestBuilder, ext: &mut Extensions) -> RequestBuilder { + ext.insert(self.0.clone()); + req } } diff --git a/reqwest-retry/Cargo.toml b/reqwest-retry/Cargo.toml index ef77147..a0098a7 100644 --- a/reqwest-retry/Cargo.toml +++ b/reqwest-retry/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "reqwest-retry" -version = "0.2.0" +version = "0.3.0" authors = ["Rodrigo Gryzinski "] edition = "2018" description = "Retry middleware for reqwest." @@ -10,7 +10,7 @@ keywords = ["reqwest", "http", "middleware", "retry"] categories = ["web-programming::http-client"] [dependencies] -reqwest-middleware = { version = "0.2.0", path = "../reqwest-middleware" } +reqwest-middleware = { version = "0.3.0", path = "../reqwest-middleware" } anyhow = "1" async-trait = "0.1.51" @@ -23,6 +23,8 @@ retry-policies = "0.1" task-local-extensions = "0.1.1" tokio = { version = "1.6", features = ["time"] } tracing = "0.1.26" +tower = { version = "0.4", features = ["retry"] } +pin-project-lite = "0.2" [dev-dependencies] async-std = { version = "1.10"} diff --git a/reqwest-retry/src/lib.rs b/reqwest-retry/src/lib.rs index bef3763..6cf81d3 100644 --- a/reqwest-retry/src/lib.rs +++ b/reqwest-retry/src/lib.rs @@ -13,7 +13,7 @@ //! // Retry up to 3 times with increasing intervals between attempts. //! let retry_policy = ExponentialBackoff::builder().build_with_max_retries(3); //! let client = ClientBuilder::new(reqwest::Client::new()) -//! .with(RetryTransientMiddleware::new_with_policy(retry_policy)) +//! .layer(RetryTransientMiddleware::new_with_policy(retry_policy)) //! .build(); //! //! client diff --git a/reqwest-retry/src/middleware.rs b/reqwest-retry/src/middleware.rs index f701b12..8ac29ee 100644 --- a/reqwest-retry/src/middleware.rs +++ b/reqwest-retry/src/middleware.rs @@ -1,15 +1,19 @@ //! `RetryTransientMiddleware` implements retrying requests on transient errors. +use std::pin::Pin; +use std::task::{ready, Context, Poll}; + use crate::retryable::Retryable; -use anyhow::anyhow; use chrono::Utc; -use reqwest::{Request, Response}; -use reqwest_middleware::{Error, Middleware, Next, Result}; +use futures::Future; +use pin_project_lite::pin_project; +use reqwest::Response; +use reqwest_middleware::{Error, MiddlewareRequest}; use retry_policies::RetryPolicy; use task_local_extensions::Extensions; - -/// We limit the number of retries to a maximum of `10` to avoid stack-overflow issues due to the recursion. -static MAXIMUM_NUMBER_OF_RETRIES: u32 = 10; +use tokio::time::Sleep; +use tower::retry::{Policy, Retry}; +use tower::{Layer, Service}; /// `RetryTransientMiddleware` offers retry logic for requests that fail in a transient manner /// and can be safely executed again. @@ -32,7 +36,7 @@ static MAXIMUM_NUMBER_OF_RETRIES: u32 = 10; /// }; /// /// let retry_transient_middleware = RetryTransientMiddleware::new_with_policy(retry_policy); -/// let client = ClientBuilder::new(Client::new()).with(retry_transient_middleware).build(); +/// let client = ClientBuilder::new(Client::new()).layer(retry_transient_middleware).build(); ///``` /// /// # Note @@ -58,76 +62,95 @@ impl RetryTransientMiddleware { } } -#[async_trait::async_trait] -impl Middleware for RetryTransientMiddleware { - async fn handle( - &self, - req: Request, - extensions: &mut Extensions, - next: Next<'_>, - ) -> Result { - // TODO: Ideally we should create a new instance of the `Extensions` map to pass - // downstream. This will guard against previous retries poluting `Extensions`. - // That is, we only return what's populated in the typemap for the last retry attempt - // and copy those into the the `global` Extensions map. - self.execute_with_retry(req, next, extensions).await +impl Layer for RetryTransientMiddleware +where + Svc: Service, +{ + type Service = Retry, Svc>; + + fn layer(&self, inner: Svc) -> Self::Service { + Retry::new( + TowerRetryPolicy { + n_past_retries: 0, + retry_policy: self.retry_policy.clone(), + }, + inner, + ) } } -impl RetryTransientMiddleware { - /// This function will try to execute the request, if it fails - /// with an error classified as transient it will call itself - /// to retry the request. - async fn execute_with_retry<'a>( - &'a self, - req: Request, - next: Next<'a>, - ext: &'a mut Extensions, - ) -> Result { - let mut n_past_retries = 0; - loop { - // Cloning the request object before-the-fact is not ideal.. - // However, if the body of the request is not static, e.g of type `Bytes`, - // the Clone operation should be of constant complexity and not O(N) - // since the byte abstraction is a shared pointer over a buffer. - let duplicate_request = req.try_clone().ok_or_else(|| { - Error::Middleware(anyhow!( - "Request object is not clonable. Are you passing a streaming body?".to_string() - )) - })?; +#[derive(Clone)] +pub struct TowerRetryPolicy { + n_past_retries: u32, + retry_policy: T, +} - let result = next.clone().run(duplicate_request, ext).await; +pin_project! { + pub struct RetryFuture + { + retry: Option>, + #[pin] + sleep: Sleep, + } +} - // We classify the response which will return None if not - // errors were returned. - break match Retryable::from_reqwest_response(&result) { - Some(retryable) - if retryable == Retryable::Transient - && n_past_retries < MAXIMUM_NUMBER_OF_RETRIES => - { - // If the response failed and the error type was transient - // we can safely try to retry the request. - let retry_decicion = self.retry_policy.should_retry(n_past_retries); - if let retry_policies::RetryDecision::Retry { execute_after } = retry_decicion { - let duration = (execute_after - Utc::now()) - .to_std() - .map_err(Error::middleware)?; - // Sleep the requested amount before we try again. - tracing::warn!( - "Retry attempt #{}. Sleeping {:?} before the next attempt", - n_past_retries, - duration - ); - tokio::time::sleep(duration).await; +impl Future for RetryFuture { + type Output = TowerRetryPolicy; - n_past_retries += 1; - continue; - } else { - result - } + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let this = self.project(); + ready!(this.sleep.poll(cx)); + Poll::Ready( + this.retry + .take() + .expect("poll should not be called more than once"), + ) + } +} + +impl Policy for TowerRetryPolicy { + type Future = RetryFuture; + + fn retry( + &self, + _req: &MiddlewareRequest, + result: std::result::Result<&Response, &Error>, + ) -> Option { + // We classify the response which will return None if not + // errors were returned. + match Retryable::from_reqwest_response(result) { + Some(Retryable::Transient) => { + // If the response failed and the error type was transient + // we can safely try to retry the request. + let retry_decicion = self.retry_policy.should_retry(self.n_past_retries); + if let retry_policies::RetryDecision::Retry { execute_after } = retry_decicion { + let duration = (execute_after - Utc::now()).to_std().ok()?; + // Sleep the requested amount before we try again. + tracing::warn!( + "Retry attempt #{}. Sleeping {:?} before the next attempt", + self.n_past_retries, + duration + ); + let sleep = tokio::time::sleep(duration); + Some(RetryFuture { + retry: Some(TowerRetryPolicy { + n_past_retries: self.n_past_retries + 1, + retry_policy: self.retry_policy.clone(), + }), + sleep, + }) + } else { + None } - Some(_) | None => result, - }; + } + Some(_) | None => None, } } + + fn clone_request(&self, req: &MiddlewareRequest) -> Option { + Some(MiddlewareRequest { + request: req.request.try_clone()?, + extensions: Extensions::new(), + }) + } } diff --git a/reqwest-retry/src/retryable.rs b/reqwest-retry/src/retryable.rs index d6e4cae..2a36deb 100644 --- a/reqwest-retry/src/retryable.rs +++ b/reqwest-retry/src/retryable.rs @@ -15,7 +15,7 @@ impl Retryable { /// /// Returns `None` if the response object does not contain any errors. /// - pub fn from_reqwest_response(res: &Result) -> Option { + pub fn from_reqwest_response(res: Result<&reqwest::Response, &Error>) -> Option { match res { Ok(success) => { let status = success.status(); diff --git a/reqwest-retry/tests/all/retry.rs b/reqwest-retry/tests/all/retry.rs index 9bfcbf3..923d635 100644 --- a/reqwest-retry/tests/all/retry.rs +++ b/reqwest-retry/tests/all/retry.rs @@ -48,7 +48,7 @@ macro_rules! assert_retry_succeeds_inner { let reqwest_client = Client::builder().build().unwrap(); let client = ClientBuilder::new(reqwest_client) - .with(RetryTransientMiddleware::new_with_policy( + .layer(RetryTransientMiddleware::new_with_policy( ExponentialBackoff { max_n_retries: retry_amount, max_retry_interval: std::time::Duration::from_millis(30), @@ -147,17 +147,6 @@ assert_retry_succeeds!(429, StatusCode::OK); assert_no_retry!(431, StatusCode::REQUEST_HEADER_FIELDS_TOO_LARGE); assert_no_retry!(451, StatusCode::UNAVAILABLE_FOR_LEGAL_REASONS); -// We assert that we cap retries at 10, which means that we will -// get 11 calls to the RetryResponder. -assert_retry_succeeds_inner!( - 500, - assert_maximum_retries_is_not_exceeded, - StatusCode::INTERNAL_SERVER_ERROR, - 100, - 11, - RetryResponder::new(100_u32, 500) -); - pub struct RetryTimeoutResponder(Arc, u32, std::time::Duration); impl RetryTimeoutResponder { @@ -195,7 +184,7 @@ async fn assert_retry_on_request_timeout() { let reqwest_client = Client::builder().build().unwrap(); let client = ClientBuilder::new(reqwest_client) - .with(RetryTransientMiddleware::new_with_policy( + .layer(RetryTransientMiddleware::new_with_policy( ExponentialBackoff { max_n_retries: 3, max_retry_interval: std::time::Duration::from_millis(100), @@ -250,7 +239,7 @@ async fn assert_retry_on_incomplete_message() { let reqwest_client = Client::builder().build().unwrap(); let client = ClientBuilder::new(reqwest_client) - .with(RetryTransientMiddleware::new_with_policy( + .layer(RetryTransientMiddleware::new_with_policy( ExponentialBackoff { max_n_retries: 3, max_retry_interval: std::time::Duration::from_millis(100), diff --git a/reqwest-tracing/Cargo.toml b/reqwest-tracing/Cargo.toml index ec2dbb2..9e214d2 100644 --- a/reqwest-tracing/Cargo.toml +++ b/reqwest-tracing/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "reqwest-tracing" -version = "0.4.0" +version = "0.5.0" authors = ["Rodrigo Gryzinski "] edition = "2018" description = "Opentracing middleware for reqwest." @@ -19,12 +19,14 @@ opentelemetry_0_18 = ["opentelemetry_0_18_pkg", "tracing-opentelemetry_0_18_pkg" [dependencies] -reqwest-middleware = { version = "0.2.0", path = "../reqwest-middleware" } +reqwest-middleware = { version = "0.3.0", path = "../reqwest-middleware" } async-trait = "0.1.51" reqwest = { version = "0.11", default-features = false } task-local-extensions = "0.1.1" tracing = "0.1.26" +tower = "0.4" +pin-project-lite = "0.2" opentelemetry_0_13_pkg = { package = "opentelemetry", version = "0.13", optional = true } opentelemetry_0_14_pkg = { package = "opentelemetry", version = "0.14", optional = true } diff --git a/reqwest-tracing/src/lib.rs b/reqwest-tracing/src/lib.rs index 2a6b766..749b817 100644 --- a/reqwest-tracing/src/lib.rs +++ b/reqwest-tracing/src/lib.rs @@ -103,3 +103,18 @@ pub use reqwest_otel_span_builder::{ #[doc(hidden)] pub mod reqwest_otel_span_macro; + +#[cfg(test)] +mod tests { + use crate::{TracingMiddleware, DefaultSpanBackend}; + use reqwest_middleware::ClientBuilder; + + #[tokio::test] + async fn compiles() { + let client = ClientBuilder::new(reqwest::Client::new()) + .layer(TracingMiddleware::::new()) + .build(); + let resp = client.get("http://example.com").send().await.unwrap(); + dbg!(resp); + } +} diff --git a/reqwest-tracing/src/middleware.rs b/reqwest-tracing/src/middleware.rs index a3c37d7..e9aa1ca 100644 --- a/reqwest-tracing/src/middleware.rs +++ b/reqwest-tracing/src/middleware.rs @@ -1,7 +1,10 @@ -use reqwest::{Request, Response}; -use reqwest_middleware::{Middleware, Next, Result}; -use task_local_extensions::Extensions; -use tracing::Instrument; +use std::{future::Future, task::ready}; + +use pin_project_lite::pin_project; +use reqwest::Response; +use reqwest_middleware::{Error, MiddlewareRequest}; +use tower::{Layer, Service}; +use tracing::Span; use crate::{DefaultSpanBackend, ReqwestOtelSpanBackend}; @@ -10,6 +13,8 @@ pub struct TracingMiddleware { span_backend: std::marker::PhantomData, } +impl Copy for TracingMiddleware {} + impl TracingMiddleware { pub fn new() -> TracingMiddleware { TracingMiddleware { @@ -30,38 +35,98 @@ impl Default for TracingMiddleware { } } -#[async_trait::async_trait] -impl Middleware for TracingMiddleware +impl Layer for TracingMiddleware where ReqwestOtelSpan: ReqwestOtelSpanBackend + Sync + Send + 'static, { - async fn handle( - &self, - req: Request, - extensions: &mut Extensions, - next: Next<'_>, - ) -> Result { - let request_span = ReqwestOtelSpan::on_request_start(&req, extensions); + type Service = TracingMiddlewareService; - let outcome_future = async { - // Adds tracing headers to the given request to propagate the OpenTelemetry context to downstream revivers of the request. - // Spans added by downstream consumers will be part of the same trace. - #[cfg(any( - feature = "opentelemetry_0_13", - feature = "opentelemetry_0_14", - feature = "opentelemetry_0_15", - feature = "opentelemetry_0_16", - feature = "opentelemetry_0_17", - feature = "opentelemetry_0_18", - ))] - let req = crate::otel::inject_opentelemetry_context_into_request(req); - - // Run the request - let outcome = next.run(req, extensions).await; - ReqwestOtelSpan::on_request_end(&request_span, &outcome, extensions); - outcome - }; - - outcome_future.instrument(request_span.clone()).await + fn layer(&self, inner: Svc) -> Self::Service { + TracingMiddlewareService { + service: inner, + layer: *self, + } + } +} + +/// Middleware Service for tracing requests using the current Opentelemetry Context. +pub struct TracingMiddlewareService { + layer: TracingMiddleware, + service: Svc, +} + +impl Service + for TracingMiddlewareService +where + ReqwestOtelSpan: ReqwestOtelSpanBackend + Sync + Send + 'static, + Svc: Service, +{ + type Response = Response; + type Error = Error; + type Future = TracingMiddlewareFuture; + + fn poll_ready( + &mut self, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + self.service.poll_ready(cx) + } + + fn call(&mut self, req: MiddlewareRequest) -> Self::Future { + let MiddlewareRequest { + request, + mut extensions, + } = req; + let request_span = ReqwestOtelSpan::on_request_start(&request, &mut extensions); + // Adds tracing headers to the given request to propagate the OpenTelemetry context to downstream revivers of the request. + // Spans added by downstream consumers will be part of the same trace. + #[cfg(any( + feature = "opentelemetry_0_13", + feature = "opentelemetry_0_14", + feature = "opentelemetry_0_15", + feature = "opentelemetry_0_16", + feature = "opentelemetry_0_17", + feature = "opentelemetry_0_18", + ))] + let request = crate::otel::inject_opentelemetry_context_into_request(request); + + let future = self.service.call(MiddlewareRequest { + request, + extensions, + }); + + TracingMiddlewareFuture { + layer: self.layer, + span: request_span, + future, + } + } +} + +pin_project!( + pub struct TracingMiddlewareFuture { + layer: TracingMiddleware, + span: Span, + #[pin] + future: F, + } +); + +impl>> Future + for TracingMiddlewareFuture +{ + type Output = F::Output; + + fn poll( + self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll { + let this = self.project(); + let outcome = { + let _guard = this.span.enter(); + ready!(this.future.poll(cx)) + }; + S::on_request_end(this.span, &outcome); + std::task::Poll::Ready(outcome) } } diff --git a/reqwest-tracing/src/reqwest_otel_span_builder.rs b/reqwest-tracing/src/reqwest_otel_span_builder.rs index 095b9b5..6fde890 100644 --- a/reqwest-tracing/src/reqwest_otel_span_builder.rs +++ b/reqwest-tracing/src/reqwest_otel_span_builder.rs @@ -44,7 +44,7 @@ pub trait ReqwestOtelSpanBackend { fn on_request_start(req: &Request, extension: &mut Extensions) -> Span; /// Runs after the request call has executed. - fn on_request_end(span: &Span, outcome: &Result, extension: &mut Extensions); + fn on_request_end(span: &Span, outcome: &Result); } /// Populates default success/failure fields for a given [`reqwest_otel_span!`] span. @@ -103,7 +103,7 @@ impl ReqwestOtelSpanBackend for DefaultSpanBackend { reqwest_otel_span!(name = name, req) } - fn on_request_end(span: &Span, outcome: &Result, _: &mut Extensions) { + fn on_request_end(span: &Span, outcome: &Result) { default_on_request_end(span, outcome) } } @@ -128,7 +128,7 @@ impl ReqwestOtelSpanBackend for SpanBackendWithUrl { reqwest_otel_span!(name = name, req, http.url = %remove_credentials(req.url())) } - fn on_request_end(span: &Span, outcome: &Result, _: &mut Extensions) { + fn on_request_end(span: &Span, outcome: &Result) { default_on_request_end(span, outcome) } }