Compare commits

...

4 Commits
main ... tower

Author SHA1 Message Date
Conrad Ludgate 8f3623eae0
remove but inspired by tower 2022-11-15 21:51:54 +00:00
Conrad Ludgate df4990d62e
fix last tracing test 2022-11-15 17:07:36 +00:00
Conrad Ludgate 571b9abc49
fix docs 2022-11-15 16:34:33 +00:00
Conrad Ludgate 6eaa2365ed
expierments with tower 2022-11-15 14:53:23 +00:00
15 changed files with 628 additions and 424 deletions

View File

@ -11,8 +11,8 @@ to allow for client middleware chains.
This crate provides functionality for building and running middleware but no middleware This crate provides functionality for building and running middleware but no middleware
implementations. This repository also contains a couple of useful concrete middleware crates: implementations. This repository also contains a couple of useful concrete middleware crates:
* [`reqwest-retry`](https://crates.io/crates/reqwest-retry): retry failed requests. - [`reqwest-retry`](https://crates.io/crates/reqwest-retry): retry failed requests.
* [`reqwest-tracing`](https://crates.io/crates/reqwest-tracing): - [`reqwest-tracing`](https://crates.io/crates/reqwest-tracing):
[`tracing`](https://crates.io/crates/tracing) integration, optional opentelemetry support. [`tracing`](https://crates.io/crates/tracing) integration, optional opentelemetry support.
## Overview ## Overview
@ -29,10 +29,12 @@ reqwest-middleware = "0.1.6"
reqwest-retry = "0.1.5" reqwest-retry = "0.1.5"
reqwest-tracing = "0.2.3" reqwest-tracing = "0.2.3"
tokio = { version = "1.12.0", features = ["macros", "rt-multi-thread"] } tokio = { version = "1.12.0", features = ["macros", "rt-multi-thread"] }
tower = "0.4"
``` ```
```rust ```rust
use reqwest_middleware::{ClientBuilder, ClientWithMiddleware}; use reqwest::Response;
use reqwest_middleware::{ClientBuilder, ClientWithMiddleware, Error, Layer, RequestInitialiser, ReqwestService, Service};
use reqwest_retry::{RetryTransientMiddleware, policies::ExponentialBackoff}; use reqwest_retry::{RetryTransientMiddleware, policies::ExponentialBackoff};
use reqwest_tracing::TracingMiddleware; use reqwest_tracing::TracingMiddleware;
@ -49,7 +51,12 @@ async fn main() {
run(client).await; run(client).await;
} }
async fn run(client: ClientWithMiddleware) { async fn run<M, I>(client: ClientWithMiddleware<M, I>)
where
M: Layer<ReqwestService>,
M::Service: Service,
I: RequestInitialiser,
{
client client
.get("https://truelayer.com") .get("https://truelayer.com")
.header("foo", "bar") .header("foo", "bar")

View File

@ -1,6 +1,6 @@
[package] [package]
name = "reqwest-middleware" name = "reqwest-middleware"
version = "0.2.0" version = "0.3.0"
authors = ["Rodrigo Gryzinski <rodrigo.gryzinski@truelayer.com>"] authors = ["Rodrigo Gryzinski <rodrigo.gryzinski@truelayer.com>"]
edition = "2018" edition = "2018"
description = "Wrapper around reqwest to allow for client middleware chains." description = "Wrapper around reqwest to allow for client middleware chains."
@ -18,6 +18,7 @@ reqwest = { version = "0.11", default-features = false, features = ["json", "mul
serde = "1" serde = "1"
task-local-extensions = "0.1.1" task-local-extensions = "0.1.1"
thiserror = "1" thiserror = "1"
futures = "0.3"
[dev-dependencies] [dev-dependencies]
reqwest = "0.11" reqwest = "0.11"

View File

@ -1,81 +1,67 @@
use futures::future::BoxFuture;
use futures::FutureExt;
use reqwest::header::{HeaderMap, HeaderName, HeaderValue}; use reqwest::header::{HeaderMap, HeaderName, HeaderValue};
use reqwest::multipart::Form; use reqwest::multipart::Form;
use reqwest::{Body, Client, IntoUrl, Method, Request, Response}; use reqwest::{Body, Client, IntoUrl, Method, Request, Response};
use serde::Serialize; use serde::Serialize;
use std::convert::TryFrom; use std::convert::TryFrom;
use std::fmt::{self, Display}; use std::fmt::{self, Display};
use std::sync::Arc;
use std::time::Duration; use std::time::Duration;
use task_local_extensions::Extensions; use task_local_extensions::Extensions;
// use tower::{Layer, Service, ServiceBuilder, ServiceExt};
use crate::error::Result; use crate::{Error, Identity, Layer, RequestInitialiser, RequestStack, Service, Stack};
use crate::middleware::{Middleware, Next};
use crate::RequestInitialiser;
/// A `ClientBuilder` is used to build a [`ClientWithMiddleware`]. /// A `ClientBuilder` is used to build a [`ClientWithMiddleware`].
/// ///
/// [`ClientWithMiddleware`]: crate::ClientWithMiddleware /// [`ClientWithMiddleware`]: crate::ClientWithMiddleware
pub struct ClientBuilder { pub struct ClientBuilder<M, I> {
client: Client, client: Client,
middleware_stack: Vec<Arc<dyn Middleware>>, middleware_stack: M,
initialiser_stack: Vec<Arc<dyn RequestInitialiser>>, initialiser_stack: I,
} }
impl ClientBuilder { impl ClientBuilder<Identity, Identity> {
pub fn new(client: Client) -> Self { pub fn new(client: Client) -> Self {
ClientBuilder { ClientBuilder {
client, client,
middleware_stack: Vec::new(), middleware_stack: Identity,
initialiser_stack: Vec::new(), initialiser_stack: Identity,
}
} }
} }
impl<M, I> ClientBuilder<M, I> {
/// Convenience method to attach middleware. /// Convenience method to attach middleware.
/// pub fn with<T>(self, layer: T) -> ClientBuilder<Stack<T, M>, I> {
/// If you need to keep a reference to the middleware after attaching, use [`with_arc`]. ClientBuilder {
/// client: self.client,
/// [`with_arc`]: Self::with_arc middleware_stack: Stack {
pub fn with<M>(self, middleware: M) -> Self inner: layer,
where outer: self.middleware_stack,
M: Middleware, },
{ initialiser_stack: self.initialiser_stack,
self.with_arc(Arc::new(middleware))
} }
/// Add middleware to the chain. [`with`] is more ergonomic if you don't need the `Arc`.
///
/// [`with`]: Self::with
pub fn with_arc(mut self, middleware: Arc<dyn Middleware>) -> Self {
self.middleware_stack.push(middleware);
self
} }
/// Convenience method to attach a request initialiser. /// Convenience method to attach a request initialiser.
/// pub fn with_init<T>(self, initialiser: T) -> ClientBuilder<M, RequestStack<T, I>> {
/// If you need to keep a reference to the initialiser after attaching, use [`with_arc_init`]. ClientBuilder {
/// client: self.client,
/// [`with_arc_init`]: Self::with_arc_init middleware_stack: self.middleware_stack,
pub fn with_init<I>(self, initialiser: I) -> Self initialiser_stack: RequestStack {
where inner: initialiser,
I: RequestInitialiser, outer: self.initialiser_stack,
{ },
self.with_arc_init(Arc::new(initialiser))
} }
/// Add a request initialiser to the chain. [`with_init`] is more ergonomic if you don't need the `Arc`.
///
/// [`with_init`]: Self::with_init
pub fn with_arc_init(mut self, initialiser: Arc<dyn RequestInitialiser>) -> Self {
self.initialiser_stack.push(initialiser);
self
} }
/// Returns a `ClientWithMiddleware` using this builder configuration. /// Returns a `ClientWithMiddleware` using this builder configuration.
pub fn build(self) -> ClientWithMiddleware { pub fn build(self) -> ClientWithMiddleware<M, I> {
ClientWithMiddleware { ClientWithMiddleware {
inner: self.client, inner: self.client,
middleware_stack: self.middleware_stack.into_boxed_slice(), middleware_stack: self.middleware_stack,
initialiser_stack: self.initialiser_stack.into_boxed_slice(), initialiser_stack: self.initialiser_stack,
} }
} }
} }
@ -83,97 +69,68 @@ impl ClientBuilder {
/// `ClientWithMiddleware` is a wrapper around [`reqwest::Client`] which runs middleware on every /// `ClientWithMiddleware` is a wrapper around [`reqwest::Client`] which runs middleware on every
/// request. /// request.
#[derive(Clone)] #[derive(Clone)]
pub struct ClientWithMiddleware { pub struct ClientWithMiddleware<M, I> {
inner: reqwest::Client, inner: reqwest::Client,
middleware_stack: Box<[Arc<dyn Middleware>]>, middleware_stack: M,
initialiser_stack: Box<[Arc<dyn RequestInitialiser>]>, initialiser_stack: I,
}
impl ClientWithMiddleware {
/// See [`ClientBuilder`] for a more ergonomic way to build `ClientWithMiddleware` instances.
pub fn new<T>(client: Client, middleware_stack: T) -> Self
where
T: Into<Box<[Arc<dyn Middleware>]>>,
{
ClientWithMiddleware {
inner: client,
middleware_stack: middleware_stack.into(),
// TODO(conradludgate) - allow downstream code to control this manually if desired
initialiser_stack: Box::new([]),
}
} }
impl<M: Layer<ReqwestService>, I: RequestInitialiser> ClientWithMiddleware<M, I> {
/// See [`Client::get`] /// See [`Client::get`]
pub fn get<U: IntoUrl>(&self, url: U) -> RequestBuilder { pub fn get<U: IntoUrl>(&self, url: U) -> RequestBuilder<M, I> {
self.request(Method::GET, url) self.request(Method::GET, url)
} }
/// See [`Client::post`] /// See [`Client::post`]
pub fn post<U: IntoUrl>(&self, url: U) -> RequestBuilder { pub fn post<U: IntoUrl>(&self, url: U) -> RequestBuilder<M, I> {
self.request(Method::POST, url) self.request(Method::POST, url)
} }
/// See [`Client::put`] /// See [`Client::put`]
pub fn put<U: IntoUrl>(&self, url: U) -> RequestBuilder { pub fn put<U: IntoUrl>(&self, url: U) -> RequestBuilder<M, I> {
self.request(Method::PUT, url) self.request(Method::PUT, url)
} }
/// See [`Client::patch`] /// See [`Client::patch`]
pub fn patch<U: IntoUrl>(&self, url: U) -> RequestBuilder { pub fn patch<U: IntoUrl>(&self, url: U) -> RequestBuilder<M, I> {
self.request(Method::PATCH, url) self.request(Method::PATCH, url)
} }
/// See [`Client::delete`] /// See [`Client::delete`]
pub fn delete<U: IntoUrl>(&self, url: U) -> RequestBuilder { pub fn delete<U: IntoUrl>(&self, url: U) -> RequestBuilder<M, I> {
self.request(Method::DELETE, url) self.request(Method::DELETE, url)
} }
/// See [`Client::head`] /// See [`Client::head`]
pub fn head<U: IntoUrl>(&self, url: U) -> RequestBuilder { pub fn head<U: IntoUrl>(&self, url: U) -> RequestBuilder<M, I> {
self.request(Method::HEAD, url) self.request(Method::HEAD, url)
} }
/// See [`Client::request`] /// See [`Client::request`]
pub fn request<U: IntoUrl>(&self, method: Method, url: U) -> RequestBuilder { pub fn request<U: IntoUrl>(&self, method: Method, url: U) -> RequestBuilder<'_, M, I> {
let req = RequestBuilder { let mut extensions = Extensions::new();
inner: self.inner.request(method, url), let request = self.inner.request(method, url);
client: self.clone(), let request = self.initialiser_stack.init(request, &mut extensions);
extensions: Extensions::new(), RequestBuilder {
}; inner: request,
self.initialiser_stack client: self,
.iter() extensions,
.fold(req, |req, i| i.init(req))
} }
/// See [`Client::execute`]
pub async fn execute(&self, req: Request) -> Result<Response> {
let mut ext = Extensions::new();
self.execute_with_extensions(req, &mut ext).await
}
/// Executes a request with initial [`Extensions`].
pub async fn execute_with_extensions(
&self,
req: Request,
ext: &mut Extensions,
) -> Result<Response> {
let next = Next::new(&self.inner, &self.middleware_stack);
next.run(req, ext).await
} }
} }
/// Create a `ClientWithMiddleware` without any middleware. /// Create a `ClientWithMiddleware` without any middleware.
impl From<Client> for ClientWithMiddleware { impl From<Client> for ClientWithMiddleware<Identity, Identity> {
fn from(client: Client) -> Self { fn from(client: Client) -> Self {
ClientWithMiddleware { ClientWithMiddleware {
inner: client, inner: client,
middleware_stack: Box::new([]), middleware_stack: Identity,
initialiser_stack: Box::new([]), initialiser_stack: Identity,
} }
} }
} }
impl fmt::Debug for ClientWithMiddleware { impl<M, I> fmt::Debug for ClientWithMiddleware<M, I> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
// skipping middleware_stack field for now // skipping middleware_stack field for now
f.debug_struct("ClientWithMiddleware") f.debug_struct("ClientWithMiddleware")
@ -184,13 +141,28 @@ impl fmt::Debug for ClientWithMiddleware {
/// This is a wrapper around [`reqwest::RequestBuilder`] exposing the same API. /// This is a wrapper around [`reqwest::RequestBuilder`] exposing the same API.
#[must_use = "RequestBuilder does nothing until you 'send' it"] #[must_use = "RequestBuilder does nothing until you 'send' it"]
pub struct RequestBuilder { pub struct RequestBuilder<'client, M, I> {
inner: reqwest::RequestBuilder, inner: reqwest::RequestBuilder,
client: ClientWithMiddleware, client: &'client ClientWithMiddleware<M, I>,
extensions: Extensions, extensions: Extensions,
} }
impl RequestBuilder { #[derive(Clone)]
pub struct ReqwestService(Client);
impl Service for ReqwestService {
type Future = BoxFuture<'static, Result<Response, Error>>;
fn call(&mut self, req: Request, _: &mut Extensions) -> Self::Future {
let fut = self.0.execute(req);
async { fut.await.map_err(Error::from) }.boxed()
}
}
impl<M: Layer<ReqwestService>, I: RequestInitialiser> RequestBuilder<'_, M, I>
where
M::Service: Service,
{
pub fn header<K, V>(self, key: K, value: V) -> Self pub fn header<K, V>(self, key: K, value: V) -> Self
where where
HeaderName: TryFrom<K>, HeaderName: TryFrom<K>,
@ -289,14 +261,19 @@ impl RequestBuilder {
&mut self.extensions &mut self.extensions
} }
pub async fn send(self) -> Result<Response> { pub async fn send(self) -> Result<Response, Error> {
let Self { let Self {
inner, inner,
client, client,
mut extensions, mut extensions,
} = self; } = self;
let req = inner.build()?; let req = inner.build()?;
client.execute_with_extensions(req, &mut extensions).await let mut svc = client
.middleware_stack
.layer(ReqwestService(client.inner.clone()));
svc.call(req, &mut extensions).await
// client.execute_with_extensions(req, &mut extensions).await
} }
/// Attempt to clone the RequestBuilder. /// Attempt to clone the RequestBuilder.
@ -309,13 +286,13 @@ impl RequestBuilder {
pub fn try_clone(&self) -> Option<Self> { pub fn try_clone(&self) -> Option<Self> {
self.inner.try_clone().map(|inner| RequestBuilder { self.inner.try_clone().map(|inner| RequestBuilder {
inner, inner,
client: self.client.clone(), client: self.client,
extensions: Extensions::new(), extensions: Extensions::new(),
}) })
} }
} }
impl fmt::Debug for RequestBuilder { impl<M, I> fmt::Debug for RequestBuilder<'_, M, I> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
// skipping middleware_stack field for now // skipping middleware_stack field for now
f.debug_struct("RequestBuilder") f.debug_struct("RequestBuilder")

View File

@ -1,7 +1,5 @@
use thiserror::Error; use thiserror::Error;
pub type Result<T> = std::result::Result<T, Error>;
#[derive(Error, Debug)] #[derive(Error, Debug)]
pub enum Error { pub enum Error {
/// There was an error running some middleware /// There was an error running some middleware

View File

@ -7,30 +7,44 @@
//! //!
//! ``` //! ```
//! use reqwest::{Client, Request, Response}; //! use reqwest::{Client, Request, Response};
//! use reqwest_middleware::{ClientBuilder, Middleware, Next, Result}; //! use reqwest_middleware::{ClientBuilder, Error, Extension, Layer, Service};
//! use task_local_extensions::Extensions; //! use task_local_extensions::Extensions;
//! use futures::future::{BoxFuture, FutureExt};
//! use std::task::{Context, Poll};
//! //!
//! struct LoggingMiddleware; //! struct LoggingLayer;
//! struct LoggingService<S>(S);
//! //!
//! #[async_trait::async_trait] //! impl<S> Layer<S> for LoggingLayer {
//! impl Middleware for LoggingMiddleware { //! type Service = LoggingService<S>;
//! async fn handle( //!
//! &self, //! fn layer(&self, inner: S) -> Self::Service {
//! req: Request, //! LoggingService(inner)
//! extensions: &mut Extensions, //! }
//! next: Next<'_>, //! }
//! ) -> Result<Response> { //!
//! println!("Request started {:?}", req); //! impl<S> Service for LoggingService<S>
//! let res = next.run(req, extensions).await; //! where
//! println!("Result: {:?}", res); //! S: Service,
//! S::Future: Send + 'static,
//! {
//! type Future = BoxFuture<'static, Result<Response, Error>>;
//!
//! fn call(&mut self, req: Request, ext: &mut Extensions) -> Self::Future {
//! println!("Request started {req:?}");
//! let fut = self.0.call(req, ext);
//! async {
//! let res = fut.await;
//! println!("Result: {res:?}");
//! res //! res
//! }.boxed()
//! } //! }
//! } //! }
//! //!
//! async fn run() { //! async fn run() {
//! let reqwest_client = Client::builder().build().unwrap(); //! let reqwest_client = Client::builder().build().unwrap();
//! let client = ClientBuilder::new(reqwest_client) //! let client = ClientBuilder::new(reqwest_client)
//! .with(LoggingMiddleware) //! .with(LoggingLayer)
//! .build(); //! .build();
//! let resp = client.get("https://truelayer.com").send().await.unwrap(); //! let resp = client.get("https://truelayer.com").send().await.unwrap();
//! println!("TrueLayer page HTML: {}", resp.text().await.unwrap()); //! println!("TrueLayer page HTML: {}", resp.text().await.unwrap());
@ -51,10 +65,54 @@ pub struct ReadmeDoctests;
mod client; mod client;
mod error; mod error;
mod middleware;
mod req_init; mod req_init;
pub use client::{ClientBuilder, ClientWithMiddleware, RequestBuilder}; pub use client::{ClientBuilder, ClientWithMiddleware, RequestBuilder, ReqwestService};
pub use error::{Error, Result}; pub use error::Error;
pub use middleware::{Middleware, Next}; pub use req_init::{Extension, RequestInitialiser, RequestStack};
pub use req_init::{Extension, RequestInitialiser}; use reqwest::{Request, Response};
use task_local_extensions::Extensions;
/// Two [`RequestInitialiser`]s or [`Service`]s chained together.
#[derive(Clone)]
pub struct Stack<Inner, Outer> {
pub(crate) inner: Inner,
pub(crate) outer: Outer,
}
pub trait Service {
type Future: std::future::Future<Output = Result<Response, Error>>;
fn call(&mut self, req: Request, ext: &mut Extensions) -> Self::Future;
}
pub struct Identity;
impl<S: Service> Layer<S> for Identity {
type Service = S;
fn layer(&self, inner: S) -> Self::Service {
inner
}
}
pub trait Layer<S> {
/// The wrapped service
type Service;
/// Wrap the given service with the middleware, returning a new service
/// that has been decorated with the middleware.
fn layer(&self, inner: S) -> Self::Service;
}
impl<S, Inner, Outer> Layer<S> for Stack<Inner, Outer>
where
Inner: Layer<S>,
Outer: Layer<Inner::Service>,
{
type Service = Outer::Service;
fn layer(&self, service: S) -> Self::Service {
let inner = self.inner.layer(service);
self.outer.layer(inner)
}
}

View File

@ -1,100 +0,0 @@
use reqwest::{Client, Request, Response};
use std::sync::Arc;
use task_local_extensions::Extensions;
use crate::error::{Error, Result};
/// When attached to a [`ClientWithMiddleware`] (generally using [`with`]), middleware is run
/// whenever the client issues a request, in the order it was attached.
///
/// # Example
///
/// ```
/// use reqwest::{Client, Request, Response};
/// use reqwest_middleware::{ClientBuilder, Middleware, Next, Result};
/// use task_local_extensions::Extensions;
///
/// struct TransparentMiddleware;
///
/// #[async_trait::async_trait]
/// impl Middleware for TransparentMiddleware {
/// async fn handle(
/// &self,
/// req: Request,
/// extensions: &mut Extensions,
/// next: Next<'_>,
/// ) -> Result<Response> {
/// next.run(req, extensions).await
/// }
/// }
/// ```
///
/// [`ClientWithMiddleware`]: crate::ClientWithMiddleware
/// [`with`]: crate::ClientBuilder::with
#[async_trait::async_trait]
pub trait Middleware: 'static + Send + Sync {
/// Invoked with a request before sending it. If you want to continue processing the request,
/// you should explicitly call `next.run(req, extensions)`.
///
/// If you need to forward data down the middleware stack, you can use the `extensions`
/// argument.
async fn handle(
&self,
req: Request,
extensions: &mut Extensions,
next: Next<'_>,
) -> Result<Response>;
}
#[async_trait::async_trait]
impl<F> Middleware for F
where
F: Send
+ Sync
+ 'static
+ for<'a> Fn(Request, &'a mut Extensions, Next<'a>) -> BoxFuture<'a, Result<Response>>,
{
async fn handle(
&self,
req: Request,
extensions: &mut Extensions,
next: Next<'_>,
) -> Result<Response> {
(self)(req, extensions, next).await
}
}
/// Next encapsulates the remaining middleware chain to run in [`Middleware::handle`]. You can
/// forward the request down the chain with [`run`].
///
/// [`Middleware::handle`]: Middleware::handle
/// [`run`]: Self::run
#[derive(Clone)]
pub struct Next<'a> {
client: &'a Client,
middlewares: &'a [Arc<dyn Middleware>],
}
pub type BoxFuture<'a, T> = std::pin::Pin<Box<dyn std::future::Future<Output = T> + Send + 'a>>;
impl<'a> Next<'a> {
pub(crate) fn new(client: &'a Client, middlewares: &'a [Arc<dyn Middleware>]) -> Self {
Next {
client,
middlewares,
}
}
pub fn run(
mut self,
req: Request,
extensions: &'a mut Extensions,
) -> BoxFuture<'a, Result<Response>> {
if let Some((current, rest)) = self.middlewares.split_first() {
self.middlewares = rest;
Box::pin(current.handle(req, extensions, self))
} else {
Box::pin(async move { self.client.execute(req).await.map_err(Error::from) })
}
}
}

View File

@ -1,4 +1,7 @@
use crate::RequestBuilder; use reqwest::RequestBuilder;
use task_local_extensions::Extensions;
use crate::Identity;
/// When attached to a [`ClientWithMiddleware`] (generally using [`with_init`]), it is run /// When attached to a [`ClientWithMiddleware`] (generally using [`with_init`]), it is run
/// whenever the client starts building a request, in the order it was attached. /// whenever the client starts building a request, in the order it was attached.
@ -6,12 +9,14 @@ use crate::RequestBuilder;
/// # Example /// # Example
/// ///
/// ``` /// ```
/// use reqwest_middleware::{RequestInitialiser, RequestBuilder}; /// use reqwest::RequestBuilder;
/// use reqwest_middleware::RequestInitialiser;
/// use task_local_extensions::Extensions;
/// ///
/// struct AuthInit; /// struct AuthInit;
/// ///
/// impl RequestInitialiser for AuthInit { /// impl RequestInitialiser for AuthInit {
/// fn init(&self, req: RequestBuilder) -> RequestBuilder { /// fn init(&self, req: RequestBuilder, ext: &mut Extensions) -> RequestBuilder {
/// req.bearer_auth("my_auth_token") /// req.bearer_auth("my_auth_token")
/// } /// }
/// } /// }
@ -20,15 +25,30 @@ use crate::RequestBuilder;
/// [`ClientWithMiddleware`]: crate::ClientWithMiddleware /// [`ClientWithMiddleware`]: crate::ClientWithMiddleware
/// [`with_init`]: crate::ClientBuilder::with_init /// [`with_init`]: crate::ClientBuilder::with_init
pub trait RequestInitialiser: 'static + Send + Sync { pub trait RequestInitialiser: 'static + Send + Sync {
fn init(&self, req: RequestBuilder) -> RequestBuilder; fn init(&self, req: RequestBuilder, ext: &mut Extensions) -> RequestBuilder;
} }
impl<F> RequestInitialiser for F impl RequestInitialiser for Identity {
fn init(&self, req: RequestBuilder, _: &mut Extensions) -> RequestBuilder {
req
}
}
/// Two [`RequestInitialiser`]s chained together.
#[derive(Clone)]
pub struct RequestStack<Inner, Outer> {
pub(crate) inner: Inner,
pub(crate) outer: Outer,
}
impl<I, O> RequestInitialiser for RequestStack<I, O>
where where
F: Send + Sync + 'static + Fn(RequestBuilder) -> RequestBuilder, I: RequestInitialiser,
O: RequestInitialiser,
{ {
fn init(&self, req: RequestBuilder) -> RequestBuilder { fn init(&self, req: RequestBuilder, ext: &mut Extensions) -> RequestBuilder {
(self)(req) let req = self.inner.init(req, ext);
self.outer.init(req, ext)
} }
} }
@ -38,30 +58,45 @@ where
/// ///
/// ``` /// ```
/// use reqwest::{Client, Request, Response}; /// use reqwest::{Client, Request, Response};
/// use reqwest_middleware::{ClientBuilder, Middleware, Next, Result, Extension}; /// use reqwest_middleware::{ClientBuilder, Error, Extension, Layer, Service};
/// use task_local_extensions::Extensions; /// use task_local_extensions::Extensions;
/// use futures::future::{BoxFuture, FutureExt};
/// use std::task::{Context, Poll};
/// ///
/// #[derive(Clone)] /// #[derive(Clone)]
/// struct LogName(&'static str); /// struct LogName(&'static str);
/// struct LoggingMiddleware;
/// ///
/// #[async_trait::async_trait] /// struct LoggingLayer;
/// impl Middleware for LoggingMiddleware { /// struct LoggingService<S>(S);
/// async fn handle( ///
/// &self, /// impl<S> Layer<S> for LoggingLayer {
/// req: Request, /// type Service = LoggingService<S>;
/// extensions: &mut Extensions, ///
/// next: Next<'_>, /// fn layer(&self, inner: S) -> Self::Service {
/// ) -> Result<Response> { /// LoggingService(inner)
/// }
/// }
///
/// impl<S> Service for LoggingService<S>
/// where
/// S: Service,
/// S::Future: Send + 'static,
/// {
/// type Future = BoxFuture<'static, Result<Response, Error>>;
///
/// fn call(&mut self, req: Request, ext: &mut Extensions) -> Self::Future {
/// // get the log name or default to "unknown" /// // get the log name or default to "unknown"
/// let name = extensions /// let name = ext
/// .get() /// .get()
/// .map(|&LogName(name)| name) /// .map(|&LogName(name)| name)
/// .unwrap_or("unknown"); /// .unwrap_or("unknown");
/// println!("[{name}] Request started {req:?}"); /// println!("[{name}] Request started {req:?}");
/// let res = next.run(req, extensions).await; /// let fut = self.0.call(req, ext);
/// async move {
/// let res = fut.await;
/// println!("[{name}] Result: {res:?}"); /// println!("[{name}] Result: {res:?}");
/// res /// res
/// }.boxed()
/// } /// }
/// } /// }
/// ///
@ -69,7 +104,7 @@ where
/// let reqwest_client = Client::builder().build().unwrap(); /// let reqwest_client = Client::builder().build().unwrap();
/// let client = ClientBuilder::new(reqwest_client) /// let client = ClientBuilder::new(reqwest_client)
/// .with_init(Extension(LogName("my-client"))) /// .with_init(Extension(LogName("my-client")))
/// .with(LoggingMiddleware) /// .with(LoggingLayer)
/// .build(); /// .build();
/// let resp = client.get("https://truelayer.com").send().await.unwrap(); /// let resp = client.get("https://truelayer.com").send().await.unwrap();
/// println!("TrueLayer page HTML: {}", resp.text().await.unwrap()); /// println!("TrueLayer page HTML: {}", resp.text().await.unwrap());
@ -78,7 +113,8 @@ where
pub struct Extension<T>(pub T); pub struct Extension<T>(pub T);
impl<T: Send + Sync + Clone + 'static> RequestInitialiser for Extension<T> { impl<T: Send + Sync + Clone + 'static> RequestInitialiser for Extension<T> {
fn init(&self, req: RequestBuilder) -> RequestBuilder { fn init(&self, req: RequestBuilder, ext: &mut Extensions) -> RequestBuilder {
req.with_extension(self.0.clone()) ext.insert(self.0.clone());
req
} }
} }

View File

@ -1,6 +1,6 @@
[package] [package]
name = "reqwest-retry" name = "reqwest-retry"
version = "0.2.0" version = "0.3.0"
authors = ["Rodrigo Gryzinski <rodrigo.gryzinski@truelayer.com>"] authors = ["Rodrigo Gryzinski <rodrigo.gryzinski@truelayer.com>"]
edition = "2018" edition = "2018"
description = "Retry middleware for reqwest." description = "Retry middleware for reqwest."
@ -10,7 +10,7 @@ keywords = ["reqwest", "http", "middleware", "retry"]
categories = ["web-programming::http-client"] categories = ["web-programming::http-client"]
[dependencies] [dependencies]
reqwest-middleware = { version = "0.2.0", path = "../reqwest-middleware" } reqwest-middleware = { version = "0.3.0", path = "../reqwest-middleware" }
anyhow = "1" anyhow = "1"
async-trait = "0.1.51" async-trait = "0.1.51"
@ -23,6 +23,7 @@ retry-policies = "0.1"
task-local-extensions = "0.1.1" task-local-extensions = "0.1.1"
tokio = { version = "1.6", features = ["time"] } tokio = { version = "1.6", features = ["time"] }
tracing = "0.1.26" tracing = "0.1.26"
pin-project-lite = "0.2"
[dev-dependencies] [dev-dependencies]
async-std = { version = "1.10"} async-std = { version = "1.10"}

View File

@ -1,15 +1,17 @@
//! `RetryTransientMiddleware` implements retrying requests on transient errors. //! `RetryTransientMiddleware` implements retrying requests on transient errors.
use std::pin::Pin;
use std::task::{ready, Context, Poll};
use crate::retryable::Retryable; use crate::retryable::Retryable;
use anyhow::anyhow;
use chrono::Utc; use chrono::Utc;
use futures::Future;
use pin_project_lite::pin_project;
use reqwest::{Request, Response}; use reqwest::{Request, Response};
use reqwest_middleware::{Error, Middleware, Next, Result}; use reqwest_middleware::{Error, Layer, Service};
use retry_policies::RetryPolicy; use retry_policies::RetryPolicy;
use task_local_extensions::Extensions; use task_local_extensions::Extensions;
use tokio::time::Sleep;
/// We limit the number of retries to a maximum of `10` to avoid stack-overflow issues due to the recursion.
static MAXIMUM_NUMBER_OF_RETRIES: u32 = 10;
/// `RetryTransientMiddleware` offers retry logic for requests that fail in a transient manner /// `RetryTransientMiddleware` offers retry logic for requests that fail in a transient manner
/// and can be safely executed again. /// and can be safely executed again.
@ -58,76 +60,253 @@ impl<T: RetryPolicy + Send + Sync> RetryTransientMiddleware<T> {
} }
} }
#[async_trait::async_trait] impl<T, Svc> Layer<Svc> for RetryTransientMiddleware<T>
impl<T: RetryPolicy + Send + Sync> Middleware for RetryTransientMiddleware<T> { where
async fn handle( T: RetryPolicy + Clone + Send + Sync + 'static,
&self, {
req: Request, type Service = Retry<TowerRetryPolicy<T>, Svc>;
extensions: &mut Extensions,
next: Next<'_>, fn layer(&self, inner: Svc) -> Self::Service {
) -> Result<Response> { Retry {
// TODO: Ideally we should create a new instance of the `Extensions` map to pass policy: TowerRetryPolicy {
// downstream. This will guard against previous retries poluting `Extensions`. n_past_retries: 0,
// That is, we only return what's populated in the typemap for the last retry attempt retry_policy: self.retry_policy.clone(),
// and copy those into the the `global` Extensions map. },
self.execute_with_retry(req, next, extensions).await service: inner,
}
} }
} }
impl<T: RetryPolicy + Send + Sync> RetryTransientMiddleware<T> { #[derive(Clone)]
/// This function will try to execute the request, if it fails pub struct TowerRetryPolicy<T> {
/// with an error classified as transient it will call itself n_past_retries: u32,
/// to retry the request. retry_policy: T,
async fn execute_with_retry<'a>( }
&'a self,
req: Request,
next: Next<'a>,
ext: &'a mut Extensions,
) -> Result<Response> {
let mut n_past_retries = 0;
loop {
// Cloning the request object before-the-fact is not ideal..
// However, if the body of the request is not static, e.g of type `Bytes`,
// the Clone operation should be of constant complexity and not O(N)
// since the byte abstraction is a shared pointer over a buffer.
let duplicate_request = req.try_clone().ok_or_else(|| {
Error::Middleware(anyhow!(
"Request object is not clonable. Are you passing a streaming body?".to_string()
))
})?;
let result = next.clone().run(duplicate_request, ext).await; pin_project! {
pub struct RetryFuture<T>
{
retry: Option<TowerRetryPolicy<T>>,
#[pin]
sleep: Sleep,
}
}
impl<T> Future for RetryFuture<T> {
type Output = TowerRetryPolicy<T>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = self.project();
ready!(this.sleep.poll(cx));
Poll::Ready(
this.retry
.take()
.expect("poll should not be called more than once"),
)
}
}
impl<T: RetryPolicy + Clone> Policy for TowerRetryPolicy<T> {
type Future = RetryFuture<T>;
fn retry(&self, _req: &Request, result: &Result<Response, Error>) -> Option<Self::Future> {
// We classify the response which will return None if not // We classify the response which will return None if not
// errors were returned. // errors were returned.
break match Retryable::from_reqwest_response(&result) { match Retryable::from_reqwest_response(result) {
Some(retryable) Some(Retryable::Transient) => {
if retryable == Retryable::Transient
&& n_past_retries < MAXIMUM_NUMBER_OF_RETRIES =>
{
// If the response failed and the error type was transient // If the response failed and the error type was transient
// we can safely try to retry the request. // we can safely try to retry the request.
let retry_decicion = self.retry_policy.should_retry(n_past_retries); let retry_decicion = self.retry_policy.should_retry(self.n_past_retries);
if let retry_policies::RetryDecision::Retry { execute_after } = retry_decicion { if let retry_policies::RetryDecision::Retry { execute_after } = retry_decicion {
let duration = (execute_after - Utc::now()) let duration = (execute_after - Utc::now()).to_std().ok()?;
.to_std()
.map_err(Error::middleware)?;
// Sleep the requested amount before we try again. // Sleep the requested amount before we try again.
tracing::warn!( tracing::warn!(
"Retry attempt #{}. Sleeping {:?} before the next attempt", "Retry attempt #{}. Sleeping {:?} before the next attempt",
n_past_retries, self.n_past_retries,
duration duration
); );
tokio::time::sleep(duration).await; let sleep = tokio::time::sleep(duration);
Some(RetryFuture {
n_past_retries += 1; retry: Some(TowerRetryPolicy {
continue; n_past_retries: self.n_past_retries + 1,
retry_policy: self.retry_policy.clone(),
}),
sleep,
})
} else { } else {
result None
}
}
Some(_) | None => None,
}
}
fn clone_request(&self, req: &Request) -> Option<Request> {
req.try_clone()
}
}
pub trait Policy: Sized {
/// The [`Future`] type returned by [`Policy::retry`].
type Future: Future<Output = Self>;
/// Check the policy if a certain request should be retried.
///
/// This method is passed a reference to the original request, and either
/// the [`Service::Response`] or [`Service::Error`] from the inner service.
///
/// If the request should **not** be retried, return `None`.
///
/// If the request *should* be retried, return `Some` future of a new
/// policy that would apply for the next request attempt.
///
/// [`Service::Response`]: crate::Service::Response
/// [`Service::Error`]: crate::Service::Error
fn retry(&self, req: &Request, result: &Result<Response, Error>) -> Option<Self::Future>;
/// Tries to clone a request before being passed to the inner service.
///
/// If the request cannot be cloned, return [`None`].
fn clone_request(&self, req: &Request) -> Option<Request>;
}
pin_project! {
/// Configure retrying requests of "failed" responses.
///
/// A [`Policy`] classifies what is a "failed" response.
#[derive(Clone, Debug)]
pub struct Retry<P, S> {
#[pin]
policy: P,
service: S,
}
}
impl<P, S> Service for Retry<P, S>
where
P: 'static + Policy + Clone,
S: 'static + Service + Clone,
{
type Future = ResponseFuture<P, S>;
fn call(&mut self, request: Request, ext: &mut Extensions) -> Self::Future {
let cloned = self.policy.clone_request(&request);
let future = self.service.call(request, ext);
ResponseFuture::new(cloned, self.clone(), future)
}
// fn call(&mut self, request: Request) -> Self::Future {
// let cloned = self.policy.clone_request(&request);
// let future = self.service.call(request);
// ResponseFuture::new(cloned, self.clone(), future)
// }
}
pin_project! {
/// The [`Future`] returned by a [`Retry`] service.
#[derive(Debug)]
pub struct ResponseFuture<P, S>
where
P: Policy,
S: Service,
{
request: Option<Request>,
#[pin]
retry: Retry<P, S>,
#[pin]
state: State<S::Future, P::Future>,
}
}
pin_project! {
#[project = StateProj]
#[derive(Debug)]
enum State<F, P> {
// Polling the future from [`Service::call`]
Called {
#[pin]
future: F
},
// Polling the future from [`Policy::retry`]
Checking {
#[pin]
checking: P
},
// Polling [`Service::poll_ready`] after [`Checking`] was OK.
Retrying,
}
}
impl<P, S> ResponseFuture<P, S>
where
P: Policy,
S: Service,
{
pub(crate) fn new(
request: Option<Request>,
retry: Retry<P, S>,
future: S::Future,
) -> ResponseFuture<P, S> {
ResponseFuture {
request,
retry,
state: State::Called { future },
}
}
}
impl<P, S> Future for ResponseFuture<P, S>
where
P: Policy + Clone,
S: Service + Clone,
{
type Output = Result<Response, Error>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let mut this = self.project();
loop {
match this.state.as_mut().project() {
StateProj::Called { future } => {
let result = ready!(future.poll(cx));
if let Some(ref req) = this.request {
match this.retry.policy.retry(req, &result) {
Some(checking) => {
this.state.set(State::Checking { checking });
}
None => return Poll::Ready(result),
}
} else {
// request wasn't cloned, so no way to retry it
return Poll::Ready(result);
}
}
StateProj::Checking { checking } => {
this.retry
.as_mut()
.project()
.policy
.set(ready!(checking.poll(cx)));
this.state.set(State::Retrying);
}
StateProj::Retrying => {
let req = this
.request
.take()
.expect("retrying requires cloned request");
*this.request = this.retry.policy.clone_request(&req);
this.state.set(State::Called {
future: this
.retry
.as_mut()
.project()
.service
.call(req, &mut Extensions::new()),
});
} }
} }
Some(_) | None => result,
};
} }
} }
} }

View File

@ -147,17 +147,6 @@ assert_retry_succeeds!(429, StatusCode::OK);
assert_no_retry!(431, StatusCode::REQUEST_HEADER_FIELDS_TOO_LARGE); assert_no_retry!(431, StatusCode::REQUEST_HEADER_FIELDS_TOO_LARGE);
assert_no_retry!(451, StatusCode::UNAVAILABLE_FOR_LEGAL_REASONS); assert_no_retry!(451, StatusCode::UNAVAILABLE_FOR_LEGAL_REASONS);
// We assert that we cap retries at 10, which means that we will
// get 11 calls to the RetryResponder.
assert_retry_succeeds_inner!(
500,
assert_maximum_retries_is_not_exceeded,
StatusCode::INTERNAL_SERVER_ERROR,
100,
11,
RetryResponder::new(100_u32, 500)
);
pub struct RetryTimeoutResponder(Arc<AtomicU32>, u32, std::time::Duration); pub struct RetryTimeoutResponder(Arc<AtomicU32>, u32, std::time::Duration);
impl RetryTimeoutResponder { impl RetryTimeoutResponder {

View File

@ -1,6 +1,6 @@
[package] [package]
name = "reqwest-tracing" name = "reqwest-tracing"
version = "0.4.0" version = "0.5.0"
authors = ["Rodrigo Gryzinski <rodrigo.gryzinski@truelayer.com>"] authors = ["Rodrigo Gryzinski <rodrigo.gryzinski@truelayer.com>"]
edition = "2018" edition = "2018"
description = "Opentracing middleware for reqwest." description = "Opentracing middleware for reqwest."
@ -19,12 +19,13 @@ opentelemetry_0_18 = ["opentelemetry_0_18_pkg", "tracing-opentelemetry_0_18_pkg"
[dependencies] [dependencies]
reqwest-middleware = { version = "0.2.0", path = "../reqwest-middleware" } reqwest-middleware = { version = "0.3.0", path = "../reqwest-middleware" }
async-trait = "0.1.51" async-trait = "0.1.51"
reqwest = { version = "0.11", default-features = false } reqwest = { version = "0.11", default-features = false }
task-local-extensions = "0.1.1" task-local-extensions = "0.1.1"
tracing = "0.1.26" tracing = "0.1.26"
pin-project-lite = "0.2"
opentelemetry_0_13_pkg = { package = "opentelemetry", version = "0.13", optional = true } opentelemetry_0_13_pkg = { package = "opentelemetry", version = "0.13", optional = true }
opentelemetry_0_14_pkg = { package = "opentelemetry", version = "0.14", optional = true } opentelemetry_0_14_pkg = { package = "opentelemetry", version = "0.14", optional = true }

View File

@ -4,11 +4,11 @@
//! //!
//! The simplest possible usage: //! The simplest possible usage:
//! ```no_run //! ```no_run
//! # use reqwest_middleware::Result; //! # use reqwest_middleware::Error;
//! use reqwest_middleware::{ClientBuilder}; //! use reqwest_middleware::{ClientBuilder};
//! use reqwest_tracing::TracingMiddleware; //! use reqwest_tracing::TracingMiddleware;
//! //!
//! # async fn example() -> Result<()> { //! # async fn example() -> Result<(), Error> {
//! let reqwest_client = reqwest::Client::builder().build().unwrap(); //! let reqwest_client = reqwest::Client::builder().build().unwrap();
//! let client = ClientBuilder::new(reqwest_client) //! let client = ClientBuilder::new(reqwest_client)
//! // Insert the tracing middleware //! // Insert the tracing middleware
@ -22,12 +22,12 @@
//! //!
//! To customise the span names use [`OtelName`]. //! To customise the span names use [`OtelName`].
//! ```no_run //! ```no_run
//! # use reqwest_middleware::Result; //! # use reqwest_middleware::Error;
//! use reqwest_middleware::{ClientBuilder, Extension}; //! use reqwest_middleware::{ClientBuilder, Extension};
//! use reqwest_tracing::{ //! use reqwest_tracing::{
//! TracingMiddleware, OtelName //! TracingMiddleware, OtelName
//! }; //! };
//! # async fn example() -> Result<()> { //! # async fn example() -> Result<(), Error> {
//! let reqwest_client = reqwest::Client::builder().build().unwrap(); //! let reqwest_client = reqwest::Client::builder().build().unwrap();
//! let client = ClientBuilder::new(reqwest_client) //! let client = ClientBuilder::new(reqwest_client)
//! // Inserts the extension before the request is started //! // Inserts the extension before the request is started
@ -52,7 +52,7 @@
//! //!
//! Note that Opentelemetry tracks start and stop already, there is no need to have a custom builder like this. //! Note that Opentelemetry tracks start and stop already, there is no need to have a custom builder like this.
//! ```rust //! ```rust
//! use reqwest_middleware::Result; //! use reqwest_middleware::Error;
//! use task_local_extensions::Extensions; //! use task_local_extensions::Extensions;
//! use reqwest::{Request, Response}; //! use reqwest::{Request, Response};
//! use reqwest_middleware::ClientBuilder; //! use reqwest_middleware::ClientBuilder;
@ -62,16 +62,17 @@
//! use tracing::Span; //! use tracing::Span;
//! use std::time::{Duration, Instant}; //! use std::time::{Duration, Instant};
//! //!
//! pub struct TimeTrace; //! pub struct TimeTrace(Instant);
//! //!
//! impl ReqwestOtelSpanBackend for TimeTrace { //! impl ReqwestOtelSpanBackend for TimeTrace {
//! fn on_request_start(req: &Request, extension: &mut Extensions) -> Span { //! fn on_request_start(req: &Request, _extension: &mut Extensions) -> (Self, Span) {
//! extension.insert(Instant::now()); //! let now = Self(Instant::now());
//! reqwest_otel_span!(name="example-request", req, time_elapsed = tracing::field::Empty) //! let span = reqwest_otel_span!(name="example-request", req, time_elapsed = tracing::field::Empty);
//! (now, span)
//! } //! }
//! //!
//! fn on_request_end(span: &Span, outcome: &Result<Response>, extension: &mut Extensions) { //! fn on_request_end(self, span: &Span, outcome: &Result<Response, Error>) {
//! let time_elapsed = extension.get::<Instant>().unwrap().elapsed().as_millis() as i64; //! let time_elapsed = self.0.elapsed().as_millis() as i64;
//! default_on_request_end(span, outcome); //! default_on_request_end(span, outcome);
//! span.record("time_elapsed", &time_elapsed); //! span.record("time_elapsed", &time_elapsed);
//! } //! }

View File

@ -1,7 +1,14 @@
use std::{
future::Future,
task::{ready, Context, Poll},
};
use pin_project_lite::pin_project;
use reqwest::{Request, Response}; use reqwest::{Request, Response};
use reqwest_middleware::{Middleware, Next, Result}; use reqwest_middleware::{Error, Layer, Service};
use task_local_extensions::Extensions; use task_local_extensions::Extensions;
use tracing::Instrument; // use tower::{Layer, Service};
use tracing::Span;
use crate::{DefaultSpanBackend, ReqwestOtelSpanBackend}; use crate::{DefaultSpanBackend, ReqwestOtelSpanBackend};
@ -10,6 +17,8 @@ pub struct TracingMiddleware<S: ReqwestOtelSpanBackend> {
span_backend: std::marker::PhantomData<S>, span_backend: std::marker::PhantomData<S>,
} }
impl<S: ReqwestOtelSpanBackend> Copy for TracingMiddleware<S> {}
impl<S: ReqwestOtelSpanBackend> TracingMiddleware<S> { impl<S: ReqwestOtelSpanBackend> TracingMiddleware<S> {
pub fn new() -> TracingMiddleware<S> { pub fn new() -> TracingMiddleware<S> {
TracingMiddleware { TracingMiddleware {
@ -30,20 +39,36 @@ impl Default for TracingMiddleware<DefaultSpanBackend> {
} }
} }
#[async_trait::async_trait] impl<ReqwestOtelSpan, Svc> Layer<Svc> for TracingMiddleware<ReqwestOtelSpan>
impl<ReqwestOtelSpan> Middleware for TracingMiddleware<ReqwestOtelSpan>
where where
ReqwestOtelSpan: ReqwestOtelSpanBackend + Sync + Send + 'static, ReqwestOtelSpan: ReqwestOtelSpanBackend + Sync + Send + 'static,
Svc: Service,
{ {
async fn handle( type Service = TracingMiddlewareService<ReqwestOtelSpan, Svc>;
&self,
req: Request,
extensions: &mut Extensions,
next: Next<'_>,
) -> Result<Response> {
let request_span = ReqwestOtelSpan::on_request_start(&req, extensions);
let outcome_future = async { fn layer(&self, inner: Svc) -> Self::Service {
TracingMiddlewareService {
service: inner,
_layer: *self,
}
}
}
/// Middleware Service for tracing requests using the current Opentelemetry Context.
pub struct TracingMiddlewareService<S: ReqwestOtelSpanBackend, Svc> {
_layer: TracingMiddleware<S>,
service: Svc,
}
impl<ReqwestOtelSpan, Svc> Service for TracingMiddlewareService<ReqwestOtelSpan, Svc>
where
ReqwestOtelSpan: ReqwestOtelSpanBackend + Sync + Send + 'static,
Svc: Service,
{
type Future = TracingMiddlewareFuture<ReqwestOtelSpan, Svc::Future>;
fn call(&mut self, req: Request, ext: &mut Extensions) -> Self::Future {
let (backend, span) = ReqwestOtelSpan::on_request_start(&req, ext);
// Adds tracing headers to the given request to propagate the OpenTelemetry context to downstream revivers of the request. // Adds tracing headers to the given request to propagate the OpenTelemetry context to downstream revivers of the request.
// Spans added by downstream consumers will be part of the same trace. // Spans added by downstream consumers will be part of the same trace.
#[cfg(any( #[cfg(any(
@ -54,14 +79,42 @@ where
feature = "opentelemetry_0_17", feature = "opentelemetry_0_17",
feature = "opentelemetry_0_18", feature = "opentelemetry_0_18",
))] ))]
let req = crate::otel::inject_opentelemetry_context_into_request(req); let request = crate::otel::inject_opentelemetry_context_into_request(request);
// Run the request let future = self.service.call(req, ext);
let outcome = next.run(req, extensions).await;
ReqwestOtelSpan::on_request_end(&request_span, &outcome, extensions); TracingMiddlewareFuture {
outcome span,
backend: Some(backend),
future,
}
}
}
pin_project!(
pub struct TracingMiddlewareFuture<S: ReqwestOtelSpanBackend, F> {
span: Span,
backend: Option<S>,
#[pin]
future: F,
}
);
impl<S: ReqwestOtelSpanBackend, F: Future<Output = Result<Response, Error>>> Future
for TracingMiddlewareFuture<S, F>
{
type Output = F::Output;
fn poll(self: std::pin::Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = self.project();
let outcome = {
let _guard = this.span.enter();
ready!(this.future.poll(cx))
}; };
this.backend
outcome_future.instrument(request_span.clone()).await .take()
.expect("poll should not be called after completion")
.on_request_end(this.span, &outcome);
Poll::Ready(outcome)
} }
} }

View File

@ -2,7 +2,7 @@ use std::borrow::Cow;
use reqwest::header::{HeaderMap, HeaderValue}; use reqwest::header::{HeaderMap, HeaderValue};
use reqwest::{Request, Response, StatusCode as RequestStatusCode, Url}; use reqwest::{Request, Response, StatusCode as RequestStatusCode, Url};
use reqwest_middleware::{Error, Result}; use reqwest_middleware::Error;
use task_local_extensions::Extensions; use task_local_extensions::Extensions;
use tracing::Span; use tracing::Span;
@ -39,17 +39,17 @@ pub const HTTP_USER_AGENT: &str = "http.user_agent";
/// Check out [`reqwest_otel_span`] documentation for examples. /// Check out [`reqwest_otel_span`] documentation for examples.
/// ///
/// [`TracingMiddleware`]: crate::middleware::TracingMiddleware. /// [`TracingMiddleware`]: crate::middleware::TracingMiddleware.
pub trait ReqwestOtelSpanBackend { pub trait ReqwestOtelSpanBackend: Sized {
/// Initalized a new span before the request is executed. /// Initalized a new span before the request is executed.
fn on_request_start(req: &Request, extension: &mut Extensions) -> Span; fn on_request_start(req: &Request, extension: &mut Extensions) -> (Self, Span);
/// Runs after the request call has executed. /// Runs after the request call has executed.
fn on_request_end(span: &Span, outcome: &Result<Response>, extension: &mut Extensions); fn on_request_end(self, span: &Span, outcome: &Result<Response, Error>);
} }
/// Populates default success/failure fields for a given [`reqwest_otel_span!`] span. /// Populates default success/failure fields for a given [`reqwest_otel_span!`] span.
#[inline] #[inline]
pub fn default_on_request_end(span: &Span, outcome: &Result<Response>) { pub fn default_on_request_end(span: &Span, outcome: &Result<Response, Error>) {
match outcome { match outcome {
Ok(res) => default_on_request_success(span, res), Ok(res) => default_on_request_success(span, res),
Err(err) => default_on_request_failure(span, err), Err(err) => default_on_request_failure(span, err),
@ -95,15 +95,15 @@ pub fn default_on_request_failure(span: &Span, e: &Error) {
pub struct DefaultSpanBackend; pub struct DefaultSpanBackend;
impl ReqwestOtelSpanBackend for DefaultSpanBackend { impl ReqwestOtelSpanBackend for DefaultSpanBackend {
fn on_request_start(req: &Request, ext: &mut Extensions) -> Span { fn on_request_start(req: &Request, ext: &mut Extensions) -> (DefaultSpanBackend, Span) {
let name = ext let name = ext
.get::<OtelName>() .get::<OtelName>()
.map(|on| on.0.as_ref()) .map(|on| on.0.as_ref())
.unwrap_or("reqwest-http-client"); .unwrap_or("reqwest-http-client");
reqwest_otel_span!(name = name, req) (Self, reqwest_otel_span!(name = name, req))
} }
fn on_request_end(span: &Span, outcome: &Result<Response>, _: &mut Extensions) { fn on_request_end(self, span: &Span, outcome: &Result<Response, Error>) {
default_on_request_end(span, outcome) default_on_request_end(span, outcome)
} }
} }
@ -119,16 +119,19 @@ fn get_header_value(key: &str, headers: &HeaderMap) -> String {
pub struct SpanBackendWithUrl; pub struct SpanBackendWithUrl;
impl ReqwestOtelSpanBackend for SpanBackendWithUrl { impl ReqwestOtelSpanBackend for SpanBackendWithUrl {
fn on_request_start(req: &Request, ext: &mut Extensions) -> Span { fn on_request_start(req: &Request, ext: &mut Extensions) -> (Self, Span) {
let name = ext let name = ext
.get::<OtelName>() .get::<OtelName>()
.map(|on| on.0.as_ref()) .map(|on| on.0.as_ref())
.unwrap_or("reqwest-http-client"); .unwrap_or("reqwest-http-client");
reqwest_otel_span!(name = name, req, http.url = %remove_credentials(req.url())) (
Self,
reqwest_otel_span!(name = name, req, http.url = %remove_credentials(req.url())),
)
} }
fn on_request_end(span: &Span, outcome: &Result<Response>, _: &mut Extensions) { fn on_request_end(self, span: &Span, outcome: &Result<Response, Error>) {
default_on_request_end(span, outcome) default_on_request_end(span, outcome)
} }
} }
@ -156,12 +159,12 @@ fn get_span_status(request_status: RequestStatusCode) -> Option<&'static str> {
/// ///
/// Usage: /// Usage:
/// ```no_run /// ```no_run
/// # use reqwest_middleware::Result; /// # use reqwest_middleware::Error;
/// use reqwest_middleware::{ClientBuilder, Extension}; /// use reqwest_middleware::{ClientBuilder, Extension};
/// use reqwest_tracing::{ /// use reqwest_tracing::{
/// TracingMiddleware, OtelName /// TracingMiddleware, OtelName
/// }; /// };
/// # async fn example() -> Result<()> { /// # async fn example() -> Result<(), Error> {
/// let reqwest_client = reqwest::Client::builder().build().unwrap(); /// let reqwest_client = reqwest::Client::builder().build().unwrap();
/// let client = ClientBuilder::new(reqwest_client) /// let client = ClientBuilder::new(reqwest_client)
/// // Inserts the extension before the request is started /// // Inserts the extension before the request is started

View File

@ -30,7 +30,7 @@
/// The second argument passed to [`reqwest_otel_span!`](crate::reqwest_otel_span) is a reference to an [`reqwest::Request`]. /// The second argument passed to [`reqwest_otel_span!`](crate::reqwest_otel_span) is a reference to an [`reqwest::Request`].
/// ///
/// ```rust /// ```rust
/// use reqwest_middleware::Result; /// use reqwest_middleware::Error;
/// use task_local_extensions::Extensions; /// use task_local_extensions::Extensions;
/// use reqwest::{Request, Response}; /// use reqwest::{Request, Response};
/// use reqwest_tracing::{ /// use reqwest_tracing::{
@ -41,11 +41,11 @@
/// pub struct CustomReqwestOtelSpanBackend; /// pub struct CustomReqwestOtelSpanBackend;
/// ///
/// impl ReqwestOtelSpanBackend for CustomReqwestOtelSpanBackend { /// impl ReqwestOtelSpanBackend for CustomReqwestOtelSpanBackend {
/// fn on_request_start(req: &Request, _extension: &mut Extensions) -> Span { /// fn on_request_start(req: &Request, _extension: &mut Extensions) -> (Self, Span) {
/// reqwest_otel_span!(name = "reqwest-http-request", req) /// (Self, reqwest_otel_span!(name = "reqwest-http-request", req))
/// } /// }
/// ///
/// fn on_request_end(span: &Span, outcome: &Result<Response>, _extension: &mut Extensions) { /// fn on_request_end(self, span: &Span, outcome: &Result<Response, Error>) {
/// default_on_request_end(span, outcome) /// default_on_request_end(span, outcome)
/// } /// }
/// } /// }