expierments with tower

This commit is contained in:
Conrad Ludgate 2022-11-15 14:53:23 +00:00
parent 289bb0452c
commit 6eaa2365ed
No known key found for this signature in database
GPG key ID: 197E3CACA1C980B5
14 changed files with 411 additions and 363 deletions

View file

@ -1,6 +1,6 @@
[package] [package]
name = "reqwest-middleware" name = "reqwest-middleware"
version = "0.2.0" version = "0.3.0"
authors = ["Rodrigo Gryzinski <rodrigo.gryzinski@truelayer.com>"] authors = ["Rodrigo Gryzinski <rodrigo.gryzinski@truelayer.com>"]
edition = "2018" edition = "2018"
description = "Wrapper around reqwest to allow for client middleware chains." description = "Wrapper around reqwest to allow for client middleware chains."
@ -18,6 +18,7 @@ reqwest = { version = "0.11", default-features = false, features = ["json", "mul
serde = "1" serde = "1"
task-local-extensions = "0.1.1" task-local-extensions = "0.1.1"
thiserror = "1" thiserror = "1"
tower = { version = "0.4", features = ["util"] }
[dev-dependencies] [dev-dependencies]
reqwest = "0.11" reqwest = "0.11"

View file

@ -4,78 +4,81 @@ use reqwest::{Body, Client, IntoUrl, Method, Request, Response};
use serde::Serialize; use serde::Serialize;
use std::convert::TryFrom; use std::convert::TryFrom;
use std::fmt::{self, Display}; use std::fmt::{self, Display};
use std::sync::Arc;
use std::time::Duration; use std::time::Duration;
use task_local_extensions::Extensions; use task_local_extensions::Extensions;
use tower::layer::util::{Identity, Stack};
use tower::{Layer, Service, ServiceBuilder, ServiceExt};
use crate::error::Result; use crate::error::Result;
use crate::middleware::{Middleware, Next}; use crate::{Error, MiddlewareRequest, RequestInitialiser};
use crate::RequestInitialiser;
/// A `ClientBuilder` is used to build a [`ClientWithMiddleware`]. /// A `ClientBuilder` is used to build a [`ClientWithMiddleware`].
/// ///
/// [`ClientWithMiddleware`]: crate::ClientWithMiddleware /// [`ClientWithMiddleware`]: crate::ClientWithMiddleware
pub struct ClientBuilder { pub struct ClientBuilder<M> {
client: Client, client: Client,
middleware_stack: Vec<Arc<dyn Middleware>>, middleware_stack: ServiceBuilder<M>,
initialiser_stack: Vec<Arc<dyn RequestInitialiser>>, initialiser_stack: (),
} }
impl ClientBuilder { impl ClientBuilder<Identity> {
pub fn new(client: Client) -> Self { pub fn new(client: Client) -> Self {
ClientBuilder { ClientBuilder {
client, client,
middleware_stack: Vec::new(), middleware_stack: ServiceBuilder::new(),
initialiser_stack: Vec::new(), initialiser_stack: (),
} }
} }
}
impl<L> ClientBuilder<L> {
/// Convenience method to attach middleware. /// Convenience method to attach middleware.
/// ///
/// If you need to keep a reference to the middleware after attaching, use [`with_arc`]. /// If you need to keep a reference to the middleware after attaching, use [`with_arc`].
/// ///
/// [`with_arc`]: Self::with_arc /// [`with_arc`]: Self::with_arc
pub fn with<M>(self, middleware: M) -> Self pub fn layer<T>(self, layer: T) -> ClientBuilder<Stack<T, L>> {
where ClientBuilder {
M: Middleware, client: self.client,
{ middleware_stack: self.middleware_stack.layer(layer),
self.with_arc(Arc::new(middleware)) initialiser_stack: (),
}
} }
/// Add middleware to the chain. [`with`] is more ergonomic if you don't need the `Arc`. // /// Add middleware to the chain. [`with`] is more ergonomic if you don't need the `Arc`.
/// // ///
/// [`with`]: Self::with // /// [`with`]: Self::with
pub fn with_arc(mut self, middleware: Arc<dyn Middleware>) -> Self { // pub fn with_arc(mut self, middleware: Arc<dyn Middleware>) -> Self {
self.middleware_stack.push(middleware); // self.middleware_stack.push(middleware);
self // self
} // }
/// Convenience method to attach a request initialiser. // /// Convenience method to attach a request initialiser.
/// // ///
/// If you need to keep a reference to the initialiser after attaching, use [`with_arc_init`]. // /// If you need to keep a reference to the initialiser after attaching, use [`with_arc_init`].
/// // ///
/// [`with_arc_init`]: Self::with_arc_init // /// [`with_arc_init`]: Self::with_arc_init
pub fn with_init<I>(self, initialiser: I) -> Self // pub fn with_init<I>(self, initialiser: I) -> Self
where // where
I: RequestInitialiser, // I: RequestInitialiser,
{ // {
self.with_arc_init(Arc::new(initialiser)) // self.with_arc_init(Arc::new(initialiser))
} // }
/// Add a request initialiser to the chain. [`with_init`] is more ergonomic if you don't need the `Arc`. // /// Add a request initialiser to the chain. [`with_init`] is more ergonomic if you don't need the `Arc`.
/// // ///
/// [`with_init`]: Self::with_init // /// [`with_init`]: Self::with_init
pub fn with_arc_init(mut self, initialiser: Arc<dyn RequestInitialiser>) -> Self { // pub fn with_arc_init(mut self, initialiser: Arc<dyn RequestInitialiser>) -> Self {
self.initialiser_stack.push(initialiser); // self.initialiser_stack.push(initialiser);
self // self
} // }
/// Returns a `ClientWithMiddleware` using this builder configuration. /// Returns a `ClientWithMiddleware` using this builder configuration.
pub fn build(self) -> ClientWithMiddleware { pub fn build(self) -> ClientWithMiddleware<L, ()> {
ClientWithMiddleware { ClientWithMiddleware {
inner: self.client, inner: self.client,
middleware_stack: self.middleware_stack.into_boxed_slice(), middleware_stack: self.middleware_stack,
initialiser_stack: self.initialiser_stack.into_boxed_slice(), initialiser_stack: self.initialiser_stack,
} }
} }
} }
@ -83,97 +86,85 @@ impl ClientBuilder {
/// `ClientWithMiddleware` is a wrapper around [`reqwest::Client`] which runs middleware on every /// `ClientWithMiddleware` is a wrapper around [`reqwest::Client`] which runs middleware on every
/// request. /// request.
#[derive(Clone)] #[derive(Clone)]
pub struct ClientWithMiddleware { pub struct ClientWithMiddleware<M, I> {
inner: reqwest::Client, inner: reqwest::Client,
middleware_stack: Box<[Arc<dyn Middleware>]>, middleware_stack: ServiceBuilder<M>,
initialiser_stack: Box<[Arc<dyn RequestInitialiser>]>, initialiser_stack: I,
} }
impl ClientWithMiddleware { // impl<M: Layer<ReqService>> ClientWithMiddleware<M, ()>
/// See [`ClientBuilder`] for a more ergonomic way to build `ClientWithMiddleware` instances. // where
pub fn new<T>(client: Client, middleware_stack: T) -> Self // M::Service: Service<MiddlewareRequest, Response = Response, Error = Error>,
where // {
T: Into<Box<[Arc<dyn Middleware>]>>, // /// See [`ClientBuilder`] for a more ergonomic way to build `ClientWithMiddleware` instances.
{ // pub fn new(client: Client, middleware_stack: M) -> Self {
ClientWithMiddleware { // ClientWithMiddleware {
inner: client, // inner: client,
middleware_stack: middleware_stack.into(), // middleware_stack,
// TODO(conradludgate) - allow downstream code to control this manually if desired // initialiser_stack: (),
initialiser_stack: Box::new([]), // }
} // }
} // }
impl<M: Layer<ReqService>, I: RequestInitialiser> ClientWithMiddleware<M, I>
where
M::Service: Service<MiddlewareRequest, Response = Response, Error = Error>,
{
/// See [`Client::get`] /// See [`Client::get`]
pub fn get<U: IntoUrl>(&self, url: U) -> RequestBuilder { pub fn get<U: IntoUrl>(&self, url: U) -> RequestBuilder<M, I> {
self.request(Method::GET, url) self.request(Method::GET, url)
} }
/// See [`Client::post`] /// See [`Client::post`]
pub fn post<U: IntoUrl>(&self, url: U) -> RequestBuilder { pub fn post<U: IntoUrl>(&self, url: U) -> RequestBuilder<M, I> {
self.request(Method::POST, url) self.request(Method::POST, url)
} }
/// See [`Client::put`] /// See [`Client::put`]
pub fn put<U: IntoUrl>(&self, url: U) -> RequestBuilder { pub fn put<U: IntoUrl>(&self, url: U) -> RequestBuilder<M, I> {
self.request(Method::PUT, url) self.request(Method::PUT, url)
} }
/// See [`Client::patch`] /// See [`Client::patch`]
pub fn patch<U: IntoUrl>(&self, url: U) -> RequestBuilder { pub fn patch<U: IntoUrl>(&self, url: U) -> RequestBuilder<M, I> {
self.request(Method::PATCH, url) self.request(Method::PATCH, url)
} }
/// See [`Client::delete`] /// See [`Client::delete`]
pub fn delete<U: IntoUrl>(&self, url: U) -> RequestBuilder { pub fn delete<U: IntoUrl>(&self, url: U) -> RequestBuilder<M, I> {
self.request(Method::DELETE, url) self.request(Method::DELETE, url)
} }
/// See [`Client::head`] /// See [`Client::head`]
pub fn head<U: IntoUrl>(&self, url: U) -> RequestBuilder { pub fn head<U: IntoUrl>(&self, url: U) -> RequestBuilder<M, I> {
self.request(Method::HEAD, url) self.request(Method::HEAD, url)
} }
/// See [`Client::request`] /// See [`Client::request`]
pub fn request<U: IntoUrl>(&self, method: Method, url: U) -> RequestBuilder { pub fn request<U: IntoUrl>(&self, method: Method, url: U) -> RequestBuilder<'_, M, I> {
let req = RequestBuilder { let mut extensions = Extensions::new();
inner: self.inner.request(method, url), let request = self.inner.request(method, url);
client: self.clone(), let request = self.initialiser_stack.init(request, &mut extensions);
extensions: Extensions::new(), RequestBuilder {
}; inner: request,
self.initialiser_stack client: self,
.iter() extensions,
.fold(req, |req, i| i.init(req))
}
/// See [`Client::execute`]
pub async fn execute(&self, req: Request) -> Result<Response> {
let mut ext = Extensions::new();
self.execute_with_extensions(req, &mut ext).await
}
/// Executes a request with initial [`Extensions`].
pub async fn execute_with_extensions(
&self,
req: Request,
ext: &mut Extensions,
) -> Result<Response> {
let next = Next::new(&self.inner, &self.middleware_stack);
next.run(req, ext).await
}
}
/// Create a `ClientWithMiddleware` without any middleware.
impl From<Client> for ClientWithMiddleware {
fn from(client: Client) -> Self {
ClientWithMiddleware {
inner: client,
middleware_stack: Box::new([]),
initialiser_stack: Box::new([]),
} }
} }
} }
impl fmt::Debug for ClientWithMiddleware { /// Create a `ClientWithMiddleware` without any middleware.
impl From<Client> for ClientWithMiddleware<Identity, ()> {
fn from(client: Client) -> Self {
ClientWithMiddleware {
inner: client,
middleware_stack: ServiceBuilder::new(),
initialiser_stack: (),
}
}
}
impl<M, I> fmt::Debug for ClientWithMiddleware<M, I> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
// skipping middleware_stack field for now // skipping middleware_stack field for now
f.debug_struct("ClientWithMiddleware") f.debug_struct("ClientWithMiddleware")
@ -184,13 +175,37 @@ impl fmt::Debug for ClientWithMiddleware {
/// This is a wrapper around [`reqwest::RequestBuilder`] exposing the same API. /// This is a wrapper around [`reqwest::RequestBuilder`] exposing the same API.
#[must_use = "RequestBuilder does nothing until you 'send' it"] #[must_use = "RequestBuilder does nothing until you 'send' it"]
pub struct RequestBuilder { pub struct RequestBuilder<'client, M, I> {
inner: reqwest::RequestBuilder, inner: reqwest::RequestBuilder,
client: ClientWithMiddleware, client: &'client ClientWithMiddleware<M, I>,
extensions: Extensions, extensions: Extensions,
} }
impl RequestBuilder { pub type BoxFuture<'a, T> = std::pin::Pin<Box<dyn std::future::Future<Output = T> + Send + 'a>>;
#[derive(Clone)]
pub struct ReqService(Client);
impl Service<MiddlewareRequest> for ReqService {
type Response = Response;
type Error = Error;
type Future = BoxFuture<'static, Result<Response>>;
fn poll_ready(&mut self, _: &mut std::task::Context<'_>) -> std::task::Poll<Result<()>> {
std::task::Poll::Ready(Ok(()))
}
fn call(&mut self, req: MiddlewareRequest) -> Self::Future {
let req = req.request;
let client = self.0.clone();
Box::pin(async move { client.execute(req).await.map_err(Error::from) })
}
}
impl<M: Layer<ReqService>, I: RequestInitialiser> RequestBuilder<'_, M, I>
where
M::Service: Service<MiddlewareRequest, Response = Response, Error = Error>,
{
pub fn header<K, V>(self, key: K, value: V) -> Self pub fn header<K, V>(self, key: K, value: V) -> Self
where where
HeaderName: TryFrom<K>, HeaderName: TryFrom<K>,
@ -293,10 +308,19 @@ impl RequestBuilder {
let Self { let Self {
inner, inner,
client, client,
mut extensions, extensions,
} = self; } = self;
let req = inner.build()?; let req = inner.build()?;
client.execute_with_extensions(req, &mut extensions).await client
.middleware_stack
.service(ReqService(client.inner.clone()))
.oneshot(MiddlewareRequest {
request: req,
extensions,
})
.await
// client.execute_with_extensions(req, &mut extensions).await
} }
/// Attempt to clone the RequestBuilder. /// Attempt to clone the RequestBuilder.
@ -309,13 +333,13 @@ impl RequestBuilder {
pub fn try_clone(&self) -> Option<Self> { pub fn try_clone(&self) -> Option<Self> {
self.inner.try_clone().map(|inner| RequestBuilder { self.inner.try_clone().map(|inner| RequestBuilder {
inner, inner,
client: self.client.clone(), client: self.client,
extensions: Extensions::new(), extensions: Extensions::new(),
}) })
} }
} }
impl fmt::Debug for RequestBuilder { impl<M, I> fmt::Debug for RequestBuilder<'_, M, I> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
// skipping middleware_stack field for now // skipping middleware_stack field for now
f.debug_struct("RequestBuilder") f.debug_struct("RequestBuilder")

View file

@ -9,28 +9,44 @@
//! use reqwest::{Client, Request, Response}; //! use reqwest::{Client, Request, Response};
//! use reqwest_middleware::{ClientBuilder, Middleware, Next, Result}; //! use reqwest_middleware::{ClientBuilder, Middleware, Next, Result};
//! use task_local_extensions::Extensions; //! use task_local_extensions::Extensions;
//! use futures::FutureExt;
//! use std::task::{Context, Poll};
//! //!
//! struct LoggingMiddleware; //! struct LoggingLayer;
//! struct LoggingService<S>(S);
//!
//! impl<S> tower::Layer<S> for LoggingLayer {
//! type Service = LoggingService<S>;
//!
//! fn layer(&self, inner: S) -> Self::Service {
//! LoggingService(inner)
//! }
//! }
//! //!
//! #[async_trait::async_trait] //! impl<S: tower::Service<MiddlewareRequest>> tower::Service<MiddlewareRequest> for LoggingService<S> {
//! impl Middleware for LoggingMiddleware { //! type Response = S::Response;
//! async fn handle( //! type Error = S::Error;
//! &self, //! type Future = futures::BoxFuture<'static, Result<S::Response, S::Error>>;
//! req: Request, //!
//! extensions: &mut Extensions, //! fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
//! next: Next<'_>, //! self.0.poll_ready(cx)
//! ) -> Result<Response> { //! }
//! println!("Request started {:?}", req); //!
//! let res = next.run(req, extensions).await; //! fn call(&mut self, req: MiddlewareRequest) -> Self::Future {
//! println!("Result: {:?}", res); //! println!("Request started {:?}", &req.request);
//! res //! let fut = self.0.call(req);
//! async {
//! let res = fut.await;
//! println!("Result: {:?}", res);
//! res
//! }.boxed()
//! } //! }
//! } //! }
//! //!
//! async fn run() { //! async fn run() {
//! let reqwest_client = Client::builder().build().unwrap(); //! let reqwest_client = Client::builder().build().unwrap();
//! let client = ClientBuilder::new(reqwest_client) //! let client = ClientBuilder::new(reqwest_client)
//! .with(LoggingMiddleware) //! .layer(LoggingLayer)
//! .build(); //! .build();
//! let resp = client.get("https://truelayer.com").send().await.unwrap(); //! let resp = client.get("https://truelayer.com").send().await.unwrap();
//! println!("TrueLayer page HTML: {}", resp.text().await.unwrap()); //! println!("TrueLayer page HTML: {}", resp.text().await.unwrap());
@ -51,10 +67,13 @@ pub struct ReadmeDoctests;
mod client; mod client;
mod error; mod error;
mod middleware;
mod req_init; mod req_init;
pub use client::{ClientBuilder, ClientWithMiddleware, RequestBuilder}; pub use client::{ClientBuilder, ClientWithMiddleware, ReqService, RequestBuilder};
pub use error::{Error, Result}; pub use error::{Error, Result};
pub use middleware::{Middleware, Next};
pub use req_init::{Extension, RequestInitialiser}; pub use req_init::{Extension, RequestInitialiser};
pub struct MiddlewareRequest {
pub request: reqwest::Request,
pub extensions: task_local_extensions::Extensions,
}

View file

@ -1,100 +0,0 @@
use reqwest::{Client, Request, Response};
use std::sync::Arc;
use task_local_extensions::Extensions;
use crate::error::{Error, Result};
/// When attached to a [`ClientWithMiddleware`] (generally using [`with`]), middleware is run
/// whenever the client issues a request, in the order it was attached.
///
/// # Example
///
/// ```
/// use reqwest::{Client, Request, Response};
/// use reqwest_middleware::{ClientBuilder, Middleware, Next, Result};
/// use task_local_extensions::Extensions;
///
/// struct TransparentMiddleware;
///
/// #[async_trait::async_trait]
/// impl Middleware for TransparentMiddleware {
/// async fn handle(
/// &self,
/// req: Request,
/// extensions: &mut Extensions,
/// next: Next<'_>,
/// ) -> Result<Response> {
/// next.run(req, extensions).await
/// }
/// }
/// ```
///
/// [`ClientWithMiddleware`]: crate::ClientWithMiddleware
/// [`with`]: crate::ClientBuilder::with
#[async_trait::async_trait]
pub trait Middleware: 'static + Send + Sync {
/// Invoked with a request before sending it. If you want to continue processing the request,
/// you should explicitly call `next.run(req, extensions)`.
///
/// If you need to forward data down the middleware stack, you can use the `extensions`
/// argument.
async fn handle(
&self,
req: Request,
extensions: &mut Extensions,
next: Next<'_>,
) -> Result<Response>;
}
#[async_trait::async_trait]
impl<F> Middleware for F
where
F: Send
+ Sync
+ 'static
+ for<'a> Fn(Request, &'a mut Extensions, Next<'a>) -> BoxFuture<'a, Result<Response>>,
{
async fn handle(
&self,
req: Request,
extensions: &mut Extensions,
next: Next<'_>,
) -> Result<Response> {
(self)(req, extensions, next).await
}
}
/// Next encapsulates the remaining middleware chain to run in [`Middleware::handle`]. You can
/// forward the request down the chain with [`run`].
///
/// [`Middleware::handle`]: Middleware::handle
/// [`run`]: Self::run
#[derive(Clone)]
pub struct Next<'a> {
client: &'a Client,
middlewares: &'a [Arc<dyn Middleware>],
}
pub type BoxFuture<'a, T> = std::pin::Pin<Box<dyn std::future::Future<Output = T> + Send + 'a>>;
impl<'a> Next<'a> {
pub(crate) fn new(client: &'a Client, middlewares: &'a [Arc<dyn Middleware>]) -> Self {
Next {
client,
middlewares,
}
}
pub fn run(
mut self,
req: Request,
extensions: &'a mut Extensions,
) -> BoxFuture<'a, Result<Response>> {
if let Some((current, rest)) = self.middlewares.split_first() {
self.middlewares = rest;
Box::pin(current.handle(req, extensions, self))
} else {
Box::pin(async move { self.client.execute(req).await.map_err(Error::from) })
}
}
}

View file

@ -1,4 +1,5 @@
use crate::RequestBuilder; use reqwest::RequestBuilder;
use task_local_extensions::Extensions;
/// When attached to a [`ClientWithMiddleware`] (generally using [`with_init`]), it is run /// When attached to a [`ClientWithMiddleware`] (generally using [`with_init`]), it is run
/// whenever the client starts building a request, in the order it was attached. /// whenever the client starts building a request, in the order it was attached.
@ -6,12 +7,12 @@ use crate::RequestBuilder;
/// # Example /// # Example
/// ///
/// ``` /// ```
/// use reqwest_middleware::{RequestInitialiser, RequestBuilder}; /// use reqwest_middleware::{RequestInitialiser, MiddlewareRequest};
/// ///
/// struct AuthInit; /// struct AuthInit;
/// ///
/// impl RequestInitialiser for AuthInit { /// impl RequestInitialiser for AuthInit {
/// fn init(&self, req: RequestBuilder) -> RequestBuilder { /// fn init(&self, req: MiddlewareRequest) -> MiddlewareRequest {
/// req.bearer_auth("my_auth_token") /// req.bearer_auth("my_auth_token")
/// } /// }
/// } /// }
@ -20,18 +21,24 @@ use crate::RequestBuilder;
/// [`ClientWithMiddleware`]: crate::ClientWithMiddleware /// [`ClientWithMiddleware`]: crate::ClientWithMiddleware
/// [`with_init`]: crate::ClientBuilder::with_init /// [`with_init`]: crate::ClientBuilder::with_init
pub trait RequestInitialiser: 'static + Send + Sync { pub trait RequestInitialiser: 'static + Send + Sync {
fn init(&self, req: RequestBuilder) -> RequestBuilder; fn init(&self, req: RequestBuilder, ext: &mut Extensions) -> RequestBuilder;
} }
impl<F> RequestInitialiser for F impl RequestInitialiser for () {
where fn init(&self, req: RequestBuilder, _: &mut Extensions) -> RequestBuilder {
F: Send + Sync + 'static + Fn(RequestBuilder) -> RequestBuilder, req
{
fn init(&self, req: RequestBuilder) -> RequestBuilder {
(self)(req)
} }
} }
// impl<F> RequestInitialiser for F
// where
// F: Send + Sync + 'static + Fn(MiddlewareRequest) -> MiddlewareRequest,
// {
// fn init(&self, req: RequestBuilder, ext: &mut Extensions) -> RequestBuilder {
// (self)(req)
// }
// }
/// A middleware that inserts the value into the [`Extensions`](task_local_extensions::Extensions) during the call. /// A middleware that inserts the value into the [`Extensions`](task_local_extensions::Extensions) during the call.
/// ///
/// This is a good way to inject extensions to middleware deeper in the stack /// This is a good way to inject extensions to middleware deeper in the stack
@ -78,7 +85,8 @@ where
pub struct Extension<T>(pub T); pub struct Extension<T>(pub T);
impl<T: Send + Sync + Clone + 'static> RequestInitialiser for Extension<T> { impl<T: Send + Sync + Clone + 'static> RequestInitialiser for Extension<T> {
fn init(&self, req: RequestBuilder) -> RequestBuilder { fn init(&self, req: RequestBuilder, ext: &mut Extensions) -> RequestBuilder {
req.with_extension(self.0.clone()) ext.insert(self.0.clone());
req
} }
} }

View file

@ -1,6 +1,6 @@
[package] [package]
name = "reqwest-retry" name = "reqwest-retry"
version = "0.2.0" version = "0.3.0"
authors = ["Rodrigo Gryzinski <rodrigo.gryzinski@truelayer.com>"] authors = ["Rodrigo Gryzinski <rodrigo.gryzinski@truelayer.com>"]
edition = "2018" edition = "2018"
description = "Retry middleware for reqwest." description = "Retry middleware for reqwest."
@ -10,7 +10,7 @@ keywords = ["reqwest", "http", "middleware", "retry"]
categories = ["web-programming::http-client"] categories = ["web-programming::http-client"]
[dependencies] [dependencies]
reqwest-middleware = { version = "0.2.0", path = "../reqwest-middleware" } reqwest-middleware = { version = "0.3.0", path = "../reqwest-middleware" }
anyhow = "1" anyhow = "1"
async-trait = "0.1.51" async-trait = "0.1.51"
@ -23,6 +23,8 @@ retry-policies = "0.1"
task-local-extensions = "0.1.1" task-local-extensions = "0.1.1"
tokio = { version = "1.6", features = ["time"] } tokio = { version = "1.6", features = ["time"] }
tracing = "0.1.26" tracing = "0.1.26"
tower = { version = "0.4", features = ["retry"] }
pin-project-lite = "0.2"
[dev-dependencies] [dev-dependencies]
async-std = { version = "1.10"} async-std = { version = "1.10"}

View file

@ -13,7 +13,7 @@
//! // Retry up to 3 times with increasing intervals between attempts. //! // Retry up to 3 times with increasing intervals between attempts.
//! let retry_policy = ExponentialBackoff::builder().build_with_max_retries(3); //! let retry_policy = ExponentialBackoff::builder().build_with_max_retries(3);
//! let client = ClientBuilder::new(reqwest::Client::new()) //! let client = ClientBuilder::new(reqwest::Client::new())
//! .with(RetryTransientMiddleware::new_with_policy(retry_policy)) //! .layer(RetryTransientMiddleware::new_with_policy(retry_policy))
//! .build(); //! .build();
//! //!
//! client //! client

View file

@ -1,15 +1,19 @@
//! `RetryTransientMiddleware` implements retrying requests on transient errors. //! `RetryTransientMiddleware` implements retrying requests on transient errors.
use std::pin::Pin;
use std::task::{ready, Context, Poll};
use crate::retryable::Retryable; use crate::retryable::Retryable;
use anyhow::anyhow;
use chrono::Utc; use chrono::Utc;
use reqwest::{Request, Response}; use futures::Future;
use reqwest_middleware::{Error, Middleware, Next, Result}; use pin_project_lite::pin_project;
use reqwest::Response;
use reqwest_middleware::{Error, MiddlewareRequest};
use retry_policies::RetryPolicy; use retry_policies::RetryPolicy;
use task_local_extensions::Extensions; use task_local_extensions::Extensions;
use tokio::time::Sleep;
/// We limit the number of retries to a maximum of `10` to avoid stack-overflow issues due to the recursion. use tower::retry::{Policy, Retry};
static MAXIMUM_NUMBER_OF_RETRIES: u32 = 10; use tower::{Layer, Service};
/// `RetryTransientMiddleware` offers retry logic for requests that fail in a transient manner /// `RetryTransientMiddleware` offers retry logic for requests that fail in a transient manner
/// and can be safely executed again. /// and can be safely executed again.
@ -32,7 +36,7 @@ static MAXIMUM_NUMBER_OF_RETRIES: u32 = 10;
/// }; /// };
/// ///
/// let retry_transient_middleware = RetryTransientMiddleware::new_with_policy(retry_policy); /// let retry_transient_middleware = RetryTransientMiddleware::new_with_policy(retry_policy);
/// let client = ClientBuilder::new(Client::new()).with(retry_transient_middleware).build(); /// let client = ClientBuilder::new(Client::new()).layer(retry_transient_middleware).build();
///``` ///```
/// ///
/// # Note /// # Note
@ -58,76 +62,95 @@ impl<T: RetryPolicy + Send + Sync> RetryTransientMiddleware<T> {
} }
} }
#[async_trait::async_trait] impl<T: RetryPolicy + Clone + Send + Sync + 'static, Svc> Layer<Svc> for RetryTransientMiddleware<T>
impl<T: RetryPolicy + Send + Sync> Middleware for RetryTransientMiddleware<T> { where
async fn handle( Svc: Service<MiddlewareRequest, Response = Response, Error = Error>,
&self, {
req: Request, type Service = Retry<TowerRetryPolicy<T>, Svc>;
extensions: &mut Extensions,
next: Next<'_>, fn layer(&self, inner: Svc) -> Self::Service {
) -> Result<Response> { Retry::new(
// TODO: Ideally we should create a new instance of the `Extensions` map to pass TowerRetryPolicy {
// downstream. This will guard against previous retries poluting `Extensions`. n_past_retries: 0,
// That is, we only return what's populated in the typemap for the last retry attempt retry_policy: self.retry_policy.clone(),
// and copy those into the the `global` Extensions map. },
self.execute_with_retry(req, next, extensions).await inner,
)
} }
} }
impl<T: RetryPolicy + Send + Sync> RetryTransientMiddleware<T> { #[derive(Clone)]
/// This function will try to execute the request, if it fails pub struct TowerRetryPolicy<T> {
/// with an error classified as transient it will call itself n_past_retries: u32,
/// to retry the request. retry_policy: T,
async fn execute_with_retry<'a>( }
&'a self,
req: Request,
next: Next<'a>,
ext: &'a mut Extensions,
) -> Result<Response> {
let mut n_past_retries = 0;
loop {
// Cloning the request object before-the-fact is not ideal..
// However, if the body of the request is not static, e.g of type `Bytes`,
// the Clone operation should be of constant complexity and not O(N)
// since the byte abstraction is a shared pointer over a buffer.
let duplicate_request = req.try_clone().ok_or_else(|| {
Error::Middleware(anyhow!(
"Request object is not clonable. Are you passing a streaming body?".to_string()
))
})?;
let result = next.clone().run(duplicate_request, ext).await; pin_project! {
pub struct RetryFuture<T>
{
retry: Option<TowerRetryPolicy<T>>,
#[pin]
sleep: Sleep,
}
}
// We classify the response which will return None if not impl<T> Future for RetryFuture<T> {
// errors were returned. type Output = TowerRetryPolicy<T>;
break match Retryable::from_reqwest_response(&result) {
Some(retryable)
if retryable == Retryable::Transient
&& n_past_retries < MAXIMUM_NUMBER_OF_RETRIES =>
{
// If the response failed and the error type was transient
// we can safely try to retry the request.
let retry_decicion = self.retry_policy.should_retry(n_past_retries);
if let retry_policies::RetryDecision::Retry { execute_after } = retry_decicion {
let duration = (execute_after - Utc::now())
.to_std()
.map_err(Error::middleware)?;
// Sleep the requested amount before we try again.
tracing::warn!(
"Retry attempt #{}. Sleeping {:?} before the next attempt",
n_past_retries,
duration
);
tokio::time::sleep(duration).await;
n_past_retries += 1; fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
continue; let this = self.project();
} else { ready!(this.sleep.poll(cx));
result Poll::Ready(
} this.retry
.take()
.expect("poll should not be called more than once"),
)
}
}
impl<T: RetryPolicy + Clone> Policy<MiddlewareRequest, Response, Error> for TowerRetryPolicy<T> {
type Future = RetryFuture<T>;
fn retry(
&self,
_req: &MiddlewareRequest,
result: std::result::Result<&Response, &Error>,
) -> Option<Self::Future> {
// We classify the response which will return None if not
// errors were returned.
match Retryable::from_reqwest_response(result) {
Some(Retryable::Transient) => {
// If the response failed and the error type was transient
// we can safely try to retry the request.
let retry_decicion = self.retry_policy.should_retry(self.n_past_retries);
if let retry_policies::RetryDecision::Retry { execute_after } = retry_decicion {
let duration = (execute_after - Utc::now()).to_std().ok()?;
// Sleep the requested amount before we try again.
tracing::warn!(
"Retry attempt #{}. Sleeping {:?} before the next attempt",
self.n_past_retries,
duration
);
let sleep = tokio::time::sleep(duration);
Some(RetryFuture {
retry: Some(TowerRetryPolicy {
n_past_retries: self.n_past_retries + 1,
retry_policy: self.retry_policy.clone(),
}),
sleep,
})
} else {
None
} }
Some(_) | None => result, }
}; Some(_) | None => None,
} }
} }
fn clone_request(&self, req: &MiddlewareRequest) -> Option<MiddlewareRequest> {
Some(MiddlewareRequest {
request: req.request.try_clone()?,
extensions: Extensions::new(),
})
}
} }

View file

@ -15,7 +15,7 @@ impl Retryable {
/// ///
/// Returns `None` if the response object does not contain any errors. /// Returns `None` if the response object does not contain any errors.
/// ///
pub fn from_reqwest_response(res: &Result<reqwest::Response, Error>) -> Option<Self> { pub fn from_reqwest_response(res: Result<&reqwest::Response, &Error>) -> Option<Self> {
match res { match res {
Ok(success) => { Ok(success) => {
let status = success.status(); let status = success.status();

View file

@ -48,7 +48,7 @@ macro_rules! assert_retry_succeeds_inner {
let reqwest_client = Client::builder().build().unwrap(); let reqwest_client = Client::builder().build().unwrap();
let client = ClientBuilder::new(reqwest_client) let client = ClientBuilder::new(reqwest_client)
.with(RetryTransientMiddleware::new_with_policy( .layer(RetryTransientMiddleware::new_with_policy(
ExponentialBackoff { ExponentialBackoff {
max_n_retries: retry_amount, max_n_retries: retry_amount,
max_retry_interval: std::time::Duration::from_millis(30), max_retry_interval: std::time::Duration::from_millis(30),
@ -147,17 +147,6 @@ assert_retry_succeeds!(429, StatusCode::OK);
assert_no_retry!(431, StatusCode::REQUEST_HEADER_FIELDS_TOO_LARGE); assert_no_retry!(431, StatusCode::REQUEST_HEADER_FIELDS_TOO_LARGE);
assert_no_retry!(451, StatusCode::UNAVAILABLE_FOR_LEGAL_REASONS); assert_no_retry!(451, StatusCode::UNAVAILABLE_FOR_LEGAL_REASONS);
// We assert that we cap retries at 10, which means that we will
// get 11 calls to the RetryResponder.
assert_retry_succeeds_inner!(
500,
assert_maximum_retries_is_not_exceeded,
StatusCode::INTERNAL_SERVER_ERROR,
100,
11,
RetryResponder::new(100_u32, 500)
);
pub struct RetryTimeoutResponder(Arc<AtomicU32>, u32, std::time::Duration); pub struct RetryTimeoutResponder(Arc<AtomicU32>, u32, std::time::Duration);
impl RetryTimeoutResponder { impl RetryTimeoutResponder {
@ -195,7 +184,7 @@ async fn assert_retry_on_request_timeout() {
let reqwest_client = Client::builder().build().unwrap(); let reqwest_client = Client::builder().build().unwrap();
let client = ClientBuilder::new(reqwest_client) let client = ClientBuilder::new(reqwest_client)
.with(RetryTransientMiddleware::new_with_policy( .layer(RetryTransientMiddleware::new_with_policy(
ExponentialBackoff { ExponentialBackoff {
max_n_retries: 3, max_n_retries: 3,
max_retry_interval: std::time::Duration::from_millis(100), max_retry_interval: std::time::Duration::from_millis(100),
@ -250,7 +239,7 @@ async fn assert_retry_on_incomplete_message() {
let reqwest_client = Client::builder().build().unwrap(); let reqwest_client = Client::builder().build().unwrap();
let client = ClientBuilder::new(reqwest_client) let client = ClientBuilder::new(reqwest_client)
.with(RetryTransientMiddleware::new_with_policy( .layer(RetryTransientMiddleware::new_with_policy(
ExponentialBackoff { ExponentialBackoff {
max_n_retries: 3, max_n_retries: 3,
max_retry_interval: std::time::Duration::from_millis(100), max_retry_interval: std::time::Duration::from_millis(100),

View file

@ -1,6 +1,6 @@
[package] [package]
name = "reqwest-tracing" name = "reqwest-tracing"
version = "0.4.0" version = "0.5.0"
authors = ["Rodrigo Gryzinski <rodrigo.gryzinski@truelayer.com>"] authors = ["Rodrigo Gryzinski <rodrigo.gryzinski@truelayer.com>"]
edition = "2018" edition = "2018"
description = "Opentracing middleware for reqwest." description = "Opentracing middleware for reqwest."
@ -19,12 +19,14 @@ opentelemetry_0_18 = ["opentelemetry_0_18_pkg", "tracing-opentelemetry_0_18_pkg"
[dependencies] [dependencies]
reqwest-middleware = { version = "0.2.0", path = "../reqwest-middleware" } reqwest-middleware = { version = "0.3.0", path = "../reqwest-middleware" }
async-trait = "0.1.51" async-trait = "0.1.51"
reqwest = { version = "0.11", default-features = false } reqwest = { version = "0.11", default-features = false }
task-local-extensions = "0.1.1" task-local-extensions = "0.1.1"
tracing = "0.1.26" tracing = "0.1.26"
tower = "0.4"
pin-project-lite = "0.2"
opentelemetry_0_13_pkg = { package = "opentelemetry", version = "0.13", optional = true } opentelemetry_0_13_pkg = { package = "opentelemetry", version = "0.13", optional = true }
opentelemetry_0_14_pkg = { package = "opentelemetry", version = "0.14", optional = true } opentelemetry_0_14_pkg = { package = "opentelemetry", version = "0.14", optional = true }

View file

@ -103,3 +103,18 @@ pub use reqwest_otel_span_builder::{
#[doc(hidden)] #[doc(hidden)]
pub mod reqwest_otel_span_macro; pub mod reqwest_otel_span_macro;
#[cfg(test)]
mod tests {
use crate::{TracingMiddleware, DefaultSpanBackend};
use reqwest_middleware::ClientBuilder;
#[tokio::test]
async fn compiles() {
let client = ClientBuilder::new(reqwest::Client::new())
.layer(TracingMiddleware::<DefaultSpanBackend>::new())
.build();
let resp = client.get("http://example.com").send().await.unwrap();
dbg!(resp);
}
}

View file

@ -1,7 +1,10 @@
use reqwest::{Request, Response}; use std::{future::Future, task::ready};
use reqwest_middleware::{Middleware, Next, Result};
use task_local_extensions::Extensions; use pin_project_lite::pin_project;
use tracing::Instrument; use reqwest::Response;
use reqwest_middleware::{Error, MiddlewareRequest};
use tower::{Layer, Service};
use tracing::Span;
use crate::{DefaultSpanBackend, ReqwestOtelSpanBackend}; use crate::{DefaultSpanBackend, ReqwestOtelSpanBackend};
@ -10,6 +13,8 @@ pub struct TracingMiddleware<S: ReqwestOtelSpanBackend> {
span_backend: std::marker::PhantomData<S>, span_backend: std::marker::PhantomData<S>,
} }
impl<S: ReqwestOtelSpanBackend> Copy for TracingMiddleware<S> {}
impl<S: ReqwestOtelSpanBackend> TracingMiddleware<S> { impl<S: ReqwestOtelSpanBackend> TracingMiddleware<S> {
pub fn new() -> TracingMiddleware<S> { pub fn new() -> TracingMiddleware<S> {
TracingMiddleware { TracingMiddleware {
@ -30,38 +35,98 @@ impl Default for TracingMiddleware<DefaultSpanBackend> {
} }
} }
#[async_trait::async_trait] impl<ReqwestOtelSpan, Svc> Layer<Svc> for TracingMiddleware<ReqwestOtelSpan>
impl<ReqwestOtelSpan> Middleware for TracingMiddleware<ReqwestOtelSpan>
where where
ReqwestOtelSpan: ReqwestOtelSpanBackend + Sync + Send + 'static, ReqwestOtelSpan: ReqwestOtelSpanBackend + Sync + Send + 'static,
{ {
async fn handle( type Service = TracingMiddlewareService<ReqwestOtelSpan, Svc>;
&self,
req: Request,
extensions: &mut Extensions,
next: Next<'_>,
) -> Result<Response> {
let request_span = ReqwestOtelSpan::on_request_start(&req, extensions);
let outcome_future = async { fn layer(&self, inner: Svc) -> Self::Service {
// Adds tracing headers to the given request to propagate the OpenTelemetry context to downstream revivers of the request. TracingMiddlewareService {
// Spans added by downstream consumers will be part of the same trace. service: inner,
#[cfg(any( layer: *self,
feature = "opentelemetry_0_13", }
feature = "opentelemetry_0_14", }
feature = "opentelemetry_0_15", }
feature = "opentelemetry_0_16",
feature = "opentelemetry_0_17", /// Middleware Service for tracing requests using the current Opentelemetry Context.
feature = "opentelemetry_0_18", pub struct TracingMiddlewareService<S: ReqwestOtelSpanBackend, Svc> {
))] layer: TracingMiddleware<S>,
let req = crate::otel::inject_opentelemetry_context_into_request(req); service: Svc,
}
// Run the request
let outcome = next.run(req, extensions).await; impl<ReqwestOtelSpan, Svc> Service<MiddlewareRequest>
ReqwestOtelSpan::on_request_end(&request_span, &outcome, extensions); for TracingMiddlewareService<ReqwestOtelSpan, Svc>
outcome where
}; ReqwestOtelSpan: ReqwestOtelSpanBackend + Sync + Send + 'static,
Svc: Service<MiddlewareRequest, Response = Response, Error = Error>,
outcome_future.instrument(request_span.clone()).await {
type Response = Response;
type Error = Error;
type Future = TracingMiddlewareFuture<ReqwestOtelSpan, Svc::Future>;
fn poll_ready(
&mut self,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Result<(), Self::Error>> {
self.service.poll_ready(cx)
}
fn call(&mut self, req: MiddlewareRequest) -> Self::Future {
let MiddlewareRequest {
request,
mut extensions,
} = req;
let request_span = ReqwestOtelSpan::on_request_start(&request, &mut extensions);
// Adds tracing headers to the given request to propagate the OpenTelemetry context to downstream revivers of the request.
// Spans added by downstream consumers will be part of the same trace.
#[cfg(any(
feature = "opentelemetry_0_13",
feature = "opentelemetry_0_14",
feature = "opentelemetry_0_15",
feature = "opentelemetry_0_16",
feature = "opentelemetry_0_17",
feature = "opentelemetry_0_18",
))]
let request = crate::otel::inject_opentelemetry_context_into_request(request);
let future = self.service.call(MiddlewareRequest {
request,
extensions,
});
TracingMiddlewareFuture {
layer: self.layer,
span: request_span,
future,
}
}
}
pin_project!(
pub struct TracingMiddlewareFuture<S: ReqwestOtelSpanBackend, F> {
layer: TracingMiddleware<S>,
span: Span,
#[pin]
future: F,
}
);
impl<S: ReqwestOtelSpanBackend, F: Future<Output = Result<Response, Error>>> Future
for TracingMiddlewareFuture<S, F>
{
type Output = F::Output;
fn poll(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Self::Output> {
let this = self.project();
let outcome = {
let _guard = this.span.enter();
ready!(this.future.poll(cx))
};
S::on_request_end(this.span, &outcome);
std::task::Poll::Ready(outcome)
} }
} }

View file

@ -44,7 +44,7 @@ pub trait ReqwestOtelSpanBackend {
fn on_request_start(req: &Request, extension: &mut Extensions) -> Span; fn on_request_start(req: &Request, extension: &mut Extensions) -> Span;
/// Runs after the request call has executed. /// Runs after the request call has executed.
fn on_request_end(span: &Span, outcome: &Result<Response>, extension: &mut Extensions); fn on_request_end(span: &Span, outcome: &Result<Response>);
} }
/// Populates default success/failure fields for a given [`reqwest_otel_span!`] span. /// Populates default success/failure fields for a given [`reqwest_otel_span!`] span.
@ -103,7 +103,7 @@ impl ReqwestOtelSpanBackend for DefaultSpanBackend {
reqwest_otel_span!(name = name, req) reqwest_otel_span!(name = name, req)
} }
fn on_request_end(span: &Span, outcome: &Result<Response>, _: &mut Extensions) { fn on_request_end(span: &Span, outcome: &Result<Response>) {
default_on_request_end(span, outcome) default_on_request_end(span, outcome)
} }
} }
@ -128,7 +128,7 @@ impl ReqwestOtelSpanBackend for SpanBackendWithUrl {
reqwest_otel_span!(name = name, req, http.url = %remove_credentials(req.url())) reqwest_otel_span!(name = name, req, http.url = %remove_credentials(req.url()))
} }
fn on_request_end(span: &Span, outcome: &Result<Response>, _: &mut Extensions) { fn on_request_end(span: &Span, outcome: &Result<Response>) {
default_on_request_end(span, outcome) default_on_request_end(span, outcome)
} }
} }