Compare commits

..

4 Commits
main ... tower

Author SHA1 Message Date
Conrad Ludgate 8f3623eae0
remove but inspired by tower 2022-11-15 21:51:54 +00:00
Conrad Ludgate df4990d62e
fix last tracing test 2022-11-15 17:07:36 +00:00
Conrad Ludgate 571b9abc49
fix docs 2022-11-15 16:34:33 +00:00
Conrad Ludgate 6eaa2365ed
expierments with tower 2022-11-15 14:53:23 +00:00
26 changed files with 835 additions and 1386 deletions

View File

@ -18,9 +18,6 @@ jobs:
- opentelemetry_0_16
- opentelemetry_0_17
- opentelemetry_0_18
- opentelemetry_0_19
- opentelemetry_0_20
- opentelemetry_0_21
steps:
- name: Checkout repository
uses: actions/checkout@v2
@ -66,9 +63,6 @@ jobs:
- opentelemetry_0_16
- opentelemetry_0_17
- opentelemetry_0_18
- opentelemetry_0_19
- opentelemetry_0_20
- opentelemetry_0_21
steps:
- name: Checkout repository
uses: actions/checkout@v2
@ -97,9 +91,6 @@ jobs:
- opentelemetry_0_16
- opentelemetry_0_17
- opentelemetry_0_18
- opentelemetry_0_19
- opentelemetry_0_20
- opentelemetry_0_21
steps:
- name: Checkout repository
uses: actions/checkout@v2
@ -140,3 +131,27 @@ jobs:
with:
command: publish
args: --dry-run --manifest-path reqwest-tracing/Cargo.toml
coverage:
name: Code coverage
runs-on: ubuntu-latest
steps:
- name: Checkout repository
uses: actions/checkout@v2
- name: Install stable toolchain
uses: actions-rs/toolchain@v1
with:
toolchain: stable
profile: minimal
override: true
- name: Run cargo-tarpaulin
uses: actions-rs/tarpaulin@v0.1
with:
args: '--ignore-tests --out Lcov'
- name: Upload to Coveralls
# upload only if push
if: ${{ github.event_name == 'push' }}
uses: coverallsapp/github-action@master
with:
github-token: ${{ secrets.GITHUB_TOKEN }}
path-to-lcov: './lcov.info'

10
.gitignore vendored
View File

@ -1,10 +1,2 @@
# OS
.DS_Store
# IDE
.idea/
.vscode/
# Rust
Cargo.lock
/target
Cargo.lock

View File

@ -4,25 +4,7 @@ All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
### [0.2.4] - 2023-09-21
### Added
- Added `fetch_mode_no_cors` method to `reqwest_middleware::RequestBuilder`
## [0.2.3] - 2023-08-07
### Added
- Added all `reqwest::Error` methods for `reqwest_middleware::Error`
## [0.2.2] - 2023-05-11
### Added
- `RequestBuilder::version` method to configure the HTTP version
## [0.2.1] - 2023-03-09
### Added
- Support for `wasm32-unknown-unknown`
## [Unreleased]
## [0.2.0] - 2022-11-15

View File

@ -11,15 +11,10 @@ to allow for client middleware chains.
This crate provides functionality for building and running middleware but no middleware
implementations. This repository also contains a couple of useful concrete middleware crates:
* [`reqwest-retry`](https://crates.io/crates/reqwest-retry): retry failed requests.
* [`reqwest-tracing`](https://crates.io/crates/reqwest-tracing):
- [`reqwest-retry`](https://crates.io/crates/reqwest-retry): retry failed requests.
- [`reqwest-tracing`](https://crates.io/crates/reqwest-tracing):
[`tracing`](https://crates.io/crates/tracing) integration, optional opentelemetry support.
Note about browser support: automated tests targetting wasm are disabled. The crate may work with
wasm but wasm support is unmaintained. PRs improving wasm are still welcome but you'd need to
reintroduce the tests and get them passing before we'd merge it (see
https://github.com/TrueLayer/reqwest-middleware/pull/105).
## Overview
The `reqwest-middleware` client exposes the same interface as a plain `reqwest` client, but
@ -34,10 +29,12 @@ reqwest-middleware = "0.1.6"
reqwest-retry = "0.1.5"
reqwest-tracing = "0.2.3"
tokio = { version = "1.12.0", features = ["macros", "rt-multi-thread"] }
tower = "0.4"
```
```rust
use reqwest_middleware::{ClientBuilder, ClientWithMiddleware};
use reqwest::Response;
use reqwest_middleware::{ClientBuilder, ClientWithMiddleware, Error, Layer, RequestInitialiser, ReqwestService, Service};
use reqwest_retry::{RetryTransientMiddleware, policies::ExponentialBackoff};
use reqwest_tracing::TracingMiddleware;
@ -54,7 +51,12 @@ async fn main() {
run(client).await;
}
async fn run(client: ClientWithMiddleware) {
async fn run<M, I>(client: ClientWithMiddleware<M, I>)
where
M: Layer<ReqwestService>,
M::Service: Service,
I: RequestInitialiser,
{
client
.get("https://truelayer.com")
.header("foo", "bar")
@ -78,14 +80,3 @@ Unless you explicitly state otherwise, any contribution intentionally submitted
for inclusion in the work by you, as defined in the Apache-2.0 license, shall be
dual licensed as above, without any additional terms or conditions.
</sub>
## Third-party middleware
The following third-party middleware use `request-middleware`:
- [`reqwest-conditional-middleware`](https://github.com/oxidecomputer/reqwest-conditional-middleware) - Per-request basis middleware
- [`http-cache`](https://github.com/06chaynes/http-cache) - HTTP caching rules
- [`reqwest-cache`](https://gitlab.com/famedly/company/backend/libraries/reqwest-cache) - HTTP caching
- [`aliri_reqwest`](https://github.com/neoeinstein/aliri/tree/main/aliri_reqwest) - Background token management and renewal
- [`http-signature-normalization-reqwest`](https://crates.io/crates/http-signature-normalization-reqwest) (not free software) - HTTP Signatures
- [`reqwest-chain`](https://github.com/tommilligan/reqwest-chain) - Apply custom criteria to any reqwest response, deciding when and how to retry.

View File

@ -1,6 +1,6 @@
[package]
name = "reqwest-middleware"
version = "0.2.4"
version = "0.3.0"
authors = ["Rodrigo Gryzinski <rodrigo.gryzinski@truelayer.com>"]
edition = "2018"
description = "Wrapper around reqwest to allow for client middleware chains."
@ -11,17 +11,18 @@ categories = ["web-programming::http-client"]
readme = "../README.md"
[dependencies]
anyhow = "1.0.0"
anyhow = "1"
async-trait = "0.1.51"
http = "0.2.0"
reqwest = { version = "0.11.4", default-features = false, features = ["json", "multipart"] }
serde = "1.0.106"
task-local-extensions = "0.1.4"
thiserror = "1.0.21"
http = "0.2"
reqwest = { version = "0.11", default-features = false, features = ["json", "multipart"] }
serde = "1"
task-local-extensions = "0.1.1"
thiserror = "1"
futures = "0.3"
[dev-dependencies]
reqwest = "0.11.4"
reqwest = "0.11"
reqwest-retry = { path = "../reqwest-retry" }
reqwest-tracing = { path = "../reqwest-tracing" }
tokio = { version = "1.0.0", features = ["macros", "rt-multi-thread"] }
wiremock = "0.5.0"
tokio = { version = "1", features = ["macros", "rt-multi-thread"] }
wiremock = "0.5"

View File

@ -1,81 +1,67 @@
use futures::future::BoxFuture;
use futures::FutureExt;
use reqwest::header::{HeaderMap, HeaderName, HeaderValue};
use reqwest::multipart::Form;
use reqwest::{Body, Client, IntoUrl, Method, Request, Response};
use serde::Serialize;
use std::convert::TryFrom;
use std::fmt::{self, Display};
use std::sync::Arc;
use std::time::Duration;
use task_local_extensions::Extensions;
// use tower::{Layer, Service, ServiceBuilder, ServiceExt};
use crate::error::Result;
use crate::middleware::{Middleware, Next};
use crate::RequestInitialiser;
use crate::{Error, Identity, Layer, RequestInitialiser, RequestStack, Service, Stack};
/// A `ClientBuilder` is used to build a [`ClientWithMiddleware`].
///
/// [`ClientWithMiddleware`]: crate::ClientWithMiddleware
#[derive(Debug)]
pub struct ClientBuilder {
pub struct ClientBuilder<M, I> {
client: Client,
middleware_stack: Vec<Arc<dyn Middleware>>,
initialiser_stack: Vec<Arc<dyn RequestInitialiser>>,
middleware_stack: M,
initialiser_stack: I,
}
impl ClientBuilder {
impl ClientBuilder<Identity, Identity> {
pub fn new(client: Client) -> Self {
ClientBuilder {
client,
middleware_stack: Vec::new(),
initialiser_stack: Vec::new(),
middleware_stack: Identity,
initialiser_stack: Identity,
}
}
}
impl<M, I> ClientBuilder<M, I> {
/// Convenience method to attach middleware.
pub fn with<T>(self, layer: T) -> ClientBuilder<Stack<T, M>, I> {
ClientBuilder {
client: self.client,
middleware_stack: Stack {
inner: layer,
outer: self.middleware_stack,
},
initialiser_stack: self.initialiser_stack,
}
}
/// Convenience method to attach middleware.
///
/// If you need to keep a reference to the middleware after attaching, use [`with_arc`].
///
/// [`with_arc`]: Self::with_arc
pub fn with<M>(self, middleware: M) -> Self
where
M: Middleware,
{
self.with_arc(Arc::new(middleware))
}
/// Add middleware to the chain. [`with`] is more ergonomic if you don't need the `Arc`.
///
/// [`with`]: Self::with
pub fn with_arc(mut self, middleware: Arc<dyn Middleware>) -> Self {
self.middleware_stack.push(middleware);
self
}
/// Convenience method to attach a request initialiser.
///
/// If you need to keep a reference to the initialiser after attaching, use [`with_arc_init`].
///
/// [`with_arc_init`]: Self::with_arc_init
pub fn with_init<I>(self, initialiser: I) -> Self
where
I: RequestInitialiser,
{
self.with_arc_init(Arc::new(initialiser))
}
/// Add a request initialiser to the chain. [`with_init`] is more ergonomic if you don't need the `Arc`.
///
/// [`with_init`]: Self::with_init
pub fn with_arc_init(mut self, initialiser: Arc<dyn RequestInitialiser>) -> Self {
self.initialiser_stack.push(initialiser);
self
pub fn with_init<T>(self, initialiser: T) -> ClientBuilder<M, RequestStack<T, I>> {
ClientBuilder {
client: self.client,
middleware_stack: self.middleware_stack,
initialiser_stack: RequestStack {
inner: initialiser,
outer: self.initialiser_stack,
},
}
}
/// Returns a `ClientWithMiddleware` using this builder configuration.
pub fn build(self) -> ClientWithMiddleware {
pub fn build(self) -> ClientWithMiddleware<M, I> {
ClientWithMiddleware {
inner: self.client,
middleware_stack: self.middleware_stack.into_boxed_slice(),
initialiser_stack: self.initialiser_stack.into_boxed_slice(),
middleware_stack: self.middleware_stack,
initialiser_stack: self.initialiser_stack,
}
}
}
@ -83,97 +69,68 @@ impl ClientBuilder {
/// `ClientWithMiddleware` is a wrapper around [`reqwest::Client`] which runs middleware on every
/// request.
#[derive(Clone)]
pub struct ClientWithMiddleware {
pub struct ClientWithMiddleware<M, I> {
inner: reqwest::Client,
middleware_stack: Box<[Arc<dyn Middleware>]>,
initialiser_stack: Box<[Arc<dyn RequestInitialiser>]>,
middleware_stack: M,
initialiser_stack: I,
}
impl ClientWithMiddleware {
/// See [`ClientBuilder`] for a more ergonomic way to build `ClientWithMiddleware` instances.
pub fn new<T>(client: Client, middleware_stack: T) -> Self
where
T: Into<Box<[Arc<dyn Middleware>]>>,
{
ClientWithMiddleware {
inner: client,
middleware_stack: middleware_stack.into(),
// TODO(conradludgate) - allow downstream code to control this manually if desired
initialiser_stack: Box::new([]),
}
}
impl<M: Layer<ReqwestService>, I: RequestInitialiser> ClientWithMiddleware<M, I> {
/// See [`Client::get`]
pub fn get<U: IntoUrl>(&self, url: U) -> RequestBuilder {
pub fn get<U: IntoUrl>(&self, url: U) -> RequestBuilder<M, I> {
self.request(Method::GET, url)
}
/// See [`Client::post`]
pub fn post<U: IntoUrl>(&self, url: U) -> RequestBuilder {
pub fn post<U: IntoUrl>(&self, url: U) -> RequestBuilder<M, I> {
self.request(Method::POST, url)
}
/// See [`Client::put`]
pub fn put<U: IntoUrl>(&self, url: U) -> RequestBuilder {
pub fn put<U: IntoUrl>(&self, url: U) -> RequestBuilder<M, I> {
self.request(Method::PUT, url)
}
/// See [`Client::patch`]
pub fn patch<U: IntoUrl>(&self, url: U) -> RequestBuilder {
pub fn patch<U: IntoUrl>(&self, url: U) -> RequestBuilder<M, I> {
self.request(Method::PATCH, url)
}
/// See [`Client::delete`]
pub fn delete<U: IntoUrl>(&self, url: U) -> RequestBuilder {
pub fn delete<U: IntoUrl>(&self, url: U) -> RequestBuilder<M, I> {
self.request(Method::DELETE, url)
}
/// See [`Client::head`]
pub fn head<U: IntoUrl>(&self, url: U) -> RequestBuilder {
pub fn head<U: IntoUrl>(&self, url: U) -> RequestBuilder<M, I> {
self.request(Method::HEAD, url)
}
/// See [`Client::request`]
pub fn request<U: IntoUrl>(&self, method: Method, url: U) -> RequestBuilder {
let req = RequestBuilder {
inner: self.inner.request(method, url),
client: self.clone(),
extensions: Extensions::new(),
};
self.initialiser_stack
.iter()
.fold(req, |req, i| i.init(req))
}
/// See [`Client::execute`]
pub async fn execute(&self, req: Request) -> Result<Response> {
let mut ext = Extensions::new();
self.execute_with_extensions(req, &mut ext).await
}
/// Executes a request with initial [`Extensions`].
pub async fn execute_with_extensions(
&self,
req: Request,
ext: &mut Extensions,
) -> Result<Response> {
let next = Next::new(&self.inner, &self.middleware_stack);
next.run(req, ext).await
}
}
/// Create a `ClientWithMiddleware` without any middleware.
impl From<Client> for ClientWithMiddleware {
fn from(client: Client) -> Self {
ClientWithMiddleware {
inner: client,
middleware_stack: Box::new([]),
initialiser_stack: Box::new([]),
pub fn request<U: IntoUrl>(&self, method: Method, url: U) -> RequestBuilder<'_, M, I> {
let mut extensions = Extensions::new();
let request = self.inner.request(method, url);
let request = self.initialiser_stack.init(request, &mut extensions);
RequestBuilder {
inner: request,
client: self,
extensions,
}
}
}
impl fmt::Debug for ClientWithMiddleware {
/// Create a `ClientWithMiddleware` without any middleware.
impl From<Client> for ClientWithMiddleware<Identity, Identity> {
fn from(client: Client) -> Self {
ClientWithMiddleware {
inner: client,
middleware_stack: Identity,
initialiser_stack: Identity,
}
}
}
impl<M, I> fmt::Debug for ClientWithMiddleware<M, I> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
// skipping middleware_stack field for now
f.debug_struct("ClientWithMiddleware")
@ -184,13 +141,28 @@ impl fmt::Debug for ClientWithMiddleware {
/// This is a wrapper around [`reqwest::RequestBuilder`] exposing the same API.
#[must_use = "RequestBuilder does nothing until you 'send' it"]
pub struct RequestBuilder {
pub struct RequestBuilder<'client, M, I> {
inner: reqwest::RequestBuilder,
client: ClientWithMiddleware,
client: &'client ClientWithMiddleware<M, I>,
extensions: Extensions,
}
impl RequestBuilder {
#[derive(Clone)]
pub struct ReqwestService(Client);
impl Service for ReqwestService {
type Future = BoxFuture<'static, Result<Response, Error>>;
fn call(&mut self, req: Request, _: &mut Extensions) -> Self::Future {
let fut = self.0.execute(req);
async { fut.await.map_err(Error::from) }.boxed()
}
}
impl<M: Layer<ReqwestService>, I: RequestInitialiser> RequestBuilder<'_, M, I>
where
M::Service: Service,
{
pub fn header<K, V>(self, key: K, value: V) -> Self
where
HeaderName: TryFrom<K>,
@ -211,14 +183,6 @@ impl RequestBuilder {
}
}
#[cfg(not(target_arch = "wasm32"))]
pub fn version(self, version: reqwest::Version) -> Self {
RequestBuilder {
inner: self.inner.version(version),
..self
}
}
pub fn basic_auth<U, P>(self, username: U, password: Option<P>) -> Self
where
U: Display,
@ -247,8 +211,7 @@ impl RequestBuilder {
}
}
#[cfg(not(target_arch = "wasm32"))]
pub fn timeout(self, timeout: std::time::Duration) -> Self {
pub fn timeout(self, timeout: Duration) -> Self {
RequestBuilder {
inner: self.inner.timeout(timeout),
..self
@ -283,13 +246,6 @@ impl RequestBuilder {
}
}
pub fn fetch_mode_no_cors(self) -> Self {
RequestBuilder {
inner: self.inner.fetch_mode_no_cors(),
..self
}
}
pub fn build(self) -> reqwest::Result<Request> {
self.inner.build()
}
@ -305,14 +261,19 @@ impl RequestBuilder {
&mut self.extensions
}
pub async fn send(self) -> Result<Response> {
pub async fn send(self) -> Result<Response, Error> {
let Self {
inner,
client,
mut extensions,
} = self;
let req = inner.build()?;
client.execute_with_extensions(req, &mut extensions).await
let mut svc = client
.middleware_stack
.layer(ReqwestService(client.inner.clone()));
svc.call(req, &mut extensions).await
// client.execute_with_extensions(req, &mut extensions).await
}
/// Attempt to clone the RequestBuilder.
@ -325,13 +286,13 @@ impl RequestBuilder {
pub fn try_clone(&self) -> Option<Self> {
self.inner.try_clone().map(|inner| RequestBuilder {
inner,
client: self.client.clone(),
client: self.client,
extensions: Extensions::new(),
})
}
}
impl fmt::Debug for RequestBuilder {
impl<M, I> fmt::Debug for RequestBuilder<'_, M, I> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
// skipping middleware_stack field for now
f.debug_struct("RequestBuilder")

View File

@ -1,8 +1,5 @@
use reqwest::{StatusCode, Url};
use thiserror::Error;
pub type Result<T> = std::result::Result<T, Error>;
#[derive(Error, Debug)]
pub enum Error {
/// There was an error running some middleware
@ -20,122 +17,4 @@ impl Error {
{
Error::Middleware(err.into())
}
/// Returns a possible URL related to this error.
pub fn url(&self) -> Option<&Url> {
match self {
Error::Middleware(_) => None,
Error::Reqwest(e) => e.url(),
}
}
/// Returns a mutable reference to the URL related to this error.
///
/// This is useful if you need to remove sensitive information from the URL
/// (e.g. an API key in the query), but do not want to remove the URL
/// entirely.
pub fn url_mut(&mut self) -> Option<&mut Url> {
match self {
Error::Middleware(_) => None,
Error::Reqwest(e) => e.url_mut(),
}
}
/// Adds a url related to this error (overwriting any existing).
pub fn with_url(self, url: Url) -> Self {
match self {
Error::Middleware(_) => self,
Error::Reqwest(e) => e.with_url(url).into(),
}
}
/// Strips the related URL from this error (if, for example, it contains
/// sensitive information).
pub fn without_url(self) -> Self {
match self {
Error::Middleware(_) => self,
Error::Reqwest(e) => e.without_url().into(),
}
}
/// Returns true if the error is from any middleware.
pub fn is_middleware(&self) -> bool {
match self {
Error::Middleware(_) => true,
Error::Reqwest(_) => false,
}
}
/// Returns true if the error is from a type `Builder`.
pub fn is_builder(&self) -> bool {
match self {
Error::Middleware(_) => false,
Error::Reqwest(e) => e.is_builder(),
}
}
/// Returns true if the error is from a `RedirectPolicy`.
pub fn is_redirect(&self) -> bool {
match self {
Error::Middleware(_) => false,
Error::Reqwest(e) => e.is_redirect(),
}
}
/// Returns true if the error is from `Response::error_for_status`.
pub fn is_status(&self) -> bool {
match self {
Error::Middleware(_) => false,
Error::Reqwest(e) => e.is_status(),
}
}
/// Returns true if the error is related to a timeout.
pub fn is_timeout(&self) -> bool {
match self {
Error::Middleware(_) => false,
Error::Reqwest(e) => e.is_timeout(),
}
}
/// Returns true if the error is related to the request.
pub fn is_request(&self) -> bool {
match self {
Error::Middleware(_) => false,
Error::Reqwest(e) => e.is_request(),
}
}
#[cfg(not(target_arch = "wasm32"))]
/// Returns true if the error is related to connect.
pub fn is_connect(&self) -> bool {
match self {
Error::Middleware(_) => false,
Error::Reqwest(e) => e.is_connect(),
}
}
/// Returns true if the error is related to the request or response body.
pub fn is_body(&self) -> bool {
match self {
Error::Middleware(_) => false,
Error::Reqwest(e) => e.is_body(),
}
}
/// Returns true if the error is related to decoding the response's body.
pub fn is_decode(&self) -> bool {
match self {
Error::Middleware(_) => false,
Error::Reqwest(e) => e.is_decode(),
}
}
/// Returns the status code, if the error was generated from a response.
pub fn status(&self) -> Option<StatusCode> {
match self {
Error::Middleware(_) => None,
Error::Reqwest(e) => e.status(),
}
}
}

View File

@ -7,30 +7,44 @@
//!
//! ```
//! use reqwest::{Client, Request, Response};
//! use reqwest_middleware::{ClientBuilder, Middleware, Next, Result};
//! use reqwest_middleware::{ClientBuilder, Error, Extension, Layer, Service};
//! use task_local_extensions::Extensions;
//! use futures::future::{BoxFuture, FutureExt};
//! use std::task::{Context, Poll};
//!
//! struct LoggingMiddleware;
//! struct LoggingLayer;
//! struct LoggingService<S>(S);
//!
//! #[async_trait::async_trait]
//! impl Middleware for LoggingMiddleware {
//! async fn handle(
//! &self,
//! req: Request,
//! extensions: &mut Extensions,
//! next: Next<'_>,
//! ) -> Result<Response> {
//! println!("Request started {:?}", req);
//! let res = next.run(req, extensions).await;
//! println!("Result: {:?}", res);
//! res
//! impl<S> Layer<S> for LoggingLayer {
//! type Service = LoggingService<S>;
//!
//! fn layer(&self, inner: S) -> Self::Service {
//! LoggingService(inner)
//! }
//! }
//!
//! impl<S> Service for LoggingService<S>
//! where
//! S: Service,
//! S::Future: Send + 'static,
//! {
//! type Future = BoxFuture<'static, Result<Response, Error>>;
//!
//! fn call(&mut self, req: Request, ext: &mut Extensions) -> Self::Future {
//! println!("Request started {req:?}");
//! let fut = self.0.call(req, ext);
//! async {
//! let res = fut.await;
//! println!("Result: {res:?}");
//! res
//! }.boxed()
//! }
//! }
//!
//! async fn run() {
//! let reqwest_client = Client::builder().build().unwrap();
//! let client = ClientBuilder::new(reqwest_client)
//! .with(LoggingMiddleware)
//! .with(LoggingLayer)
//! .build();
//! let resp = client.get("https://truelayer.com").send().await.unwrap();
//! println!("TrueLayer page HTML: {}", resp.text().await.unwrap());
@ -51,10 +65,54 @@ pub struct ReadmeDoctests;
mod client;
mod error;
mod middleware;
mod req_init;
pub use client::{ClientBuilder, ClientWithMiddleware, RequestBuilder};
pub use error::{Error, Result};
pub use middleware::{Middleware, Next};
pub use req_init::{Extension, RequestInitialiser};
pub use client::{ClientBuilder, ClientWithMiddleware, RequestBuilder, ReqwestService};
pub use error::Error;
pub use req_init::{Extension, RequestInitialiser, RequestStack};
use reqwest::{Request, Response};
use task_local_extensions::Extensions;
/// Two [`RequestInitialiser`]s or [`Service`]s chained together.
#[derive(Clone)]
pub struct Stack<Inner, Outer> {
pub(crate) inner: Inner,
pub(crate) outer: Outer,
}
pub trait Service {
type Future: std::future::Future<Output = Result<Response, Error>>;
fn call(&mut self, req: Request, ext: &mut Extensions) -> Self::Future;
}
pub struct Identity;
impl<S: Service> Layer<S> for Identity {
type Service = S;
fn layer(&self, inner: S) -> Self::Service {
inner
}
}
pub trait Layer<S> {
/// The wrapped service
type Service;
/// Wrap the given service with the middleware, returning a new service
/// that has been decorated with the middleware.
fn layer(&self, inner: S) -> Self::Service;
}
impl<S, Inner, Outer> Layer<S> for Stack<Inner, Outer>
where
Inner: Layer<S>,
Outer: Layer<Inner::Service>,
{
type Service = Outer::Service;
fn layer(&self, service: S) -> Self::Service {
let inner = self.inner.layer(service);
self.outer.layer(inner)
}
}

View File

@ -1,111 +0,0 @@
use reqwest::{Client, Request, Response};
use std::sync::Arc;
use task_local_extensions::Extensions;
use crate::error::{Error, Result};
/// When attached to a [`ClientWithMiddleware`] (generally using [`with`]), middleware is run
/// whenever the client issues a request, in the order it was attached.
///
/// # Example
///
/// ```
/// use reqwest::{Client, Request, Response};
/// use reqwest_middleware::{ClientBuilder, Middleware, Next, Result};
/// use task_local_extensions::Extensions;
///
/// struct TransparentMiddleware;
///
/// #[async_trait::async_trait]
/// impl Middleware for TransparentMiddleware {
/// async fn handle(
/// &self,
/// req: Request,
/// extensions: &mut Extensions,
/// next: Next<'_>,
/// ) -> Result<Response> {
/// next.run(req, extensions).await
/// }
/// }
/// ```
///
/// [`ClientWithMiddleware`]: crate::ClientWithMiddleware
/// [`with`]: crate::ClientBuilder::with
#[cfg_attr(not(target_arch = "wasm32"), async_trait::async_trait)]
#[cfg_attr(target_arch = "wasm32", async_trait::async_trait(?Send))]
pub trait Middleware: 'static + Send + Sync {
/// Invoked with a request before sending it. If you want to continue processing the request,
/// you should explicitly call `next.run(req, extensions)`.
///
/// If you need to forward data down the middleware stack, you can use the `extensions`
/// argument.
async fn handle(
&self,
req: Request,
extensions: &mut Extensions,
next: Next<'_>,
) -> Result<Response>;
}
impl std::fmt::Debug for dyn Middleware {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "Middleware")
}
}
#[cfg_attr(not(target_arch = "wasm32"), async_trait::async_trait)]
#[cfg_attr(target_arch = "wasm32", async_trait::async_trait(?Send))]
impl<F> Middleware for F
where
F: Send
+ Sync
+ 'static
+ for<'a> Fn(Request, &'a mut Extensions, Next<'a>) -> BoxFuture<'a, Result<Response>>,
{
async fn handle(
&self,
req: Request,
extensions: &mut Extensions,
next: Next<'_>,
) -> Result<Response> {
(self)(req, extensions, next).await
}
}
/// Next encapsulates the remaining middleware chain to run in [`Middleware::handle`]. You can
/// forward the request down the chain with [`run`].
///
/// [`Middleware::handle`]: Middleware::handle
/// [`run`]: Self::run
#[derive(Clone)]
pub struct Next<'a> {
client: &'a Client,
middlewares: &'a [Arc<dyn Middleware>],
}
#[cfg(not(target_arch = "wasm32"))]
pub type BoxFuture<'a, T> = std::pin::Pin<Box<dyn std::future::Future<Output = T> + Send + 'a>>;
#[cfg(target_arch = "wasm32")]
pub type BoxFuture<'a, T> = std::pin::Pin<Box<dyn std::future::Future<Output = T> + 'a>>;
impl<'a> Next<'a> {
pub(crate) fn new(client: &'a Client, middlewares: &'a [Arc<dyn Middleware>]) -> Self {
Next {
client,
middlewares,
}
}
pub fn run(
mut self,
req: Request,
extensions: &'a mut Extensions,
) -> BoxFuture<'a, Result<Response>> {
if let Some((current, rest)) = self.middlewares.split_first() {
self.middlewares = rest;
Box::pin(current.handle(req, extensions, self))
} else {
Box::pin(async move { self.client.execute(req).await.map_err(Error::from) })
}
}
}

View File

@ -1,4 +1,7 @@
use crate::RequestBuilder;
use reqwest::RequestBuilder;
use task_local_extensions::Extensions;
use crate::Identity;
/// When attached to a [`ClientWithMiddleware`] (generally using [`with_init`]), it is run
/// whenever the client starts building a request, in the order it was attached.
@ -6,12 +9,14 @@ use crate::RequestBuilder;
/// # Example
///
/// ```
/// use reqwest_middleware::{RequestInitialiser, RequestBuilder};
/// use reqwest::RequestBuilder;
/// use reqwest_middleware::RequestInitialiser;
/// use task_local_extensions::Extensions;
///
/// struct AuthInit;
///
/// impl RequestInitialiser for AuthInit {
/// fn init(&self, req: RequestBuilder) -> RequestBuilder {
/// fn init(&self, req: RequestBuilder, ext: &mut Extensions) -> RequestBuilder {
/// req.bearer_auth("my_auth_token")
/// }
/// }
@ -20,21 +25,30 @@ use crate::RequestBuilder;
/// [`ClientWithMiddleware`]: crate::ClientWithMiddleware
/// [`with_init`]: crate::ClientBuilder::with_init
pub trait RequestInitialiser: 'static + Send + Sync {
fn init(&self, req: RequestBuilder) -> RequestBuilder;
fn init(&self, req: RequestBuilder, ext: &mut Extensions) -> RequestBuilder;
}
impl core::fmt::Debug for dyn RequestInitialiser {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "RequestInitialiser")
impl RequestInitialiser for Identity {
fn init(&self, req: RequestBuilder, _: &mut Extensions) -> RequestBuilder {
req
}
}
impl<F> RequestInitialiser for F
/// Two [`RequestInitialiser`]s chained together.
#[derive(Clone)]
pub struct RequestStack<Inner, Outer> {
pub(crate) inner: Inner,
pub(crate) outer: Outer,
}
impl<I, O> RequestInitialiser for RequestStack<I, O>
where
F: Send + Sync + 'static + Fn(RequestBuilder) -> RequestBuilder,
I: RequestInitialiser,
O: RequestInitialiser,
{
fn init(&self, req: RequestBuilder) -> RequestBuilder {
(self)(req)
fn init(&self, req: RequestBuilder, ext: &mut Extensions) -> RequestBuilder {
let req = self.inner.init(req, ext);
self.outer.init(req, ext)
}
}
@ -44,30 +58,45 @@ where
///
/// ```
/// use reqwest::{Client, Request, Response};
/// use reqwest_middleware::{ClientBuilder, Middleware, Next, Result, Extension};
/// use reqwest_middleware::{ClientBuilder, Error, Extension, Layer, Service};
/// use task_local_extensions::Extensions;
/// use futures::future::{BoxFuture, FutureExt};
/// use std::task::{Context, Poll};
///
/// #[derive(Clone)]
/// struct LogName(&'static str);
/// struct LoggingMiddleware;
///
/// #[async_trait::async_trait]
/// impl Middleware for LoggingMiddleware {
/// async fn handle(
/// &self,
/// req: Request,
/// extensions: &mut Extensions,
/// next: Next<'_>,
/// ) -> Result<Response> {
/// struct LoggingLayer;
/// struct LoggingService<S>(S);
///
/// impl<S> Layer<S> for LoggingLayer {
/// type Service = LoggingService<S>;
///
/// fn layer(&self, inner: S) -> Self::Service {
/// LoggingService(inner)
/// }
/// }
///
/// impl<S> Service for LoggingService<S>
/// where
/// S: Service,
/// S::Future: Send + 'static,
/// {
/// type Future = BoxFuture<'static, Result<Response, Error>>;
///
/// fn call(&mut self, req: Request, ext: &mut Extensions) -> Self::Future {
/// // get the log name or default to "unknown"
/// let name = extensions
/// let name = ext
/// .get()
/// .map(|&LogName(name)| name)
/// .unwrap_or("unknown");
/// println!("[{name}] Request started {req:?}");
/// let res = next.run(req, extensions).await;
/// println!("[{name}] Result: {res:?}");
/// res
/// let fut = self.0.call(req, ext);
/// async move {
/// let res = fut.await;
/// println!("[{name}] Result: {res:?}");
/// res
/// }.boxed()
/// }
/// }
///
@ -75,7 +104,7 @@ where
/// let reqwest_client = Client::builder().build().unwrap();
/// let client = ClientBuilder::new(reqwest_client)
/// .with_init(Extension(LogName("my-client")))
/// .with(LoggingMiddleware)
/// .with(LoggingLayer)
/// .build();
/// let resp = client.get("https://truelayer.com").send().await.unwrap();
/// println!("TrueLayer page HTML: {}", resp.text().await.unwrap());
@ -84,7 +113,8 @@ where
pub struct Extension<T>(pub T);
impl<T: Send + Sync + Clone + 'static> RequestInitialiser for Extension<T> {
fn init(&self, req: RequestBuilder) -> RequestBuilder {
req.with_extension(self.0.clone())
fn init(&self, req: RequestBuilder, ext: &mut Extensions) -> RequestBuilder {
ext.insert(self.0.clone());
req
}
}

View File

@ -4,20 +4,10 @@ All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
## [0.3.0] - 2023-09-07
### Changed
- `retry-policies` upgraded to 0.2.0
## [0.2.3] - 2023-08-30
### Added
- `RetryableStrategy` which allows for custom retry decisions based on the response that a request got
## [0.2.1] - 2022-12-01
### Changed
- Classify `io::Error`s and `hyper::Error(Canceled)` as transient
## [Unreleased]
## [0.2.0] - 2022-11-15
### Changed
- Updated `reqwest-middleware` to `0.2.0`

View File

@ -10,29 +10,23 @@ keywords = ["reqwest", "http", "middleware", "retry"]
categories = ["web-programming::http-client"]
[dependencies]
reqwest-middleware = { version = "0.2.0", path = "../reqwest-middleware" }
reqwest-middleware = { version = "0.3.0", path = "../reqwest-middleware" }
anyhow = "1.0.0"
anyhow = "1"
async-trait = "0.1.51"
chrono = { version = "0.4.19", features = ["clock"], default-features = false }
futures = "0.3.0"
http = "0.2.0"
reqwest = { version = "0.11.0", default-features = false }
retry-policies = "0.2.0"
task-local-extensions = "0.1.4"
futures = "0.3"
http = "0.2"
hyper = "0.14"
reqwest = { version = "0.11", default-features = false }
retry-policies = "0.1"
task-local-extensions = "0.1.1"
tokio = { version = "1.6", features = ["time"] }
tracing = "0.1.26"
[target.'cfg(not(target_arch = "wasm32"))'.dependencies]
hyper = "0.14.0"
tokio = { version = "1.6.0", features = ["time"] }
[target.'cfg(target_arch = "wasm32")'.dependencies]
parking_lot = { version = "0.11.2", features = ["wasm-bindgen"] } # work around https://github.com/tomaka/wasm-timer/issues/14
wasm-timer = "0.2.5"
getrandom = { version = "0.2.0", features = ["js"] }
pin-project-lite = "0.2"
[dev-dependencies]
paste = "1.0.0"
tokio = { version = "1.0.0", features = ["full"] }
wiremock = "0.5.0"
futures = "0.3.0"
async-std = { version = "1.10"}
paste = "1"
tokio = { version = "1", features = ["macros"] }
wiremock = "0.5"

View File

@ -27,13 +27,8 @@
mod middleware;
mod retryable;
mod retryable_strategy;
pub use retry_policies::{policies, RetryPolicy};
pub use middleware::RetryTransientMiddleware;
pub use retryable::Retryable;
pub use retryable_strategy::{
default_on_request_failure, default_on_request_success, DefaultRetryableStrategy,
RetryableStrategy,
};

View File

@ -1,20 +1,23 @@
//! `RetryTransientMiddleware` implements retrying requests on transient errors.
use crate::retryable_strategy::RetryableStrategy;
use crate::{retryable::Retryable, retryable_strategy::DefaultRetryableStrategy};
use anyhow::anyhow;
use std::pin::Pin;
use std::task::{ready, Context, Poll};
use crate::retryable::Retryable;
use chrono::Utc;
use futures::Future;
use pin_project_lite::pin_project;
use reqwest::{Request, Response};
use reqwest_middleware::{Error, Middleware, Next, Result};
use reqwest_middleware::{Error, Layer, Service};
use retry_policies::RetryPolicy;
use task_local_extensions::Extensions;
use tokio::time::Sleep;
/// `RetryTransientMiddleware` offers retry logic for requests that fail in a transient manner
/// and can be safely executed again.
///
/// Currently, it allows setting a [RetryPolicy] algorithm for calculating the __wait_time__
/// between each request retry. Sleeping on non-`wasm32` archs is performed using
/// [`tokio::time::sleep`], therefore it will respect pauses/auto-advance if run under a
/// runtime that supports them.
/// Currently, it allows setting a [RetryPolicy][retry_policies::RetryPolicy] algorithm for calculating the __wait_time__
/// between each request retry.
///
///```rust
/// use reqwest_middleware::ClientBuilder;
@ -46,116 +49,264 @@ use task_local_extensions::Extensions;
/// * You can wrap this middleware in a custom one which skips retries for streaming requests.
/// * You can write a custom retry middleware that builds new streaming requests from the data
/// source directly, avoiding the issue of streaming requests not being clonable.
pub struct RetryTransientMiddleware<
T: RetryPolicy + Send + Sync + 'static,
R: RetryableStrategy + Send + Sync + 'static = DefaultRetryableStrategy,
> {
pub struct RetryTransientMiddleware<T: RetryPolicy + Send + Sync + 'static> {
retry_policy: T,
retryable_strategy: R,
}
impl<T: RetryPolicy + Send + Sync> RetryTransientMiddleware<T, DefaultRetryableStrategy> {
/// Construct `RetryTransientMiddleware` with a [retry_policy][RetryPolicy].
impl<T: RetryPolicy + Send + Sync> RetryTransientMiddleware<T> {
/// Construct `RetryTransientMiddleware` with a [retry_policy][retry_policies::RetryPolicy].
pub fn new_with_policy(retry_policy: T) -> Self {
Self::new_with_policy_and_strategy(retry_policy, DefaultRetryableStrategy)
Self { retry_policy }
}
}
impl<T, R> RetryTransientMiddleware<T, R>
impl<T, Svc> Layer<Svc> for RetryTransientMiddleware<T>
where
T: RetryPolicy + Send + Sync,
R: RetryableStrategy + Send + Sync,
T: RetryPolicy + Clone + Send + Sync + 'static,
{
/// Construct `RetryTransientMiddleware` with a [retry_policy][RetryPolicy] and [retryable_strategy](RetryableStrategy).
pub fn new_with_policy_and_strategy(retry_policy: T, retryable_strategy: R) -> Self {
Self {
retry_policy,
retryable_strategy,
type Service = Retry<TowerRetryPolicy<T>, Svc>;
fn layer(&self, inner: Svc) -> Self::Service {
Retry {
policy: TowerRetryPolicy {
n_past_retries: 0,
retry_policy: self.retry_policy.clone(),
},
service: inner,
}
}
}
#[cfg_attr(not(target_arch = "wasm32"), async_trait::async_trait)]
#[cfg_attr(target_arch = "wasm32", async_trait::async_trait(?Send))]
impl<T, R> Middleware for RetryTransientMiddleware<T, R>
where
T: RetryPolicy + Send + Sync,
R: RetryableStrategy + Send + Sync + 'static,
{
async fn handle(
&self,
req: Request,
extensions: &mut Extensions,
next: Next<'_>,
) -> Result<Response> {
// TODO: Ideally we should create a new instance of the `Extensions` map to pass
// downstream. This will guard against previous retries poluting `Extensions`.
// That is, we only return what's populated in the typemap for the last retry attempt
// and copy those into the the `global` Extensions map.
self.execute_with_retry(req, next, extensions).await
#[derive(Clone)]
pub struct TowerRetryPolicy<T> {
n_past_retries: u32,
retry_policy: T,
}
pin_project! {
pub struct RetryFuture<T>
{
retry: Option<TowerRetryPolicy<T>>,
#[pin]
sleep: Sleep,
}
}
impl<T, R> RetryTransientMiddleware<T, R>
impl<T> Future for RetryFuture<T> {
type Output = TowerRetryPolicy<T>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = self.project();
ready!(this.sleep.poll(cx));
Poll::Ready(
this.retry
.take()
.expect("poll should not be called more than once"),
)
}
}
impl<T: RetryPolicy + Clone> Policy for TowerRetryPolicy<T> {
type Future = RetryFuture<T>;
fn retry(&self, _req: &Request, result: &Result<Response, Error>) -> Option<Self::Future> {
// We classify the response which will return None if not
// errors were returned.
match Retryable::from_reqwest_response(result) {
Some(Retryable::Transient) => {
// If the response failed and the error type was transient
// we can safely try to retry the request.
let retry_decicion = self.retry_policy.should_retry(self.n_past_retries);
if let retry_policies::RetryDecision::Retry { execute_after } = retry_decicion {
let duration = (execute_after - Utc::now()).to_std().ok()?;
// Sleep the requested amount before we try again.
tracing::warn!(
"Retry attempt #{}. Sleeping {:?} before the next attempt",
self.n_past_retries,
duration
);
let sleep = tokio::time::sleep(duration);
Some(RetryFuture {
retry: Some(TowerRetryPolicy {
n_past_retries: self.n_past_retries + 1,
retry_policy: self.retry_policy.clone(),
}),
sleep,
})
} else {
None
}
}
Some(_) | None => None,
}
}
fn clone_request(&self, req: &Request) -> Option<Request> {
req.try_clone()
}
}
pub trait Policy: Sized {
/// The [`Future`] type returned by [`Policy::retry`].
type Future: Future<Output = Self>;
/// Check the policy if a certain request should be retried.
///
/// This method is passed a reference to the original request, and either
/// the [`Service::Response`] or [`Service::Error`] from the inner service.
///
/// If the request should **not** be retried, return `None`.
///
/// If the request *should* be retried, return `Some` future of a new
/// policy that would apply for the next request attempt.
///
/// [`Service::Response`]: crate::Service::Response
/// [`Service::Error`]: crate::Service::Error
fn retry(&self, req: &Request, result: &Result<Response, Error>) -> Option<Self::Future>;
/// Tries to clone a request before being passed to the inner service.
///
/// If the request cannot be cloned, return [`None`].
fn clone_request(&self, req: &Request) -> Option<Request>;
}
pin_project! {
/// Configure retrying requests of "failed" responses.
///
/// A [`Policy`] classifies what is a "failed" response.
#[derive(Clone, Debug)]
pub struct Retry<P, S> {
#[pin]
policy: P,
service: S,
}
}
impl<P, S> Service for Retry<P, S>
where
T: RetryPolicy + Send + Sync,
R: RetryableStrategy + Send + Sync,
P: 'static + Policy + Clone,
S: 'static + Service + Clone,
{
/// This function will try to execute the request, if it fails
/// with an error classified as transient it will call itself
/// to retry the request.
async fn execute_with_retry<'a>(
&'a self,
req: Request,
next: Next<'a>,
ext: &'a mut Extensions,
) -> Result<Response> {
let mut n_past_retries = 0;
type Future = ResponseFuture<P, S>;
fn call(&mut self, request: Request, ext: &mut Extensions) -> Self::Future {
let cloned = self.policy.clone_request(&request);
let future = self.service.call(request, ext);
ResponseFuture::new(cloned, self.clone(), future)
}
// fn call(&mut self, request: Request) -> Self::Future {
// let cloned = self.policy.clone_request(&request);
// let future = self.service.call(request);
// ResponseFuture::new(cloned, self.clone(), future)
// }
}
pin_project! {
/// The [`Future`] returned by a [`Retry`] service.
#[derive(Debug)]
pub struct ResponseFuture<P, S>
where
P: Policy,
S: Service,
{
request: Option<Request>,
#[pin]
retry: Retry<P, S>,
#[pin]
state: State<S::Future, P::Future>,
}
}
pin_project! {
#[project = StateProj]
#[derive(Debug)]
enum State<F, P> {
// Polling the future from [`Service::call`]
Called {
#[pin]
future: F
},
// Polling the future from [`Policy::retry`]
Checking {
#[pin]
checking: P
},
// Polling [`Service::poll_ready`] after [`Checking`] was OK.
Retrying,
}
}
impl<P, S> ResponseFuture<P, S>
where
P: Policy,
S: Service,
{
pub(crate) fn new(
request: Option<Request>,
retry: Retry<P, S>,
future: S::Future,
) -> ResponseFuture<P, S> {
ResponseFuture {
request,
retry,
state: State::Called { future },
}
}
}
impl<P, S> Future for ResponseFuture<P, S>
where
P: Policy + Clone,
S: Service + Clone,
{
type Output = Result<Response, Error>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let mut this = self.project();
loop {
// Cloning the request object before-the-fact is not ideal..
// However, if the body of the request is not static, e.g of type `Bytes`,
// the Clone operation should be of constant complexity and not O(N)
// since the byte abstraction is a shared pointer over a buffer.
let duplicate_request = req.try_clone().ok_or_else(|| {
Error::Middleware(anyhow!(
"Request object is not clonable. Are you passing a streaming body?".to_string()
))
})?;
let result = next.clone().run(duplicate_request, ext).await;
// We classify the response which will return None if not
// errors were returned.
break match self.retryable_strategy.handle(&result) {
Some(Retryable::Transient) => {
// If the response failed and the error type was transient
// we can safely try to retry the request.
let retry_decision = self.retry_policy.should_retry(n_past_retries);
if let retry_policies::RetryDecision::Retry { execute_after } = retry_decision {
let duration = (execute_after - Utc::now())
.to_std()
.map_err(Error::middleware)?;
// Sleep the requested amount before we try again.
tracing::warn!(
"Retry attempt #{}. Sleeping {:?} before the next attempt",
n_past_retries,
duration
);
#[cfg(not(target_arch = "wasm32"))]
tokio::time::sleep(duration).await;
#[cfg(target_arch = "wasm32")]
wasm_timer::Delay::new(duration)
.await
.expect("failed sleeping");
n_past_retries += 1;
continue;
match this.state.as_mut().project() {
StateProj::Called { future } => {
let result = ready!(future.poll(cx));
if let Some(ref req) = this.request {
match this.retry.policy.retry(req, &result) {
Some(checking) => {
this.state.set(State::Checking { checking });
}
None => return Poll::Ready(result),
}
} else {
result
// request wasn't cloned, so no way to retry it
return Poll::Ready(result);
}
}
Some(_) | None => result,
};
StateProj::Checking { checking } => {
this.retry
.as_mut()
.project()
.policy
.set(ready!(checking.poll(cx)));
this.state.set(State::Retrying);
}
StateProj::Retrying => {
let req = this
.request
.take()
.expect("retrying requires cloned request");
*this.request = this.retry.policy.clone_request(&req);
this.state.set(State::Called {
future: this
.retry
.as_mut()
.project()
.service
.call(req, &mut Extensions::new()),
});
}
}
}
}
}

View File

@ -1,4 +1,4 @@
use crate::retryable_strategy::{DefaultRetryableStrategy, RetryableStrategy};
use http::StatusCode;
use reqwest_middleware::Error;
/// Classification of an error/status returned by request.
@ -16,7 +16,62 @@ impl Retryable {
/// Returns `None` if the response object does not contain any errors.
///
pub fn from_reqwest_response(res: &Result<reqwest::Response, Error>) -> Option<Self> {
DefaultRetryableStrategy.handle(res)
match res {
Ok(success) => {
let status = success.status();
if status.is_server_error() {
Some(Retryable::Transient)
} else if status.is_client_error()
&& status != StatusCode::REQUEST_TIMEOUT
&& status != StatusCode::TOO_MANY_REQUESTS
{
Some(Retryable::Fatal)
} else if status.is_success() {
None
} else if status == StatusCode::REQUEST_TIMEOUT
|| status == StatusCode::TOO_MANY_REQUESTS
{
Some(Retryable::Transient)
} else {
Some(Retryable::Fatal)
}
}
Err(error) => match error {
// If something fails in the middleware we're screwed.
Error::Middleware(_) => Some(Retryable::Fatal),
Error::Reqwest(error) => {
if error.is_timeout() || error.is_connect() {
Some(Retryable::Transient)
} else if error.is_body()
|| error.is_decode()
|| error.is_builder()
|| error.is_redirect()
{
Some(Retryable::Fatal)
} else if error.is_request() {
// It seems that hyper::Error(IncompleteMessage) is not correctly handled by reqwest.
// Here we check if the Reqwest error was originated by hyper and map it consistently.
if let Some(hyper_error) = get_source_error_type::<hyper::Error>(&error) {
// The hyper::Error(IncompleteMessage) is raised if the HTTP response is well formatted but does not contain all the bytes.
// This can happen when the server has started sending back the response but the connection is cut halfway thorugh.
// We can safely retry the call, hence marking this error as [`Retryable::Transient`].
if hyper_error.is_incomplete_message() {
Some(Retryable::Transient)
} else {
Some(Retryable::Fatal)
}
} else {
Some(Retryable::Fatal)
}
} else {
// We omit checking if error.is_status() since we check that already.
// However, if Response::error_for_status is used the status will still
// remain in the response object.
None
}
}
},
}
}
}
@ -25,3 +80,19 @@ impl From<&reqwest::Error> for Retryable {
Retryable::Transient
}
}
/// Downcasts the given err source into T.
fn get_source_error_type<T: std::error::Error + 'static>(
err: &dyn std::error::Error,
) -> Option<&T> {
let mut source = err.source();
while let Some(err) = source {
if let Some(hyper_err) = err.downcast_ref::<T>() {
return Some(hyper_err);
}
source = err.source();
}
None
}

View File

@ -1,213 +0,0 @@
use crate::retryable::Retryable;
use http::StatusCode;
use reqwest_middleware::Error;
/// A strategy to create a [`Retryable`] from a [`Result<reqwest::Response, reqwest_middleware::Error>`]
///
/// A [`RetryableStrategy`] has a single `handler` functions.
/// The result of calling the request could be:
/// - [`reqwest::Response`] In case the request has been sent and received correctly
/// This could however still mean that the server responded with a erroneous response.
/// For example a HTTP statuscode of 500
/// - [`reqwest_middleware::Error`] In this case the request actually failed.
/// This could, for example, be caused by a timeout on the connection.
///
/// Example:
///
/// ```
/// use reqwest_retry::{default_on_request_failure, policies::ExponentialBackoff, Retryable, RetryableStrategy, RetryTransientMiddleware};
/// use reqwest::{Request, Response};
/// use reqwest_middleware::{ClientBuilder, Middleware, Next, Result};
/// use task_local_extensions::Extensions;
///
/// // Log each request to show that the requests will be retried
/// struct LoggingMiddleware;
///
/// #[async_trait::async_trait]
/// impl Middleware for LoggingMiddleware {
/// async fn handle(
/// &self,
/// req: Request,
/// extensions: &mut Extensions,
/// next: Next<'_>,
/// ) -> Result<Response> {
/// println!("Request started {}", req.url());
/// let res = next.run(req, extensions).await;
/// println!("Request finished");
/// res
/// }
/// }
///
/// // Just a toy example, retry when the successful response code is 201, else do nothing.
/// struct Retry201;
/// impl RetryableStrategy for Retry201 {
/// fn handle(&self, res: &Result<reqwest::Response>) -> Option<Retryable> {
/// match res {
/// // retry if 201
/// Ok(success) if success.status() == 201 => Some(Retryable::Transient),
/// // otherwise do not retry a successful request
/// Ok(success) => None,
/// // but maybe retry a request failure
/// Err(error) => default_on_request_failure(error),
/// }
/// }
/// }
///
/// #[tokio::main]
/// async fn main() {
/// // Exponential backoff with max 2 retries
/// let retry_policy = ExponentialBackoff::builder()
/// .build_with_max_retries(2);
///
/// // Create the actual middleware, with the exponential backoff and custom retry stategy.
/// let ret_s = RetryTransientMiddleware::new_with_policy_and_strategy(
/// retry_policy,
/// Retry201,
/// );
///
/// let client = ClientBuilder::new(reqwest::Client::new())
/// // Retry failed requests.
/// .with(ret_s)
/// // Log the requests
/// .with(LoggingMiddleware)
/// .build();
///
/// // Send request which should get a 201 response. So it will be retried
/// let r = client
/// .get("https://httpbin.org/status/201")
/// .send()
/// .await;
/// println!("{:?}", r);
///
/// // Send request which should get a 200 response. So it will not be retried
/// let r = client
/// .get("https://httpbin.org/status/200")
/// .send()
/// .await;
/// println!("{:?}", r);
/// }
/// ```
pub trait RetryableStrategy {
fn handle(&self, res: &Result<reqwest::Response, Error>) -> Option<Retryable>;
}
/// The default [`RetryableStrategy`] for [`RetryTransientMiddleware`](crate::RetryTransientMiddleware).
pub struct DefaultRetryableStrategy;
impl RetryableStrategy for DefaultRetryableStrategy {
fn handle(&self, res: &Result<reqwest::Response, Error>) -> Option<Retryable> {
match res {
Ok(success) => default_on_request_success(success),
Err(error) => default_on_request_failure(error),
}
}
}
/// Default request success retry strategy.
///
/// Will only retry if:
/// * The status was 5XX (server error)
/// * The status was 408 (request timeout) or 429 (too many requests)
///
/// Note that success here means that the request finished without interruption, not that it was logically OK.
pub fn default_on_request_success(success: &reqwest::Response) -> Option<Retryable> {
let status = success.status();
if status.is_server_error() {
Some(Retryable::Transient)
} else if status.is_client_error()
&& status != StatusCode::REQUEST_TIMEOUT
&& status != StatusCode::TOO_MANY_REQUESTS
{
Some(Retryable::Fatal)
} else if status.is_success() {
None
} else if status == StatusCode::REQUEST_TIMEOUT || status == StatusCode::TOO_MANY_REQUESTS {
Some(Retryable::Transient)
} else {
Some(Retryable::Fatal)
}
}
/// Default request failure retry strategy.
///
/// Will only retry if the request failed due to a network error
pub fn default_on_request_failure(error: &Error) -> Option<Retryable> {
match error {
// If something fails in the middleware we're screwed.
Error::Middleware(_) => Some(Retryable::Fatal),
Error::Reqwest(error) => {
#[cfg(not(target_arch = "wasm32"))]
let is_connect = error.is_connect();
#[cfg(target_arch = "wasm32")]
let is_connect = false;
if error.is_timeout() || is_connect {
Some(Retryable::Transient)
} else if error.is_body()
|| error.is_decode()
|| error.is_builder()
|| error.is_redirect()
{
Some(Retryable::Fatal)
} else if error.is_request() {
// It seems that hyper::Error(IncompleteMessage) is not correctly handled by reqwest.
// Here we check if the Reqwest error was originated by hyper and map it consistently.
#[cfg(not(target_arch = "wasm32"))]
if let Some(hyper_error) = get_source_error_type::<hyper::Error>(&error) {
// The hyper::Error(IncompleteMessage) is raised if the HTTP response is well formatted but does not contain all the bytes.
// This can happen when the server has started sending back the response but the connection is cut halfway thorugh.
// We can safely retry the call, hence marking this error as [`Retryable::Transient`].
// Instead hyper::Error(Canceled) is raised when the connection is
// gracefully closed on the server side.
if hyper_error.is_incomplete_message() || hyper_error.is_canceled() {
Some(Retryable::Transient)
// Try and downcast the hyper error to io::Error if that is the
// underlying error, and try and classify it.
} else if let Some(io_error) =
get_source_error_type::<std::io::Error>(hyper_error)
{
Some(classify_io_error(io_error))
} else {
Some(Retryable::Fatal)
}
} else {
Some(Retryable::Fatal)
}
#[cfg(target_arch = "wasm32")]
Some(Retryable::Fatal)
} else {
// We omit checking if error.is_status() since we check that already.
// However, if Response::error_for_status is used the status will still
// remain in the response object.
None
}
}
}
}
#[cfg(not(target_arch = "wasm32"))]
fn classify_io_error(error: &std::io::Error) -> Retryable {
match error.kind() {
std::io::ErrorKind::ConnectionReset | std::io::ErrorKind::ConnectionAborted => {
Retryable::Transient
}
_ => Retryable::Fatal,
}
}
/// Downcasts the given err source into T.
#[cfg(not(target_arch = "wasm32"))]
fn get_source_error_type<T: std::error::Error + 'static>(
err: &dyn std::error::Error,
) -> Option<&T> {
let mut source = err.source();
while let Some(err) = source {
if let Some(err) = err.downcast_ref::<T>() {
return Some(err);
}
source = err.source();
}
None
}

View File

@ -1,12 +1,9 @@
use futures::future::BoxFuture;
use async_std::io::ReadExt;
use async_std::io::WriteExt;
use async_std::net::{TcpListener, TcpStream};
use futures::stream::StreamExt;
use std::error::Error;
use std::fmt;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::{TcpListener, TcpStream};
type CustomMessageHandler = Box<
dyn Fn(TcpStream) -> BoxFuture<'static, Result<(), Box<dyn std::error::Error>>> + Send + Sync,
>;
/// This is a simple server that returns the responses given at creation time: [`self.raw_http_responses`] following a round-robin mechanism.
pub struct SimpleServer {
@ -15,7 +12,6 @@ pub struct SimpleServer {
host: String,
raw_http_responses: Vec<String>,
calls_counter: usize,
custom_handler: Option<CustomMessageHandler>,
}
/// Request-Line = Method SP Request-URI SP HTTP-Version CRLF
@ -50,21 +46,9 @@ impl SimpleServer {
host: host.to_string(),
raw_http_responses,
calls_counter: 0,
custom_handler: None,
})
}
pub fn set_custom_handler(
&mut self,
custom_handler: impl Fn(TcpStream) -> BoxFuture<'static, Result<(), Box<dyn std::error::Error>>>
+ Send
+ Sync
+ 'static,
) -> &mut Self {
self.custom_handler.replace(Box::new(custom_handler));
self
}
/// Returns the uri in which the server is listening to.
pub fn uri(&self) -> String {
format!("http://{}:{}", self.host, self.port)
@ -72,9 +56,9 @@ impl SimpleServer {
/// Starts the TcpListener and handles the requests.
pub async fn start(mut self) {
loop {
match self.listener.accept().await {
Ok((stream, _)) => {
while let Some(stream) = self.listener.incoming().next().await {
match stream {
Ok(stream) => {
match self.handle_connection(stream).await {
Ok(_) => (),
Err(e) => {
@ -95,15 +79,11 @@ impl SimpleServer {
///
/// Returns a 400 if the request if formatted badly.
async fn handle_connection(&self, mut stream: TcpStream) -> Result<(), Box<dyn Error>> {
if let Some(ref custom_handler) = self.custom_handler {
return custom_handler(stream).await;
}
let mut buffer = vec![0; 1024];
let n = stream.read(&mut buffer).await.unwrap();
stream.read(&mut buffer).await.unwrap();
let request = String::from_utf8_lossy(&buffer[..n]);
let request = String::from_utf8_lossy(&buffer[..]);
let request_line = request.lines().next().unwrap();
let response = match Self::parse_request_line(request_line) {
@ -118,7 +98,7 @@ impl SimpleServer {
};
println!("-- Response --\n{}\n--------------", response.clone());
stream.write_all(response.as_bytes()).await.unwrap();
stream.write(response.as_bytes()).await.unwrap();
stream.flush().await.unwrap();
Ok(())

View File

@ -1,16 +1,12 @@
use futures::FutureExt;
use paste::paste;
use reqwest::Client;
use reqwest::StatusCode;
use reqwest_middleware::ClientBuilder;
use reqwest_retry::{policies::ExponentialBackoff, RetryTransientMiddleware};
use std::sync::atomic::AtomicI8;
use std::sync::{
atomic::{AtomicU32, Ordering},
Arc,
};
use tokio::io::AsyncReadExt;
use tokio::io::AsyncWriteExt;
use wiremock::matchers::{method, path};
use wiremock::{Mock, MockServer, Respond, ResponseTemplate};
@ -53,12 +49,12 @@ macro_rules! assert_retry_succeeds_inner {
let reqwest_client = Client::builder().build().unwrap();
let client = ClientBuilder::new(reqwest_client)
.with(RetryTransientMiddleware::new_with_policy(
ExponentialBackoff::builder()
.retry_bounds(
std::time::Duration::from_millis(30),
std::time::Duration::from_millis(100),
)
.build_with_max_retries(retry_amount),
ExponentialBackoff {
max_n_retries: retry_amount,
max_retry_interval: std::time::Duration::from_millis(30),
min_retry_interval: std::time::Duration::from_millis(100),
backoff_exponent: 2,
},
))
.build();
@ -189,12 +185,12 @@ async fn assert_retry_on_request_timeout() {
let reqwest_client = Client::builder().build().unwrap();
let client = ClientBuilder::new(reqwest_client)
.with(RetryTransientMiddleware::new_with_policy(
ExponentialBackoff::builder()
.retry_bounds(
std::time::Duration::from_millis(30),
std::time::Duration::from_millis(100),
)
.build_with_max_retries(3),
ExponentialBackoff {
max_n_retries: 3,
max_retry_interval: std::time::Duration::from_millis(100),
min_retry_interval: std::time::Duration::from_millis(30),
backoff_exponent: 2,
},
))
.build();
@ -244,111 +240,12 @@ async fn assert_retry_on_incomplete_message() {
let reqwest_client = Client::builder().build().unwrap();
let client = ClientBuilder::new(reqwest_client)
.with(RetryTransientMiddleware::new_with_policy(
ExponentialBackoff::builder()
.retry_bounds(
std::time::Duration::from_millis(30),
std::time::Duration::from_millis(100),
)
.build_with_max_retries(3),
))
.build();
let resp = client
.get(&format!("{}/foo", uri))
.timeout(std::time::Duration::from_millis(100))
.send()
.await
.expect("call failed");
assert_eq!(resp.status(), 200);
}
#[tokio::test]
async fn assert_retry_on_hyper_canceled() {
let counter = Arc::new(AtomicI8::new(0));
let mut simple_server = SimpleServer::new("127.0.0.1", None, vec![])
.await
.expect("Error when creating a simple server");
simple_server.set_custom_handler(move |mut stream| {
let counter = counter.clone();
async move {
let mut buffer = Vec::new();
stream.read_buf(&mut buffer).await.unwrap();
if counter.fetch_add(1, Ordering::SeqCst) > 1 {
// This triggeres hyper:Error(Canceled).
let _res = stream
.into_std()
.unwrap()
.shutdown(std::net::Shutdown::Both);
} else {
let _res = stream.write("HTTP/1.1 200 OK\r\n\r\n".as_bytes()).await;
}
Ok(())
}
.boxed()
});
let uri = simple_server.uri();
tokio::spawn(simple_server.start());
let reqwest_client = Client::builder().build().unwrap();
let client = ClientBuilder::new(reqwest_client)
.with(RetryTransientMiddleware::new_with_policy(
ExponentialBackoff::builder()
.retry_bounds(
std::time::Duration::from_millis(30),
std::time::Duration::from_millis(100),
)
.build_with_max_retries(3),
))
.build();
let resp = client
.get(&format!("{}/foo", uri))
.timeout(std::time::Duration::from_millis(100))
.send()
.await
.expect("call failed");
assert_eq!(resp.status(), 200);
}
#[tokio::test]
async fn assert_retry_on_connection_reset_by_peer() {
let counter = Arc::new(AtomicI8::new(0));
let mut simple_server = SimpleServer::new("127.0.0.1", None, vec![])
.await
.expect("Error when creating a simple server");
simple_server.set_custom_handler(move |mut stream| {
let counter = counter.clone();
async move {
let mut buffer = Vec::new();
stream.read_buf(&mut buffer).await.unwrap();
if counter.fetch_add(1, Ordering::SeqCst) > 1 {
// This triggeres hyper:Error(Io, io::Error(ConnectionReset)).
drop(stream);
} else {
let _res = stream.write("HTTP/1.1 200 OK\r\n\r\n".as_bytes()).await;
}
Ok(())
}
.boxed()
});
let uri = simple_server.uri();
tokio::spawn(simple_server.start());
let reqwest_client = Client::builder().build().unwrap();
let client = ClientBuilder::new(reqwest_client)
.with(RetryTransientMiddleware::new_with_policy(
ExponentialBackoff::builder()
.retry_bounds(
std::time::Duration::from_millis(30),
std::time::Duration::from_millis(100),
)
.build_with_max_retries(3),
ExponentialBackoff {
max_n_retries: 3,
max_retry_interval: std::time::Duration::from_millis(100),
min_retry_interval: std::time::Duration::from_millis(30),
backoff_exponent: 2,
},
))
.build();

View File

@ -6,41 +6,6 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
## [Unreleased]
## [0.4.6] - 2023-08-23
### Added
- Add support for opentelemetry 0.20
## [0.4.5] - 2023-06-20
### Added
- A new extension `DisableOtelPropagation` which stops opentelemetry contexts propagating
- Support for opentelemetry 0.19
## [0.4.4] - 2023-05-15
### Added
- A new `default_span_name` method for use in custom span backends.
## [0.4.3] - 2023-05-15
### Fixed
- Fix span and http status codes
## [0.4.2] - 2023-05-12
### Added
- `OtelPathNames` extension to provide known parameterized paths that will be used in span names
### Changed
- `DefaultSpanBackend` and `SpanBackendWithUrl` default span name to HTTP method name instead of `reqwest-http-client`
## [0.4.1] - 2023-03-09
### Added
- Support for `wasm32-unknown-unknown` target
## [0.4.0] - 2022-11-15
### Changed

View File

@ -1,6 +1,6 @@
[package]
name = "reqwest-tracing"
version = "0.4.7"
version = "0.5.0"
authors = ["Rodrigo Gryzinski <rodrigo.gryzinski@truelayer.com>"]
edition = "2018"
description = "Opentracing middleware for reqwest."
@ -16,49 +16,33 @@ opentelemetry_0_15 = ["opentelemetry_0_15_pkg", "tracing-opentelemetry_0_14_pkg"
opentelemetry_0_16 = ["opentelemetry_0_16_pkg", "tracing-opentelemetry_0_16_pkg"]
opentelemetry_0_17 = ["opentelemetry_0_17_pkg", "tracing-opentelemetry_0_17_pkg"]
opentelemetry_0_18 = ["opentelemetry_0_18_pkg", "tracing-opentelemetry_0_18_pkg"]
opentelemetry_0_19 = ["opentelemetry_0_19_pkg", "tracing-opentelemetry_0_19_pkg"]
opentelemetry_0_20 = ["opentelemetry_0_20_pkg", "tracing-opentelemetry_0_20_pkg"]
opentelemetry_0_21 = ["opentelemetry_0_21_pkg", "tracing-opentelemetry_0_22_pkg"]
[dependencies]
reqwest-middleware = { version = "0.2.0", path = "../reqwest-middleware" }
reqwest-middleware = { version = "0.3.0", path = "../reqwest-middleware" }
anyhow = "1.0.70"
async-trait = "0.1.51"
matchit = "0.7.0"
reqwest = { version = "0.11.0", default-features = false }
task-local-extensions = "0.1.4"
reqwest = { version = "0.11", default-features = false }
task-local-extensions = "0.1.1"
tracing = "0.1.26"
pin-project-lite = "0.2"
opentelemetry_0_13_pkg = { package = "opentelemetry", version = "0.13.0", optional = true }
opentelemetry_0_14_pkg = { package = "opentelemetry", version = "0.14.0", optional = true }
opentelemetry_0_15_pkg = { package = "opentelemetry", version = "0.15.0", optional = true }
opentelemetry_0_16_pkg = { package = "opentelemetry", version = "0.16.0", optional = true }
opentelemetry_0_17_pkg = { package = "opentelemetry", version = "0.17.0", optional = true }
opentelemetry_0_18_pkg = { package = "opentelemetry", version = "0.18.0", optional = true }
opentelemetry_0_19_pkg = { package = "opentelemetry", version = "0.19.0", optional = true }
opentelemetry_0_20_pkg = { package = "opentelemetry", version = "0.20.0", optional = true }
opentelemetry_0_21_pkg = { package = "opentelemetry", version = "0.21.0", optional = true }
tracing-opentelemetry_0_12_pkg = { package = "tracing-opentelemetry", version = "0.12.0", optional = true }
tracing-opentelemetry_0_13_pkg = { package = "tracing-opentelemetry", version = "0.13.0", optional = true }
tracing-opentelemetry_0_14_pkg = { package = "tracing-opentelemetry", version = "0.14.0", optional = true }
tracing-opentelemetry_0_16_pkg = { package = "tracing-opentelemetry", version = "0.16.0", optional = true }
tracing-opentelemetry_0_17_pkg = { package = "tracing-opentelemetry", version = "0.17.0", optional = true }
tracing-opentelemetry_0_18_pkg = { package = "tracing-opentelemetry", version = "0.18.0", optional = true }
tracing-opentelemetry_0_19_pkg = { package = "tracing-opentelemetry", version = "0.19.0", optional = true }
tracing-opentelemetry_0_20_pkg = { package = "tracing-opentelemetry", version = "0.20.0", optional = true }
tracing-opentelemetry_0_22_pkg = { package = "tracing-opentelemetry", version = "0.22.0", optional = true }
opentelemetry_0_13_pkg = { package = "opentelemetry", version = "0.13", optional = true }
opentelemetry_0_14_pkg = { package = "opentelemetry", version = "0.14", optional = true }
opentelemetry_0_15_pkg = { package = "opentelemetry", version = "0.15", optional = true }
opentelemetry_0_16_pkg = { package = "opentelemetry", version = "0.16", optional = true }
opentelemetry_0_17_pkg = { package = "opentelemetry", version = "0.17", optional = true }
opentelemetry_0_18_pkg = { package = "opentelemetry", version = "0.18", optional = true }
tracing-opentelemetry_0_12_pkg = { package = "tracing-opentelemetry",version = "0.12", optional = true }
tracing-opentelemetry_0_13_pkg = { package = "tracing-opentelemetry", version = "0.13", optional = true }
tracing-opentelemetry_0_14_pkg = { package = "tracing-opentelemetry",version = "0.14", optional = true }
tracing-opentelemetry_0_16_pkg = { package = "tracing-opentelemetry",version = "0.16", optional = true }
tracing-opentelemetry_0_17_pkg = { package = "tracing-opentelemetry",version = "0.17", optional = true }
tracing-opentelemetry_0_18_pkg = { package = "tracing-opentelemetry",version = "0.18", optional = true }
[target.'cfg(target_arch = "wasm32")'.dependencies]
getrandom = { version = "0.2.0", features = ["js"] }
[dev-dependencies]
tokio = { version = "1.0.0", features = ["macros"] }
tracing_subscriber_0_2 = { package = "tracing-subscriber", version = "0.2.0" }
tracing_subscriber_0_3 = { package = "tracing-subscriber", version = "0.3.0" }
wiremock = "0.5.0"
opentelemetry_sdk_0_21 = { package = "opentelemetry_sdk", version = "0.21.0", features = ["trace"] }
opentelemetry_stdout_0_1 = { package = "opentelemetry-stdout", version = "0.1.0", features = ["trace"] }
opentelemetry_stdout_0_2 = { package = "opentelemetry-stdout", version = "0.2.0", features = ["trace"] }
tokio = { version = "1", features = ["macros"] }
tracing_subscriber_0_2 = { package = "tracing-subscriber", version = "0.2" }
tracing_subscriber_0_3 = { package = "tracing-subscriber", version = "0.3" }
wiremock = "0.5"

View File

@ -25,7 +25,7 @@ tokio = { version = "1.12.0", features = ["macros", "rt-multi-thread"] }
tracing = "0.1"
tracing-opentelemetry = "0.18"
tracing-subscriber = "0.3"
task-local-extensions = "0.1.4"
task-local-extensions = "0.1.0"
```
```rust,skip
@ -44,7 +44,7 @@ pub struct TimeTrace;
impl ReqwestOtelSpanBackend for TimeTrace {
fn on_request_start(req: &Request, extension: &mut Extensions) -> Span {
extension.insert(Instant::now());
reqwest_otel_span!(name="example-request", req, time_elapsed = tracing::field::Empty)
reqwest_otel_span!(req, time_elapsed = tracing::field::Empty)
}
fn on_request_end(span: &Span, outcome: &Result<Response>, extension: &mut Extensions) {
@ -92,9 +92,8 @@ an opentelemetry version feature:
reqwest-tracing = { version = "0.3.1", features = ["opentelemetry_0_18"] }
```
Available opentelemetry features are `opentelemetry_0_21`, `opentelemetry_0_20`,
`opentelemetry_0_19`, `opentelemetry_0_18`, `opentelemetry_0_17`, `opentelemetry_0_16`,
`opentelemetry_0_15`, `opentelemetry_0_14` and `opentelemetry_0_13`.
Available opentelemetry features are `opentelemetry_0_18`, `opentelemetry_0_17`, `opentelemetry_0_16`, `opentelemetry_0_15`, `opentelemetry_0_14` and
`opentelemetry_0_13`.
#### License

View File

@ -4,11 +4,11 @@
//!
//! The simplest possible usage:
//! ```no_run
//! # use reqwest_middleware::Result;
//! # use reqwest_middleware::Error;
//! use reqwest_middleware::{ClientBuilder};
//! use reqwest_tracing::TracingMiddleware;
//!
//! # async fn example() -> Result<()> {
//! # async fn example() -> Result<(), Error> {
//! let reqwest_client = reqwest::Client::builder().build().unwrap();
//! let client = ClientBuilder::new(reqwest_client)
//! // Insert the tracing middleware
@ -22,12 +22,12 @@
//!
//! To customise the span names use [`OtelName`].
//! ```no_run
//! # use reqwest_middleware::Result;
//! # use reqwest_middleware::Error;
//! use reqwest_middleware::{ClientBuilder, Extension};
//! use reqwest_tracing::{
//! TracingMiddleware, OtelName
//! };
//! # async fn example() -> Result<()> {
//! # async fn example() -> Result<(), Error> {
//! let reqwest_client = reqwest::Client::builder().build().unwrap();
//! let client = ClientBuilder::new(reqwest_client)
//! // Inserts the extension before the request is started
@ -52,7 +52,7 @@
//!
//! Note that Opentelemetry tracks start and stop already, there is no need to have a custom builder like this.
//! ```rust
//! use reqwest_middleware::Result;
//! use reqwest_middleware::Error;
//! use task_local_extensions::Extensions;
//! use reqwest::{Request, Response};
//! use reqwest_middleware::ClientBuilder;
@ -62,16 +62,17 @@
//! use tracing::Span;
//! use std::time::{Duration, Instant};
//!
//! pub struct TimeTrace;
//! pub struct TimeTrace(Instant);
//!
//! impl ReqwestOtelSpanBackend for TimeTrace {
//! fn on_request_start(req: &Request, extension: &mut Extensions) -> Span {
//! extension.insert(Instant::now());
//! reqwest_otel_span!(name="example-request", req, time_elapsed = tracing::field::Empty)
//! fn on_request_start(req: &Request, _extension: &mut Extensions) -> (Self, Span) {
//! let now = Self(Instant::now());
//! let span = reqwest_otel_span!(name="example-request", req, time_elapsed = tracing::field::Empty);
//! (now, span)
//! }
//!
//! fn on_request_end(span: &Span, outcome: &Result<Response>, extension: &mut Extensions) {
//! let time_elapsed = extension.get::<Instant>().unwrap().elapsed().as_millis() as i64;
//! fn on_request_end(self, span: &Span, outcome: &Result<Response, Error>) {
//! let time_elapsed = self.0.elapsed().as_millis() as i64;
//! default_on_request_end(span, outcome);
//! span.record("time_elapsed", &time_elapsed);
//! }
@ -90,19 +91,15 @@ mod middleware;
feature = "opentelemetry_0_16",
feature = "opentelemetry_0_17",
feature = "opentelemetry_0_18",
feature = "opentelemetry_0_19",
feature = "opentelemetry_0_20",
feature = "opentelemetry_0_21",
))]
mod otel;
mod reqwest_otel_span_builder;
pub use middleware::TracingMiddleware;
pub use reqwest_otel_span_builder::{
default_on_request_end, default_on_request_failure, default_on_request_success,
default_span_name, DefaultSpanBackend, DisableOtelPropagation, OtelName, OtelPathNames,
ReqwestOtelSpanBackend, SpanBackendWithUrl, ERROR_CAUSE_CHAIN, ERROR_MESSAGE, HTTP_HOST,
HTTP_METHOD, HTTP_SCHEME, HTTP_STATUS_CODE, HTTP_URL, HTTP_USER_AGENT, NET_HOST_PORT,
OTEL_KIND, OTEL_NAME, OTEL_STATUS_CODE,
DefaultSpanBackend, OtelName, ReqwestOtelSpanBackend, SpanBackendWithUrl, ERROR_CAUSE_CHAIN,
ERROR_MESSAGE, HTTP_HOST, HTTP_METHOD, HTTP_SCHEME, HTTP_STATUS_CODE, HTTP_URL,
HTTP_USER_AGENT, NET_HOST_PORT, OTEL_KIND, OTEL_NAME, OTEL_STATUS_CODE,
};
#[doc(hidden)]

View File

@ -1,7 +1,14 @@
use std::{
future::Future,
task::{ready, Context, Poll},
};
use pin_project_lite::pin_project;
use reqwest::{Request, Response};
use reqwest_middleware::{Middleware, Next, Result};
use reqwest_middleware::{Error, Layer, Service};
use task_local_extensions::Extensions;
use tracing::Instrument;
// use tower::{Layer, Service};
use tracing::Span;
use crate::{DefaultSpanBackend, ReqwestOtelSpanBackend};
@ -10,6 +17,8 @@ pub struct TracingMiddleware<S: ReqwestOtelSpanBackend> {
span_backend: std::marker::PhantomData<S>,
}
impl<S: ReqwestOtelSpanBackend> Copy for TracingMiddleware<S> {}
impl<S: ReqwestOtelSpanBackend> TracingMiddleware<S> {
pub fn new() -> TracingMiddleware<S> {
TracingMiddleware {
@ -30,46 +39,82 @@ impl Default for TracingMiddleware<DefaultSpanBackend> {
}
}
#[cfg_attr(not(target_arch = "wasm32"), async_trait::async_trait)]
#[cfg_attr(target_arch = "wasm32", async_trait::async_trait(?Send))]
impl<ReqwestOtelSpan> Middleware for TracingMiddleware<ReqwestOtelSpan>
impl<ReqwestOtelSpan, Svc> Layer<Svc> for TracingMiddleware<ReqwestOtelSpan>
where
ReqwestOtelSpan: ReqwestOtelSpanBackend + Sync + Send + 'static,
Svc: Service,
{
async fn handle(
&self,
req: Request,
extensions: &mut Extensions,
next: Next<'_>,
) -> Result<Response> {
let request_span = ReqwestOtelSpan::on_request_start(&req, extensions);
type Service = TracingMiddlewareService<ReqwestOtelSpan, Svc>;
let outcome_future = async {
#[cfg(any(
feature = "opentelemetry_0_13",
feature = "opentelemetry_0_14",
feature = "opentelemetry_0_15",
feature = "opentelemetry_0_16",
feature = "opentelemetry_0_17",
feature = "opentelemetry_0_18",
feature = "opentelemetry_0_19",
feature = "opentelemetry_0_20",
feature = "opentelemetry_0_21",
))]
let req = if !extensions.contains::<crate::DisableOtelPropagation>() {
// Adds tracing headers to the given request to propagate the OpenTelemetry context to downstream revivers of the request.
// Spans added by downstream consumers will be part of the same trace.
crate::otel::inject_opentelemetry_context_into_request(req)
} else {
req
};
// Run the request
let outcome = next.run(req, extensions).await;
ReqwestOtelSpan::on_request_end(&request_span, &outcome, extensions);
outcome
};
outcome_future.instrument(request_span.clone()).await
fn layer(&self, inner: Svc) -> Self::Service {
TracingMiddlewareService {
service: inner,
_layer: *self,
}
}
}
/// Middleware Service for tracing requests using the current Opentelemetry Context.
pub struct TracingMiddlewareService<S: ReqwestOtelSpanBackend, Svc> {
_layer: TracingMiddleware<S>,
service: Svc,
}
impl<ReqwestOtelSpan, Svc> Service for TracingMiddlewareService<ReqwestOtelSpan, Svc>
where
ReqwestOtelSpan: ReqwestOtelSpanBackend + Sync + Send + 'static,
Svc: Service,
{
type Future = TracingMiddlewareFuture<ReqwestOtelSpan, Svc::Future>;
fn call(&mut self, req: Request, ext: &mut Extensions) -> Self::Future {
let (backend, span) = ReqwestOtelSpan::on_request_start(&req, ext);
// Adds tracing headers to the given request to propagate the OpenTelemetry context to downstream revivers of the request.
// Spans added by downstream consumers will be part of the same trace.
#[cfg(any(
feature = "opentelemetry_0_13",
feature = "opentelemetry_0_14",
feature = "opentelemetry_0_15",
feature = "opentelemetry_0_16",
feature = "opentelemetry_0_17",
feature = "opentelemetry_0_18",
))]
let request = crate::otel::inject_opentelemetry_context_into_request(request);
let future = self.service.call(req, ext);
TracingMiddlewareFuture {
span,
backend: Some(backend),
future,
}
}
}
pin_project!(
pub struct TracingMiddlewareFuture<S: ReqwestOtelSpanBackend, F> {
span: Span,
backend: Option<S>,
#[pin]
future: F,
}
);
impl<S: ReqwestOtelSpanBackend, F: Future<Output = Result<Response, Error>>> Future
for TracingMiddlewareFuture<S, F>
{
type Output = F::Output;
fn poll(self: std::pin::Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = self.project();
let outcome = {
let _guard = this.span.enter();
ready!(this.future.poll(cx))
};
this.backend
.take()
.expect("poll should not be called after completion")
.on_request_end(this.span, &outcome);
Poll::Ready(outcome)
}
}

View File

@ -21,15 +21,6 @@ use opentelemetry_0_17_pkg as opentelemetry;
#[cfg(feature = "opentelemetry_0_18")]
use opentelemetry_0_18_pkg as opentelemetry;
#[cfg(feature = "opentelemetry_0_19")]
use opentelemetry_0_19_pkg as opentelemetry;
#[cfg(feature = "opentelemetry_0_20")]
use opentelemetry_0_20_pkg as opentelemetry;
#[cfg(feature = "opentelemetry_0_21")]
use opentelemetry_0_21_pkg as opentelemetry;
#[cfg(feature = "opentelemetry_0_13")]
pub use tracing_opentelemetry_0_12_pkg as tracing_opentelemetry;
@ -48,15 +39,6 @@ pub use tracing_opentelemetry_0_17_pkg as tracing_opentelemetry;
#[cfg(feature = "opentelemetry_0_18")]
pub use tracing_opentelemetry_0_18_pkg as tracing_opentelemetry;
#[cfg(feature = "opentelemetry_0_19")]
pub use tracing_opentelemetry_0_19_pkg as tracing_opentelemetry;
#[cfg(feature = "opentelemetry_0_20")]
pub use tracing_opentelemetry_0_20_pkg as tracing_opentelemetry;
#[cfg(feature = "opentelemetry_0_21")]
pub use tracing_opentelemetry_0_22_pkg as tracing_opentelemetry;
use opentelemetry::global;
use opentelemetry::propagation::Injector;
use tracing_opentelemetry::OpenTelemetrySpanExt;
@ -98,16 +80,10 @@ impl<'a> Injector for RequestCarrier<'a> {
#[cfg(test)]
mod test {
use std::sync::OnceLock;
use super::*;
use crate::{DisableOtelPropagation, TracingMiddleware};
#[cfg(not(feature = "opentelemetry_0_21"))]
use crate::TracingMiddleware;
use opentelemetry::sdk::propagation::TraceContextPropagator;
#[cfg(feature = "opentelemetry_0_21")]
use opentelemetry_sdk_0_21::propagation::TraceContextPropagator;
use reqwest::Response;
use reqwest_middleware::{ClientBuilder, ClientWithMiddleware, Extension};
use reqwest_middleware::ClientBuilder;
use tracing::{info_span, Instrument, Level};
#[cfg(any(
feature = "opentelemetry_0_13",
@ -123,50 +99,17 @@ mod test {
use tracing_subscriber_0_3::{filter, layer::SubscriberExt, Registry};
use wiremock::{matchers::any, Mock, MockServer, ResponseTemplate};
async fn make_echo_request_in_otel_context(client: ClientWithMiddleware) -> Response {
static TELEMETRY: OnceLock<()> = OnceLock::new();
TELEMETRY.get_or_init(|| {
#[cfg(all(
not(feature = "opentelemetry_0_20"),
not(feature = "opentelemetry_0_21")
))]
let tracer = opentelemetry::sdk::export::trace::stdout::new_pipeline()
.with_writer(std::io::sink())
.install_simple();
#[cfg(any(feature = "opentelemetry_0_20", feature = "opentelemetry_0_21"))]
let tracer = {
use opentelemetry::trace::TracerProvider;
#[cfg(feature = "opentelemetry_0_20")]
use opentelemetry_stdout_0_1::SpanExporterBuilder;
#[cfg(feature = "opentelemetry_0_21")]
use opentelemetry_stdout_0_2::SpanExporterBuilder;
let exporter = SpanExporterBuilder::default()
.with_writer(std::io::sink())
.build();
#[cfg(feature = "opentelemetry_0_20")]
let provider = opentelemetry::sdk::trace::TracerProvider::builder()
.with_simple_exporter(exporter)
.build();
#[cfg(feature = "opentelemetry_0_21")]
let provider = opentelemetry_sdk_0_21::trace::TracerProvider::builder()
.with_simple_exporter(exporter)
.build();
let tracer = provider.versioned_tracer("reqwest", None::<&str>, None::<&str>, None);
let _ = global::set_tracer_provider(provider);
tracer
};
let telemetry = tracing_opentelemetry::layer().with_tracer(tracer);
let subscriber = Registry::default()
.with(
filter::Targets::new().with_target("reqwest_tracing::otel::test", Level::DEBUG),
)
.with(telemetry);
tracing::subscriber::set_global_default(subscriber).unwrap();
global::set_text_map_propagator(TraceContextPropagator::new());
});
#[tokio::test]
async fn tracing_middleware_propagates_otel_data_even_when_the_span_is_disabled() {
let tracer = opentelemetry::sdk::export::trace::stdout::new_pipeline()
.with_writer(std::io::sink())
.install_simple();
let telemetry = tracing_opentelemetry::layer().with_tracer(tracer);
let subscriber = Registry::default()
.with(filter::Targets::new().with_target("reqwest_tracing::otel::test", Level::DEBUG))
.with(telemetry);
tracing::subscriber::set_global_default(subscriber).unwrap();
global::set_text_map_propagator(TraceContextPropagator::new());
// Mock server - sends all request headers back in the response
let server = MockServer::start().await;
@ -181,40 +124,17 @@ mod test {
.mount(&server)
.await;
client
let client = ClientBuilder::new(reqwest::Client::new())
.with(TracingMiddleware::default())
.build();
let resp = client
.get(server.uri())
.send()
.instrument(info_span!("some_span"))
.await
.unwrap()
}
.unwrap();
#[tokio::test]
async fn tracing_middleware_propagates_otel_data_even_when_the_span_is_disabled() {
let client = ClientBuilder::new(reqwest::Client::new())
.with(TracingMiddleware::default())
.build();
let resp = make_echo_request_in_otel_context(client).await;
assert!(
resp.headers().contains_key("traceparent"),
"by default, the tracing middleware will propagate otel contexts"
);
}
#[tokio::test]
async fn context_no_propagated() {
let client = ClientBuilder::new(reqwest::Client::new())
.with_init(Extension(DisableOtelPropagation))
.with(TracingMiddleware::default())
.build();
let resp = make_echo_request_in_otel_context(client).await;
assert!(
!resp.headers().contains_key("traceparent"),
"request should not contain traceparent if context propagation is disabled"
);
assert!(resp.headers().contains_key("traceparent"));
}
}

View File

@ -1,11 +1,10 @@
use std::borrow::Cow;
use matchit::Router;
use reqwest::header::{HeaderMap, HeaderValue};
use reqwest::{Request, Response, StatusCode as RequestStatusCode, Url};
use reqwest_middleware::{Error, Result};
use reqwest_middleware::Error;
use task_local_extensions::Extensions;
use tracing::{warn, Span};
use tracing::Span;
use crate::reqwest_otel_span;
@ -23,8 +22,8 @@ pub const NET_HOST_PORT: &str = "net.host.port";
pub const OTEL_KIND: &str = "otel.kind";
/// The `otel.name` field added to the span by [`reqwest_otel_span`]
pub const OTEL_NAME: &str = "otel.name";
/// The `otel.status_code` field added to the span by [`reqwest_otel_span`]
pub const OTEL_STATUS_CODE: &str = "otel.status_code";
/// The `otel.status_code.code` field added to the span by [`reqwest_otel_span`]
pub const OTEL_STATUS_CODE: &str = "http.status_code";
/// The `error.message` field added to the span by [`reqwest_otel_span`]
pub const ERROR_MESSAGE: &str = "error.message";
/// The `error.cause_chain` field added to the span by [`reqwest_otel_span`]
@ -40,17 +39,17 @@ pub const HTTP_USER_AGENT: &str = "http.user_agent";
/// Check out [`reqwest_otel_span`] documentation for examples.
///
/// [`TracingMiddleware`]: crate::middleware::TracingMiddleware.
pub trait ReqwestOtelSpanBackend {
pub trait ReqwestOtelSpanBackend: Sized {
/// Initalized a new span before the request is executed.
fn on_request_start(req: &Request, extension: &mut Extensions) -> Span;
fn on_request_start(req: &Request, extension: &mut Extensions) -> (Self, Span);
/// Runs after the request call has executed.
fn on_request_end(span: &Span, outcome: &Result<Response>, extension: &mut Extensions);
fn on_request_end(self, span: &Span, outcome: &Result<Response, Error>);
}
/// Populates default success/failure fields for a given [`reqwest_otel_span!`] span.
#[inline]
pub fn default_on_request_end(span: &Span, outcome: &Result<Response>) {
pub fn default_on_request_end(span: &Span, outcome: &Result<Response, Error>) {
match outcome {
Ok(res) => default_on_request_success(span, res),
Err(err) => default_on_request_failure(span, err),
@ -61,11 +60,12 @@ pub fn default_on_request_end(span: &Span, outcome: &Result<Response>) {
#[inline]
pub fn default_on_request_success(span: &Span, response: &Response) {
let span_status = get_span_status(response.status());
let status_code = response.status().as_u16() as i64;
let user_agent = get_header_value("user_agent", response.headers());
if let Some(span_status) = span_status {
span.record(OTEL_STATUS_CODE, span_status);
}
span.record(HTTP_STATUS_CODE, response.status().as_u16());
span.record(HTTP_STATUS_CODE, status_code);
span.record(HTTP_USER_AGENT, user_agent.as_str());
}
@ -78,30 +78,13 @@ pub fn default_on_request_failure(span: &Span, e: &Error) {
span.record(ERROR_MESSAGE, error_message.as_str());
span.record(ERROR_CAUSE_CHAIN, error_cause_chain.as_str());
if let Error::Reqwest(e) = e {
if let Some(status) = e.status() {
span.record(HTTP_STATUS_CODE, status.as_u16());
}
}
}
/// Determine the name of the span that should be associated with this request.
///
/// This tries to be PII safe by default, not including any path information unless
/// specifically opted in using either [`OtelName`] or [`OtelPathNames`]
#[inline]
pub fn default_span_name<'a>(req: &'a Request, ext: &'a Extensions) -> Cow<'a, str> {
if let Some(name) = ext.get::<OtelName>() {
Cow::Borrowed(name.0.as_ref())
} else if let Some(path_names) = ext.get::<OtelPathNames>() {
path_names
.find(req.url().path())
.map(|path| Cow::Owned(format!("{} {}", req.method(), path)))
.unwrap_or_else(|| {
warn!("no OTEL path name found");
Cow::Owned(format!("{} UNKNOWN", req.method().as_str()))
})
} else {
Cow::Borrowed(req.method().as_str())
span.record(
HTTP_STATUS_CODE,
e.status()
.map(|s| s.to_string())
.unwrap_or_else(|| "".to_string())
.as_str(),
);
}
}
@ -112,12 +95,15 @@ pub fn default_span_name<'a>(req: &'a Request, ext: &'a Extensions) -> Cow<'a, s
pub struct DefaultSpanBackend;
impl ReqwestOtelSpanBackend for DefaultSpanBackend {
fn on_request_start(req: &Request, ext: &mut Extensions) -> Span {
let name = default_span_name(req, ext);
reqwest_otel_span!(name = name, req)
fn on_request_start(req: &Request, ext: &mut Extensions) -> (DefaultSpanBackend, Span) {
let name = ext
.get::<OtelName>()
.map(|on| on.0.as_ref())
.unwrap_or("reqwest-http-client");
(Self, reqwest_otel_span!(name = name, req))
}
fn on_request_end(span: &Span, outcome: &Result<Response>, _: &mut Extensions) {
fn on_request_end(self, span: &Span, outcome: &Result<Response, Error>) {
default_on_request_end(span, outcome)
}
}
@ -133,12 +119,19 @@ fn get_header_value(key: &str, headers: &HeaderMap) -> String {
pub struct SpanBackendWithUrl;
impl ReqwestOtelSpanBackend for SpanBackendWithUrl {
fn on_request_start(req: &Request, ext: &mut Extensions) -> Span {
let name = default_span_name(req, ext);
reqwest_otel_span!(name = name, req, http.url = %remove_credentials(req.url()))
fn on_request_start(req: &Request, ext: &mut Extensions) -> (Self, Span) {
let name = ext
.get::<OtelName>()
.map(|on| on.0.as_ref())
.unwrap_or("reqwest-http-client");
(
Self,
reqwest_otel_span!(name = name, req, http.url = %remove_credentials(req.url())),
)
}
fn on_request_end(span: &Span, outcome: &Result<Response>, _: &mut Extensions) {
fn on_request_end(self, span: &Span, outcome: &Result<Response, Error>) {
default_on_request_end(span, outcome)
}
}
@ -162,146 +155,29 @@ fn get_span_status(request_status: RequestStatusCode) -> Option<&'static str> {
}
/// [`OtelName`] allows customisation of the name of the spans created by
/// [`DefaultSpanBackend`] and [`SpanBackendWithUrl`].
/// DefaultSpanBackend.
///
/// Usage:
/// ```no_run
/// # use reqwest_middleware::Result;
/// # use reqwest_middleware::Error;
/// use reqwest_middleware::{ClientBuilder, Extension};
/// use reqwest_tracing::{
/// TracingMiddleware, OtelName
/// };
/// # async fn example() -> Result<()> {
/// # async fn example() -> Result<(), Error> {
/// let reqwest_client = reqwest::Client::builder().build().unwrap();
/// let client = ClientBuilder::new(reqwest_client)
/// // Inserts the extension before the request is started
/// .with_init(Extension(OtelName("my-client".into())))
/// // Makes use of that extension to specify the otel name
/// .with(TracingMiddleware::default())
/// .build();
/// // Inserts the extension before the request is started
/// .with_init(Extension(OtelName("my-client".into())))
/// // Makes use of that extension to specify the otel name
/// .with(TracingMiddleware::default())
/// .build();
///
/// let resp = client.get("https://truelayer.com").send().await.unwrap();
///
/// // Or specify it on the individual request (will take priority)
/// let resp = client.post("https://api.truelayer.com/payment")
/// .with_extension(OtelName("POST /payment".into()))
/// .send()
/// .await
/// .unwrap();
/// # Ok(())
/// # }
/// ```
#[derive(Clone)]
pub struct OtelName(pub Cow<'static, str>);
/// [`OtelPathNames`] allows including templated paths in the spans created by
/// [`DefaultSpanBackend`] and [`SpanBackendWithUrl`].
///
/// When creating spans this can be used to try to match the path against some
/// known paths. If the path matches value returned is the templated path. This
/// can be used in span names as it will not contain values that would
/// increase the cardinality.
///
/// ```
/// /// # use reqwest_middleware::Result;
/// use reqwest_middleware::{ClientBuilder, Extension};
/// use reqwest_tracing::{
/// TracingMiddleware, OtelPathNames
/// };
/// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
/// let reqwest_client = reqwest::Client::builder().build()?;
/// let client = ClientBuilder::new(reqwest_client)
/// // Inserts the extension before the request is started
/// .with_init(Extension(OtelPathNames::known_paths(["/payment/:paymentId"])?))
/// // Makes use of that extension to specify the otel name
/// .with(TracingMiddleware::default())
/// .build();
///
/// let resp = client.get("https://truelayer.com/payment/id-123").send().await?;
///
/// // Or specify it on the individual request (will take priority)
/// let resp = client.post("https://api.truelayer.com/payment/id-123/authorization-flow")
/// .with_extension(OtelPathNames::known_paths(["/payment/:paymentId/authorization-flow"])?)
/// .send()
/// .await?;
/// # Ok(())
/// # }
/// ```
#[derive(Clone)]
pub struct OtelPathNames(matchit::Router<String>);
impl OtelPathNames {
/// Create a new [`OtelPathNames`] from a set of known paths.
///
/// Paths in this set will be found with `find`.
///
/// Paths can have different parameters:
/// - Named parameters like `:paymentId` match anything until the next `/` or the end of the path.
/// - Catch-all parameters start with `*` and match everything after the `/`. They must be at the end of the route.
/// ```
/// # use reqwest_tracing::OtelPathNames;
/// OtelPathNames::known_paths([
/// "/",
/// "/payment",
/// "/payment/:paymentId",
/// "/payment/:paymentId/*action",
/// ]).unwrap();
/// ```
pub fn known_paths<Paths, Path>(paths: Paths) -> anyhow::Result<Self>
where
Paths: IntoIterator<Item = Path>,
Path: Into<String>,
{
let mut router = Router::new();
for path in paths {
let path = path.into();
router.insert(path.clone(), path)?;
}
Ok(Self(router))
}
/// Find the templated path from the actual path.
///
/// Returns the templated path if a match is found.
///
/// ```
/// # use reqwest_tracing::OtelPathNames;
/// let path_names = OtelPathNames::known_paths(["/payment/:paymentId"]).unwrap();
/// let path = path_names.find("/payment/payment-id-123");
/// assert_eq!(path, Some("/payment/:paymentId"));
/// ```
pub fn find(&self, path: &str) -> Option<&str> {
self.0.at(path).map(|mtch| mtch.value.as_str()).ok()
}
}
/// `DisableOtelPropagation` disables opentelemetry header propagation, while still tracing the HTTP request.
///
/// By default, the [`TracingMiddleware`](super::TracingMiddleware) middleware will also propagate any opentelemtry
/// contexts to the server. For any external facing requests, this can be problematic and it should be disabled.
///
/// Usage:
/// ```no_run
/// # use reqwest_middleware::Result;
/// use reqwest_middleware::{ClientBuilder, Extension};
/// use reqwest_tracing::{
/// TracingMiddleware, DisableOtelPropagation
/// };
/// # async fn example() -> Result<()> {
/// let reqwest_client = reqwest::Client::builder().build().unwrap();
/// let client = ClientBuilder::new(reqwest_client)
/// // Inserts the extension before the request is started
/// .with_init(Extension(DisableOtelPropagation))
/// // Makes use of that extension to specify the otel name
/// .with(TracingMiddleware::default())
/// .build();
///
/// let resp = client.get("https://truelayer.com").send().await.unwrap();
///
/// // Or specify it on the individual request (will take priority)
/// let resp = client.post("https://api.truelayer.com/payment")
/// .with_extension(DisableOtelPropagation)
/// .send()
/// .await
/// .unwrap();
@ -309,7 +185,7 @@ impl OtelPathNames {
/// # }
/// ```
#[derive(Clone)]
pub struct DisableOtelPropagation;
pub struct OtelName(pub Cow<'static, str>);
/// Removes the username and/or password parts of the url, if present.
fn remove_credentials(url: &Url) -> Cow<'_, str> {

View File

@ -30,7 +30,7 @@
/// The second argument passed to [`reqwest_otel_span!`](crate::reqwest_otel_span) is a reference to an [`reqwest::Request`].
///
/// ```rust
/// use reqwest_middleware::Result;
/// use reqwest_middleware::Error;
/// use task_local_extensions::Extensions;
/// use reqwest::{Request, Response};
/// use reqwest_tracing::{
@ -41,11 +41,11 @@
/// pub struct CustomReqwestOtelSpanBackend;
///
/// impl ReqwestOtelSpanBackend for CustomReqwestOtelSpanBackend {
/// fn on_request_start(req: &Request, _extension: &mut Extensions) -> Span {
/// reqwest_otel_span!(name = "reqwest-http-request", req)
/// fn on_request_start(req: &Request, _extension: &mut Extensions) -> (Self, Span) {
/// (Self, reqwest_otel_span!(name = "reqwest-http-request", req))
/// }
///
/// fn on_request_end(span: &Span, outcome: &Result<Response>, _extension: &mut Extensions) {
/// fn on_request_end(self, span: &Span, outcome: &Result<Response, Error>) {
/// default_on_request_end(span, outcome)
/// }
/// }