This commit is contained in:
Conrad Ludgate 2022-11-15 16:34:33 +00:00
parent 6eaa2365ed
commit 571b9abc49
No known key found for this signature in database
GPG key ID: 197E3CACA1C980B5
13 changed files with 159 additions and 167 deletions

View file

@ -11,8 +11,8 @@ to allow for client middleware chains.
This crate provides functionality for building and running middleware but no middleware This crate provides functionality for building and running middleware but no middleware
implementations. This repository also contains a couple of useful concrete middleware crates: implementations. This repository also contains a couple of useful concrete middleware crates:
* [`reqwest-retry`](https://crates.io/crates/reqwest-retry): retry failed requests. - [`reqwest-retry`](https://crates.io/crates/reqwest-retry): retry failed requests.
* [`reqwest-tracing`](https://crates.io/crates/reqwest-tracing): - [`reqwest-tracing`](https://crates.io/crates/reqwest-tracing):
[`tracing`](https://crates.io/crates/tracing) integration, optional opentelemetry support. [`tracing`](https://crates.io/crates/tracing) integration, optional opentelemetry support.
## Overview ## Overview
@ -29,10 +29,12 @@ reqwest-middleware = "0.1.6"
reqwest-retry = "0.1.5" reqwest-retry = "0.1.5"
reqwest-tracing = "0.2.3" reqwest-tracing = "0.2.3"
tokio = { version = "1.12.0", features = ["macros", "rt-multi-thread"] } tokio = { version = "1.12.0", features = ["macros", "rt-multi-thread"] }
tower = "0.4"
``` ```
```rust ```rust
use reqwest_middleware::{ClientBuilder, ClientWithMiddleware}; use reqwest::Response;
use reqwest_middleware::{ClientBuilder, ClientWithMiddleware, Error, MiddlewareRequest, RequestInitialiser, ReqwestService};
use reqwest_retry::{RetryTransientMiddleware, policies::ExponentialBackoff}; use reqwest_retry::{RetryTransientMiddleware, policies::ExponentialBackoff};
use reqwest_tracing::TracingMiddleware; use reqwest_tracing::TracingMiddleware;
@ -49,7 +51,12 @@ async fn main() {
run(client).await; run(client).await;
} }
async fn run(client: ClientWithMiddleware) { async fn run<M, I>(client: ClientWithMiddleware<M, I>)
where
M: tower::Layer<ReqwestService>,
M::Service: tower::Service<MiddlewareRequest, Response = Response, Error = Error>,
I: RequestInitialiser,
{
client client
.get("https://truelayer.com") .get("https://truelayer.com")
.header("foo", "bar") .header("foo", "bar")

View file

@ -19,6 +19,7 @@ serde = "1"
task-local-extensions = "0.1.1" task-local-extensions = "0.1.1"
thiserror = "1" thiserror = "1"
tower = { version = "0.4", features = ["util"] } tower = { version = "0.4", features = ["util"] }
futures = "0.3"
[dev-dependencies] [dev-dependencies]
reqwest = "0.11" reqwest = "0.11"

View file

@ -1,80 +1,62 @@
use futures::future::BoxFuture;
use futures::FutureExt;
use reqwest::header::{HeaderMap, HeaderName, HeaderValue}; use reqwest::header::{HeaderMap, HeaderName, HeaderValue};
use reqwest::multipart::Form; use reqwest::multipart::Form;
use reqwest::{Body, Client, IntoUrl, Method, Request, Response}; use reqwest::{Body, Client, IntoUrl, Method, Request, Response};
use serde::Serialize; use serde::Serialize;
use std::convert::TryFrom; use std::convert::TryFrom;
use std::fmt::{self, Display}; use std::fmt::{self, Display};
use std::task::{Context, Poll};
use std::time::Duration; use std::time::Duration;
use task_local_extensions::Extensions; use task_local_extensions::Extensions;
use tower::layer::util::{Identity, Stack}; use tower::layer::util::{Identity, Stack};
use tower::{Layer, Service, ServiceBuilder, ServiceExt}; use tower::{Layer, Service, ServiceBuilder, ServiceExt};
use crate::error::Result; use crate::{Error, MiddlewareRequest, RequestInitialiser, RequestStack};
use crate::{Error, MiddlewareRequest, RequestInitialiser};
/// A `ClientBuilder` is used to build a [`ClientWithMiddleware`]. /// A `ClientBuilder` is used to build a [`ClientWithMiddleware`].
/// ///
/// [`ClientWithMiddleware`]: crate::ClientWithMiddleware /// [`ClientWithMiddleware`]: crate::ClientWithMiddleware
pub struct ClientBuilder<M> { pub struct ClientBuilder<M, I> {
client: Client, client: Client,
middleware_stack: ServiceBuilder<M>, middleware_stack: ServiceBuilder<M>,
initialiser_stack: (), initialiser_stack: I,
} }
impl ClientBuilder<Identity> { impl ClientBuilder<Identity, Identity> {
pub fn new(client: Client) -> Self { pub fn new(client: Client) -> Self {
ClientBuilder { ClientBuilder {
client, client,
middleware_stack: ServiceBuilder::new(), middleware_stack: ServiceBuilder::new(),
initialiser_stack: (), initialiser_stack: Identity::new(),
} }
} }
} }
impl<L> ClientBuilder<L> { impl<M, I> ClientBuilder<M, I> {
/// Convenience method to attach middleware. /// Convenience method to attach middleware.
/// pub fn with<T>(self, layer: T) -> ClientBuilder<Stack<T, M>, I> {
/// If you need to keep a reference to the middleware after attaching, use [`with_arc`].
///
/// [`with_arc`]: Self::with_arc
pub fn layer<T>(self, layer: T) -> ClientBuilder<Stack<T, L>> {
ClientBuilder { ClientBuilder {
client: self.client, client: self.client,
middleware_stack: self.middleware_stack.layer(layer), middleware_stack: self.middleware_stack.layer(layer),
initialiser_stack: (), initialiser_stack: self.initialiser_stack,
} }
} }
// /// Add middleware to the chain. [`with`] is more ergonomic if you don't need the `Arc`. /// Convenience method to attach a request initialiser.
// /// pub fn with_init<T>(self, initialiser: T) -> ClientBuilder<M, RequestStack<T, I>> {
// /// [`with`]: Self::with ClientBuilder {
// pub fn with_arc(mut self, middleware: Arc<dyn Middleware>) -> Self { client: self.client,
// self.middleware_stack.push(middleware); middleware_stack: self.middleware_stack,
// self initialiser_stack: RequestStack {
// } inner: initialiser,
outer: self.initialiser_stack,
// /// Convenience method to attach a request initialiser. },
// /// }
// /// If you need to keep a reference to the initialiser after attaching, use [`with_arc_init`]. }
// ///
// /// [`with_arc_init`]: Self::with_arc_init
// pub fn with_init<I>(self, initialiser: I) -> Self
// where
// I: RequestInitialiser,
// {
// self.with_arc_init(Arc::new(initialiser))
// }
// /// Add a request initialiser to the chain. [`with_init`] is more ergonomic if you don't need the `Arc`.
// ///
// /// [`with_init`]: Self::with_init
// pub fn with_arc_init(mut self, initialiser: Arc<dyn RequestInitialiser>) -> Self {
// self.initialiser_stack.push(initialiser);
// self
// }
/// Returns a `ClientWithMiddleware` using this builder configuration. /// Returns a `ClientWithMiddleware` using this builder configuration.
pub fn build(self) -> ClientWithMiddleware<L, ()> { pub fn build(self) -> ClientWithMiddleware<M, I> {
ClientWithMiddleware { ClientWithMiddleware {
inner: self.client, inner: self.client,
middleware_stack: self.middleware_stack, middleware_stack: self.middleware_stack,
@ -92,21 +74,7 @@ pub struct ClientWithMiddleware<M, I> {
initialiser_stack: I, initialiser_stack: I,
} }
// impl<M: Layer<ReqService>> ClientWithMiddleware<M, ()> impl<M: Layer<ReqwestService>, I: RequestInitialiser> ClientWithMiddleware<M, I>
// where
// M::Service: Service<MiddlewareRequest, Response = Response, Error = Error>,
// {
// /// See [`ClientBuilder`] for a more ergonomic way to build `ClientWithMiddleware` instances.
// pub fn new(client: Client, middleware_stack: M) -> Self {
// ClientWithMiddleware {
// inner: client,
// middleware_stack,
// initialiser_stack: (),
// }
// }
// }
impl<M: Layer<ReqService>, I: RequestInitialiser> ClientWithMiddleware<M, I>
where where
M::Service: Service<MiddlewareRequest, Response = Response, Error = Error>, M::Service: Service<MiddlewareRequest, Response = Response, Error = Error>,
{ {
@ -181,28 +149,26 @@ pub struct RequestBuilder<'client, M, I> {
extensions: Extensions, extensions: Extensions,
} }
pub type BoxFuture<'a, T> = std::pin::Pin<Box<dyn std::future::Future<Output = T> + Send + 'a>>;
#[derive(Clone)] #[derive(Clone)]
pub struct ReqService(Client); pub struct ReqwestService(Client);
impl Service<MiddlewareRequest> for ReqService { impl Service<MiddlewareRequest> for ReqwestService {
type Response = Response; type Response = Response;
type Error = Error; type Error = Error;
type Future = BoxFuture<'static, Result<Response>>; type Future = BoxFuture<'static, Result<Response, Error>>;
fn poll_ready(&mut self, _: &mut std::task::Context<'_>) -> std::task::Poll<Result<()>> { fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll<Result<(), Error>> {
std::task::Poll::Ready(Ok(())) Poll::Ready(Ok(()))
} }
fn call(&mut self, req: MiddlewareRequest) -> Self::Future { fn call(&mut self, req: MiddlewareRequest) -> Self::Future {
let req = req.request; let req = req.request;
let client = self.0.clone(); let client = self.0.clone();
Box::pin(async move { client.execute(req).await.map_err(Error::from) }) async move { client.execute(req).await.map_err(Error::from) }.boxed()
} }
} }
impl<M: Layer<ReqService>, I: RequestInitialiser> RequestBuilder<'_, M, I> impl<M: Layer<ReqwestService>, I: RequestInitialiser> RequestBuilder<'_, M, I>
where where
M::Service: Service<MiddlewareRequest, Response = Response, Error = Error>, M::Service: Service<MiddlewareRequest, Response = Response, Error = Error>,
{ {
@ -304,7 +270,7 @@ where
&mut self.extensions &mut self.extensions
} }
pub async fn send(self) -> Result<Response> { pub async fn send(self) -> Result<Response, Error> {
let Self { let Self {
inner, inner,
client, client,
@ -313,7 +279,7 @@ where
let req = inner.build()?; let req = inner.build()?;
client client
.middleware_stack .middleware_stack
.service(ReqService(client.inner.clone())) .service(ReqwestService(client.inner.clone()))
.oneshot(MiddlewareRequest { .oneshot(MiddlewareRequest {
request: req, request: req,
extensions, extensions,

View file

@ -1,7 +1,5 @@
use thiserror::Error; use thiserror::Error;
pub type Result<T> = std::result::Result<T, Error>;
#[derive(Error, Debug)] #[derive(Error, Debug)]
pub enum Error { pub enum Error {
/// There was an error running some middleware /// There was an error running some middleware

View file

@ -7,27 +7,31 @@
//! //!
//! ``` //! ```
//! use reqwest::{Client, Request, Response}; //! use reqwest::{Client, Request, Response};
//! use reqwest_middleware::{ClientBuilder, Middleware, Next, Result}; //! use reqwest_middleware::{ClientBuilder, Error, Extension, MiddlewareRequest};
//! use task_local_extensions::Extensions; //! use task_local_extensions::Extensions;
//! use futures::FutureExt; //! use futures::future::{BoxFuture, FutureExt};
//! use std::task::{Context, Poll}; //! use std::task::{Context, Poll};
//! //!
//! struct LoggingLayer; //! struct LoggingLayer;
//! struct LoggingService<S>(S); //! struct LoggingService<S>(S);
//! //!
//! impl<S> tower::Layer<S> for LoggingLayer { //! impl<S> tower::Layer<S> for LoggingLayer {
//! type Service = LoggingService<S>; //! type Service = LoggingService<S>;
//! //!
//! fn layer(&self, inner: S) -> Self::Service { //! fn layer(&self, inner: S) -> Self::Service {
//! LoggingService(inner) //! LoggingService(inner)
//! } //! }
//! } //! }
//! //!
//! impl<S: tower::Service<MiddlewareRequest>> tower::Service<MiddlewareRequest> for LoggingService<S> { //! impl<S> tower::Service<MiddlewareRequest> for LoggingService<S>
//! type Response = S::Response; //! where
//! type Error = S::Error; //! S: tower::Service<MiddlewareRequest, Response = Response, Error = Error>,
//! type Future = futures::BoxFuture<'static, Result<S::Response, S::Error>>; //! S::Future: Send + 'static,
//! //! {
//! type Response = Response;
//! type Error = Error;
//! type Future = BoxFuture<'static, Result<Response, Error>>;
//!
//! fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> { //! fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
//! self.0.poll_ready(cx) //! self.0.poll_ready(cx)
//! } //! }
@ -46,7 +50,7 @@
//! async fn run() { //! async fn run() {
//! let reqwest_client = Client::builder().build().unwrap(); //! let reqwest_client = Client::builder().build().unwrap();
//! let client = ClientBuilder::new(reqwest_client) //! let client = ClientBuilder::new(reqwest_client)
//! .layer(LoggingLayer) //! .with(LoggingLayer)
//! .build(); //! .build();
//! let resp = client.get("https://truelayer.com").send().await.unwrap(); //! let resp = client.get("https://truelayer.com").send().await.unwrap();
//! println!("TrueLayer page HTML: {}", resp.text().await.unwrap()); //! println!("TrueLayer page HTML: {}", resp.text().await.unwrap());
@ -69,9 +73,9 @@ mod client;
mod error; mod error;
mod req_init; mod req_init;
pub use client::{ClientBuilder, ClientWithMiddleware, ReqService, RequestBuilder}; pub use client::{ClientBuilder, ClientWithMiddleware, RequestBuilder, ReqwestService};
pub use error::{Error, Result}; pub use error::Error;
pub use req_init::{Extension, RequestInitialiser}; pub use req_init::{Extension, RequestInitialiser, RequestStack};
pub struct MiddlewareRequest { pub struct MiddlewareRequest {
pub request: reqwest::Request, pub request: reqwest::Request,

View file

@ -1,5 +1,6 @@
use reqwest::RequestBuilder; use reqwest::RequestBuilder;
use task_local_extensions::Extensions; use task_local_extensions::Extensions;
use tower::layer::util::Identity;
/// When attached to a [`ClientWithMiddleware`] (generally using [`with_init`]), it is run /// When attached to a [`ClientWithMiddleware`] (generally using [`with_init`]), it is run
/// whenever the client starts building a request, in the order it was attached. /// whenever the client starts building a request, in the order it was attached.
@ -7,12 +8,14 @@ use task_local_extensions::Extensions;
/// # Example /// # Example
/// ///
/// ``` /// ```
/// use reqwest_middleware::{RequestInitialiser, MiddlewareRequest}; /// use reqwest::RequestBuilder;
/// use reqwest_middleware::RequestInitialiser;
/// use task_local_extensions::Extensions;
/// ///
/// struct AuthInit; /// struct AuthInit;
/// ///
/// impl RequestInitialiser for AuthInit { /// impl RequestInitialiser for AuthInit {
/// fn init(&self, req: MiddlewareRequest) -> MiddlewareRequest { /// fn init(&self, req: RequestBuilder, ext: &mut Extensions) -> RequestBuilder {
/// req.bearer_auth("my_auth_token") /// req.bearer_auth("my_auth_token")
/// } /// }
/// } /// }
@ -24,51 +27,82 @@ pub trait RequestInitialiser: 'static + Send + Sync {
fn init(&self, req: RequestBuilder, ext: &mut Extensions) -> RequestBuilder; fn init(&self, req: RequestBuilder, ext: &mut Extensions) -> RequestBuilder;
} }
impl RequestInitialiser for () { impl RequestInitialiser for Identity {
fn init(&self, req: RequestBuilder, _: &mut Extensions) -> RequestBuilder { fn init(&self, req: RequestBuilder, _: &mut Extensions) -> RequestBuilder {
req req
} }
} }
// impl<F> RequestInitialiser for F /// Two [`RequestInitialiser`]s chained together.
// where #[derive(Clone)]
// F: Send + Sync + 'static + Fn(MiddlewareRequest) -> MiddlewareRequest, pub struct RequestStack<Inner, Outer> {
// { pub(crate) inner: Inner,
// fn init(&self, req: RequestBuilder, ext: &mut Extensions) -> RequestBuilder { pub(crate) outer: Outer,
// (self)(req) }
// }
// } impl<I, O> RequestInitialiser for RequestStack<I, O>
where
I: RequestInitialiser,
O: RequestInitialiser,
{
fn init(&self, req: RequestBuilder, ext: &mut Extensions) -> RequestBuilder {
let req = self.inner.init(req, ext);
self.outer.init(req, ext)
}
}
/// A middleware that inserts the value into the [`Extensions`](task_local_extensions::Extensions) during the call. /// A middleware that inserts the value into the [`Extensions`](task_local_extensions::Extensions) during the call.
/// ///
/// This is a good way to inject extensions to middleware deeper in the stack /// This is a good way to inject extensions to middleware deeper in the stack
/// ///
/// ``` /// ```
/// use reqwest::{Client, Request, Response}; /// use reqwest::{Client, RequestBuilder, Response};
/// use reqwest_middleware::{ClientBuilder, Middleware, Next, Result, Extension}; /// use reqwest_middleware::{ClientBuilder, Error, Extension, MiddlewareRequest};
/// use task_local_extensions::Extensions; /// use task_local_extensions::Extensions;
/// use futures::future::{BoxFuture, FutureExt};
/// use std::task::{Context, Poll};
/// ///
/// #[derive(Clone)] /// #[derive(Clone)]
/// struct LogName(&'static str); /// struct LogName(&'static str);
/// struct LoggingMiddleware;
/// ///
/// #[async_trait::async_trait] /// struct LoggingLayer;
/// impl Middleware for LoggingMiddleware { /// struct LoggingService<S>(S);
/// async fn handle( ///
/// &self, /// impl<S> tower::Layer<S> for LoggingLayer {
/// req: Request, /// type Service = LoggingService<S>;
/// extensions: &mut Extensions, ///
/// next: Next<'_>, /// fn layer(&self, inner: S) -> Self::Service {
/// ) -> Result<Response> { /// LoggingService(inner)
/// }
/// }
///
/// impl<S> tower::Service<MiddlewareRequest> for LoggingService<S>
/// where
/// S: tower::Service<MiddlewareRequest, Response = Response, Error = Error>,
/// S::Future: Send + 'static,
/// {
/// type Response = Response;
/// type Error = Error;
/// type Future = BoxFuture<'static, Result<Response, Error>>;
///
/// fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
/// self.0.poll_ready(cx)
/// }
///
/// fn call(&mut self, req: MiddlewareRequest) -> Self::Future {
/// // get the log name or default to "unknown" /// // get the log name or default to "unknown"
/// let name = extensions /// let name = req
/// .extensions
/// .get() /// .get()
/// .map(|&LogName(name)| name) /// .map(|&LogName(name)| name)
/// .unwrap_or("unknown"); /// .unwrap_or("unknown");
/// println!("[{name}] Request started {req:?}"); /// println!("[{name}] Request started {:?}", &req.request);
/// let res = next.run(req, extensions).await; /// let fut = self.0.call(req);
/// println!("[{name}] Result: {res:?}"); /// async move {
/// res /// let res = fut.await;
/// println!("[{name}] Result: {res:?}");
/// res
/// }.boxed()
/// } /// }
/// } /// }
/// ///
@ -76,7 +110,7 @@ impl RequestInitialiser for () {
/// let reqwest_client = Client::builder().build().unwrap(); /// let reqwest_client = Client::builder().build().unwrap();
/// let client = ClientBuilder::new(reqwest_client) /// let client = ClientBuilder::new(reqwest_client)
/// .with_init(Extension(LogName("my-client"))) /// .with_init(Extension(LogName("my-client")))
/// .with(LoggingMiddleware) /// .with(LoggingLayer)
/// .build(); /// .build();
/// let resp = client.get("https://truelayer.com").send().await.unwrap(); /// let resp = client.get("https://truelayer.com").send().await.unwrap();
/// println!("TrueLayer page HTML: {}", resp.text().await.unwrap()); /// println!("TrueLayer page HTML: {}", resp.text().await.unwrap());

View file

@ -13,7 +13,7 @@
//! // Retry up to 3 times with increasing intervals between attempts. //! // Retry up to 3 times with increasing intervals between attempts.
//! let retry_policy = ExponentialBackoff::builder().build_with_max_retries(3); //! let retry_policy = ExponentialBackoff::builder().build_with_max_retries(3);
//! let client = ClientBuilder::new(reqwest::Client::new()) //! let client = ClientBuilder::new(reqwest::Client::new())
//! .layer(RetryTransientMiddleware::new_with_policy(retry_policy)) //! .with(RetryTransientMiddleware::new_with_policy(retry_policy))
//! .build(); //! .build();
//! //!
//! client //! client

View file

@ -36,7 +36,7 @@ use tower::{Layer, Service};
/// }; /// };
/// ///
/// let retry_transient_middleware = RetryTransientMiddleware::new_with_policy(retry_policy); /// let retry_transient_middleware = RetryTransientMiddleware::new_with_policy(retry_policy);
/// let client = ClientBuilder::new(Client::new()).layer(retry_transient_middleware).build(); /// let client = ClientBuilder::new(Client::new()).with(retry_transient_middleware).build();
///``` ///```
/// ///
/// # Note /// # Note

View file

@ -48,7 +48,7 @@ macro_rules! assert_retry_succeeds_inner {
let reqwest_client = Client::builder().build().unwrap(); let reqwest_client = Client::builder().build().unwrap();
let client = ClientBuilder::new(reqwest_client) let client = ClientBuilder::new(reqwest_client)
.layer(RetryTransientMiddleware::new_with_policy( .with(RetryTransientMiddleware::new_with_policy(
ExponentialBackoff { ExponentialBackoff {
max_n_retries: retry_amount, max_n_retries: retry_amount,
max_retry_interval: std::time::Duration::from_millis(30), max_retry_interval: std::time::Duration::from_millis(30),
@ -184,7 +184,7 @@ async fn assert_retry_on_request_timeout() {
let reqwest_client = Client::builder().build().unwrap(); let reqwest_client = Client::builder().build().unwrap();
let client = ClientBuilder::new(reqwest_client) let client = ClientBuilder::new(reqwest_client)
.layer(RetryTransientMiddleware::new_with_policy( .with(RetryTransientMiddleware::new_with_policy(
ExponentialBackoff { ExponentialBackoff {
max_n_retries: 3, max_n_retries: 3,
max_retry_interval: std::time::Duration::from_millis(100), max_retry_interval: std::time::Duration::from_millis(100),
@ -239,7 +239,7 @@ async fn assert_retry_on_incomplete_message() {
let reqwest_client = Client::builder().build().unwrap(); let reqwest_client = Client::builder().build().unwrap();
let client = ClientBuilder::new(reqwest_client) let client = ClientBuilder::new(reqwest_client)
.layer(RetryTransientMiddleware::new_with_policy( .with(RetryTransientMiddleware::new_with_policy(
ExponentialBackoff { ExponentialBackoff {
max_n_retries: 3, max_n_retries: 3,
max_retry_interval: std::time::Duration::from_millis(100), max_retry_interval: std::time::Duration::from_millis(100),

View file

@ -4,11 +4,11 @@
//! //!
//! The simplest possible usage: //! The simplest possible usage:
//! ```no_run //! ```no_run
//! # use reqwest_middleware::Result; //! # use reqwest_middleware::Error;
//! use reqwest_middleware::{ClientBuilder}; //! use reqwest_middleware::{ClientBuilder};
//! use reqwest_tracing::TracingMiddleware; //! use reqwest_tracing::TracingMiddleware;
//! //!
//! # async fn example() -> Result<()> { //! # async fn example() -> Result<(), Error> {
//! let reqwest_client = reqwest::Client::builder().build().unwrap(); //! let reqwest_client = reqwest::Client::builder().build().unwrap();
//! let client = ClientBuilder::new(reqwest_client) //! let client = ClientBuilder::new(reqwest_client)
//! // Insert the tracing middleware //! // Insert the tracing middleware
@ -22,12 +22,12 @@
//! //!
//! To customise the span names use [`OtelName`]. //! To customise the span names use [`OtelName`].
//! ```no_run //! ```no_run
//! # use reqwest_middleware::Result; //! # use reqwest_middleware::Error;
//! use reqwest_middleware::{ClientBuilder, Extension}; //! use reqwest_middleware::{ClientBuilder, Extension};
//! use reqwest_tracing::{ //! use reqwest_tracing::{
//! TracingMiddleware, OtelName //! TracingMiddleware, OtelName
//! }; //! };
//! # async fn example() -> Result<()> { //! # async fn example() -> Result<(), Error> {
//! let reqwest_client = reqwest::Client::builder().build().unwrap(); //! let reqwest_client = reqwest::Client::builder().build().unwrap();
//! let client = ClientBuilder::new(reqwest_client) //! let client = ClientBuilder::new(reqwest_client)
//! // Inserts the extension before the request is started //! // Inserts the extension before the request is started
@ -52,7 +52,7 @@
//! //!
//! Note that Opentelemetry tracks start and stop already, there is no need to have a custom builder like this. //! Note that Opentelemetry tracks start and stop already, there is no need to have a custom builder like this.
//! ```rust //! ```rust
//! use reqwest_middleware::Result; //! use reqwest_middleware::Error;
//! use task_local_extensions::Extensions; //! use task_local_extensions::Extensions;
//! use reqwest::{Request, Response}; //! use reqwest::{Request, Response};
//! use reqwest_middleware::ClientBuilder; //! use reqwest_middleware::ClientBuilder;
@ -70,7 +70,7 @@
//! reqwest_otel_span!(name="example-request", req, time_elapsed = tracing::field::Empty) //! reqwest_otel_span!(name="example-request", req, time_elapsed = tracing::field::Empty)
//! } //! }
//! //!
//! fn on_request_end(span: &Span, outcome: &Result<Response>, extension: &mut Extensions) { //! fn on_request_end(span: &Span, outcome: &Result<Response, Error>, extension: &mut Extensions) {
//! let time_elapsed = extension.get::<Instant>().unwrap().elapsed().as_millis() as i64; //! let time_elapsed = extension.get::<Instant>().unwrap().elapsed().as_millis() as i64;
//! default_on_request_end(span, outcome); //! default_on_request_end(span, outcome);
//! span.record("time_elapsed", &time_elapsed); //! span.record("time_elapsed", &time_elapsed);
@ -103,18 +103,3 @@ pub use reqwest_otel_span_builder::{
#[doc(hidden)] #[doc(hidden)]
pub mod reqwest_otel_span_macro; pub mod reqwest_otel_span_macro;
#[cfg(test)]
mod tests {
use crate::{TracingMiddleware, DefaultSpanBackend};
use reqwest_middleware::ClientBuilder;
#[tokio::test]
async fn compiles() {
let client = ClientBuilder::new(reqwest::Client::new())
.layer(TracingMiddleware::<DefaultSpanBackend>::new())
.build();
let resp = client.get("http://example.com").send().await.unwrap();
dbg!(resp);
}
}

View file

@ -1,4 +1,7 @@
use std::{future::Future, task::ready}; use std::{
future::Future,
task::{ready, Context, Poll},
};
use pin_project_lite::pin_project; use pin_project_lite::pin_project;
use reqwest::Response; use reqwest::Response;
@ -65,10 +68,7 @@ where
type Error = Error; type Error = Error;
type Future = TracingMiddlewareFuture<ReqwestOtelSpan, Svc::Future>; type Future = TracingMiddlewareFuture<ReqwestOtelSpan, Svc::Future>;
fn poll_ready( fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
&mut self,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Result<(), Self::Error>> {
self.service.poll_ready(cx) self.service.poll_ready(cx)
} }
@ -117,16 +117,13 @@ impl<S: ReqwestOtelSpanBackend, F: Future<Output = Result<Response, Error>>> Fut
{ {
type Output = F::Output; type Output = F::Output;
fn poll( fn poll(self: std::pin::Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Self::Output> {
let this = self.project(); let this = self.project();
let outcome = { let outcome = {
let _guard = this.span.enter(); let _guard = this.span.enter();
ready!(this.future.poll(cx)) ready!(this.future.poll(cx))
}; };
S::on_request_end(this.span, &outcome); S::on_request_end(this.span, &outcome);
std::task::Poll::Ready(outcome) Poll::Ready(outcome)
} }
} }

View file

@ -2,7 +2,7 @@ use std::borrow::Cow;
use reqwest::header::{HeaderMap, HeaderValue}; use reqwest::header::{HeaderMap, HeaderValue};
use reqwest::{Request, Response, StatusCode as RequestStatusCode, Url}; use reqwest::{Request, Response, StatusCode as RequestStatusCode, Url};
use reqwest_middleware::{Error, Result}; use reqwest_middleware::Error;
use task_local_extensions::Extensions; use task_local_extensions::Extensions;
use tracing::Span; use tracing::Span;
@ -44,12 +44,12 @@ pub trait ReqwestOtelSpanBackend {
fn on_request_start(req: &Request, extension: &mut Extensions) -> Span; fn on_request_start(req: &Request, extension: &mut Extensions) -> Span;
/// Runs after the request call has executed. /// Runs after the request call has executed.
fn on_request_end(span: &Span, outcome: &Result<Response>); fn on_request_end(span: &Span, outcome: &Result<Response, Error>);
} }
/// Populates default success/failure fields for a given [`reqwest_otel_span!`] span. /// Populates default success/failure fields for a given [`reqwest_otel_span!`] span.
#[inline] #[inline]
pub fn default_on_request_end(span: &Span, outcome: &Result<Response>) { pub fn default_on_request_end(span: &Span, outcome: &Result<Response, Error>) {
match outcome { match outcome {
Ok(res) => default_on_request_success(span, res), Ok(res) => default_on_request_success(span, res),
Err(err) => default_on_request_failure(span, err), Err(err) => default_on_request_failure(span, err),
@ -103,7 +103,7 @@ impl ReqwestOtelSpanBackend for DefaultSpanBackend {
reqwest_otel_span!(name = name, req) reqwest_otel_span!(name = name, req)
} }
fn on_request_end(span: &Span, outcome: &Result<Response>) { fn on_request_end(span: &Span, outcome: &Result<Response, Error>) {
default_on_request_end(span, outcome) default_on_request_end(span, outcome)
} }
} }
@ -128,7 +128,7 @@ impl ReqwestOtelSpanBackend for SpanBackendWithUrl {
reqwest_otel_span!(name = name, req, http.url = %remove_credentials(req.url())) reqwest_otel_span!(name = name, req, http.url = %remove_credentials(req.url()))
} }
fn on_request_end(span: &Span, outcome: &Result<Response>) { fn on_request_end(span: &Span, outcome: &Result<Response, Error>) {
default_on_request_end(span, outcome) default_on_request_end(span, outcome)
} }
} }
@ -156,28 +156,28 @@ fn get_span_status(request_status: RequestStatusCode) -> Option<&'static str> {
/// ///
/// Usage: /// Usage:
/// ```no_run /// ```no_run
/// # use reqwest_middleware::Result; /// # use reqwest_middleware::Error;
/// use reqwest_middleware::{ClientBuilder, Extension}; /// use reqwest_middleware::{ClientBuilder, Extension};
/// use reqwest_tracing::{ /// use reqwest_tracing::{
/// TracingMiddleware, OtelName /// TracingMiddleware, OtelName
/// }; /// };
/// # async fn example() -> Result<()> { /// # async fn example() -> Result<(), Error> {
/// let reqwest_client = reqwest::Client::builder().build().unwrap(); /// let reqwest_client = reqwest::Client::builder().build().unwrap();
/// let client = ClientBuilder::new(reqwest_client) /// let client = ClientBuilder::new(reqwest_client)
/// // Inserts the extension before the request is started /// // Inserts the extension before the request is started
/// .with_init(Extension(OtelName("my-client".into()))) /// .with_init(Extension(OtelName("my-client".into())))
/// // Makes use of that extension to specify the otel name /// // Makes use of that extension to specify the otel name
/// .with(TracingMiddleware::default()) /// .with(TracingMiddleware::default())
/// .build(); /// .build();
/// ///
/// let resp = client.get("https://truelayer.com").send().await.unwrap(); /// let resp = client.get("https://truelayer.com").send().await.unwrap();
/// ///
/// // Or specify it on the individual request (will take priority) /// // Or specify it on the individual request (will take priority)
/// let resp = client.post("https://api.truelayer.com/payment") /// let resp = client.post("https://api.truelayer.com/payment")
/// .with_extension(OtelName("POST /payment".into())) /// .with_extension(OtelName("POST /payment".into()))
/// .send() /// .send()
/// .await /// .await
/// .unwrap(); /// .unwrap();
/// # Ok(()) /// # Ok(())
/// # } /// # }
/// ``` /// ```

View file

@ -30,7 +30,7 @@
/// The second argument passed to [`reqwest_otel_span!`](crate::reqwest_otel_span) is a reference to an [`reqwest::Request`]. /// The second argument passed to [`reqwest_otel_span!`](crate::reqwest_otel_span) is a reference to an [`reqwest::Request`].
/// ///
/// ```rust /// ```rust
/// use reqwest_middleware::Result; /// use reqwest_middleware::Error;
/// use task_local_extensions::Extensions; /// use task_local_extensions::Extensions;
/// use reqwest::{Request, Response}; /// use reqwest::{Request, Response};
/// use reqwest_tracing::{ /// use reqwest_tracing::{
@ -45,7 +45,7 @@
/// reqwest_otel_span!(name = "reqwest-http-request", req) /// reqwest_otel_span!(name = "reqwest-http-request", req)
/// } /// }
/// ///
/// fn on_request_end(span: &Span, outcome: &Result<Response>, _extension: &mut Extensions) { /// fn on_request_end(span: &Span, outcome: &Result<Response, Error>) {
/// default_on_request_end(span, outcome) /// default_on_request_end(span, outcome)
/// } /// }
/// } /// }