Compare commits

..

3 Commits

Author SHA1 Message Date
Luca Palmieri feb73a7e60 (cargo-release) start next development iteration 0.2.1-alpha.0 2021-11-30 16:13:56 +00:00
Luca Palmieri 251e03ae7a (cargo-release) version 0.2.0 2021-11-30 16:13:21 +00:00
Luca Palmieri 7489e2d4b7 Cut 0.2.0 release of `reqwest-tracing`. 2021-11-30 16:11:59 +00:00
27 changed files with 405 additions and 1944 deletions

View File

@ -1,6 +0,0 @@
# See https://github.com/rustsec/rustsec/blob/59e1d2ad0b9cbc6892c26de233d4925074b4b97b/cargo-audit/audit.toml.example for example.
[advisories]
ignore = [
"RUSTSEC-2020-0159",
]

View File

@ -16,11 +16,6 @@ jobs:
- opentelemetry_0_14 - opentelemetry_0_14
- opentelemetry_0_15 - opentelemetry_0_15
- opentelemetry_0_16 - opentelemetry_0_16
- opentelemetry_0_17
- opentelemetry_0_18
- opentelemetry_0_19
- opentelemetry_0_20
- opentelemetry_0_21
steps: steps:
- name: Checkout repository - name: Checkout repository
uses: actions/checkout@v2 uses: actions/checkout@v2
@ -64,11 +59,6 @@ jobs:
- opentelemetry_0_14 - opentelemetry_0_14
- opentelemetry_0_15 - opentelemetry_0_15
- opentelemetry_0_16 - opentelemetry_0_16
- opentelemetry_0_17
- opentelemetry_0_18
- opentelemetry_0_19
- opentelemetry_0_20
- opentelemetry_0_21
steps: steps:
- name: Checkout repository - name: Checkout repository
uses: actions/checkout@v2 uses: actions/checkout@v2
@ -95,11 +85,6 @@ jobs:
- opentelemetry_0_14 - opentelemetry_0_14
- opentelemetry_0_15 - opentelemetry_0_15
- opentelemetry_0_16 - opentelemetry_0_16
- opentelemetry_0_17
- opentelemetry_0_18
- opentelemetry_0_19
- opentelemetry_0_20
- opentelemetry_0_21
steps: steps:
- name: Checkout repository - name: Checkout repository
uses: actions/checkout@v2 uses: actions/checkout@v2
@ -140,3 +125,27 @@ jobs:
with: with:
command: publish command: publish
args: --dry-run --manifest-path reqwest-tracing/Cargo.toml args: --dry-run --manifest-path reqwest-tracing/Cargo.toml
coverage:
name: Code coverage
runs-on: ubuntu-latest
steps:
- name: Checkout repository
uses: actions/checkout@v2
- name: Install stable toolchain
uses: actions-rs/toolchain@v1
with:
toolchain: stable
profile: minimal
override: true
- name: Run cargo-tarpaulin
uses: actions-rs/tarpaulin@v0.1
with:
args: '--ignore-tests --out Lcov'
- name: Upload to Coveralls
# upload only if push
if: ${{ github.event_name == 'push' }}
uses: coverallsapp/github-action@master
with:
github-token: ${{ secrets.GITHUB_TOKEN }}
path-to-lcov: './lcov.info'

10
.gitignore vendored
View File

@ -1,10 +1,2 @@
# OS
.DS_Store
# IDE
.idea/
.vscode/
# Rust
Cargo.lock
/target /target
Cargo.lock

View File

@ -4,53 +4,7 @@ All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
### [0.2.4] - 2023-09-21 ## [Unreleased]
### Added
- Added `fetch_mode_no_cors` method to `reqwest_middleware::RequestBuilder`
## [0.2.3] - 2023-08-07
### Added
- Added all `reqwest::Error` methods for `reqwest_middleware::Error`
## [0.2.2] - 2023-05-11
### Added
- `RequestBuilder::version` method to configure the HTTP version
## [0.2.1] - 2023-03-09
### Added
- Support for `wasm32-unknown-unknown`
## [0.2.0] - 2022-11-15
### Changed
- `RequestBuilder::try_clone` has a fixed function signature now
### Removed
- `RequestBuilder::send_with_extensions` - use `RequestBuilder::with_extensions` + `RequestBuilder::send` instead.
### Added
- Implementation of `Debug` trait for `RequestBuilder`.
- A new `RequestInitialiser` trait that can be added to `ClientWithMiddleware`
- A new `Extension` initialiser that adds extensions to each request
- Adds `with_extension` method functionality to `RequestBuilder` that can add extensions for the `send` method to use.
## [0.1.6] - 2022-04-21
Absolutely nothing changed
## [0.1.5] - 2022-02-21
### Added
- Added support for `opentelemetry` version `0.17`.
## [0.1.4] - 2022-01-24
### Changed
- Made `Debug` impl for `ClientWithExtensions` non-exhaustive.
## [0.1.3] - 2021-10-18 ## [0.1.3] - 2021-10-18

View File

@ -15,11 +15,6 @@ implementations. This repository also contains a couple of useful concrete middl
* [`reqwest-tracing`](https://crates.io/crates/reqwest-tracing): * [`reqwest-tracing`](https://crates.io/crates/reqwest-tracing):
[`tracing`](https://crates.io/crates/tracing) integration, optional opentelemetry support. [`tracing`](https://crates.io/crates/tracing) integration, optional opentelemetry support.
Note about browser support: automated tests targetting wasm are disabled. The crate may work with
wasm but wasm support is unmaintained. PRs improving wasm are still welcome but you'd need to
reintroduce the tests and get them passing before we'd merge it (see
https://github.com/TrueLayer/reqwest-middleware/pull/105).
## Overview ## Overview
The `reqwest-middleware` client exposes the same interface as a plain `reqwest` client, but The `reqwest-middleware` client exposes the same interface as a plain `reqwest` client, but
@ -30,9 +25,9 @@ The `reqwest-middleware` client exposes the same interface as a plain `reqwest`
# ... # ...
[dependencies] [dependencies]
reqwest = "0.11" reqwest = "0.11"
reqwest-middleware = "0.1.6" reqwest-middleware = "0.1.1"
reqwest-retry = "0.1.5" reqwest-retry = "0.1.1"
reqwest-tracing = "0.2.3" reqwest-tracing = "0.1.2"
tokio = { version = "1.12.0", features = ["macros", "rt-multi-thread"] } tokio = { version = "1.12.0", features = ["macros", "rt-multi-thread"] }
``` ```
@ -47,7 +42,7 @@ async fn main() {
let retry_policy = ExponentialBackoff::builder().build_with_max_retries(3); let retry_policy = ExponentialBackoff::builder().build_with_max_retries(3);
let client = ClientBuilder::new(reqwest::Client::new()) let client = ClientBuilder::new(reqwest::Client::new())
// Trace HTTP requests. See the tracing crate to make use of these traces. // Trace HTTP requests. See the tracing crate to make use of these traces.
.with(TracingMiddleware::default()) .with(TracingMiddleware)
// Retry failed requests. // Retry failed requests.
.with(RetryTransientMiddleware::new_with_policy(retry_policy)) .with(RetryTransientMiddleware::new_with_policy(retry_policy))
.build(); .build();
@ -78,14 +73,3 @@ Unless you explicitly state otherwise, any contribution intentionally submitted
for inclusion in the work by you, as defined in the Apache-2.0 license, shall be for inclusion in the work by you, as defined in the Apache-2.0 license, shall be
dual licensed as above, without any additional terms or conditions. dual licensed as above, without any additional terms or conditions.
</sub> </sub>
## Third-party middleware
The following third-party middleware use `request-middleware`:
- [`reqwest-conditional-middleware`](https://github.com/oxidecomputer/reqwest-conditional-middleware) - Per-request basis middleware
- [`http-cache`](https://github.com/06chaynes/http-cache) - HTTP caching rules
- [`reqwest-cache`](https://gitlab.com/famedly/company/backend/libraries/reqwest-cache) - HTTP caching
- [`aliri_reqwest`](https://github.com/neoeinstein/aliri/tree/main/aliri_reqwest) - Background token management and renewal
- [`http-signature-normalization-reqwest`](https://crates.io/crates/http-signature-normalization-reqwest) (not free software) - HTTP Signatures
- [`reqwest-chain`](https://github.com/tommilligan/reqwest-chain) - Apply custom criteria to any reqwest response, deciding when and how to retry.

View File

@ -1,6 +1,6 @@
[package] [package]
name = "reqwest-middleware" name = "reqwest-middleware"
version = "0.2.4" version = "0.1.2"
authors = ["Rodrigo Gryzinski <rodrigo.gryzinski@truelayer.com>"] authors = ["Rodrigo Gryzinski <rodrigo.gryzinski@truelayer.com>"]
edition = "2018" edition = "2018"
description = "Wrapper around reqwest to allow for client middleware chains." description = "Wrapper around reqwest to allow for client middleware chains."
@ -11,17 +11,18 @@ categories = ["web-programming::http-client"]
readme = "../README.md" readme = "../README.md"
[dependencies] [dependencies]
anyhow = "1.0.0" anyhow = "1"
async-trait = "0.1.51" async-trait = "0.1.51"
http = "0.2.0" futures = "0.3"
reqwest = { version = "0.11.4", default-features = false, features = ["json", "multipart"] } http = "0.2"
serde = "1.0.106" reqwest = { version = "0.11", default-features = false, features = ["json", "multipart"] }
task-local-extensions = "0.1.4" serde = "1"
thiserror = "1.0.21" thiserror = "1"
task-local-extensions = "0.1.1"
[dev-dependencies] [dev-dependencies]
reqwest = "0.11.4" reqwest = "0.11"
reqwest-retry = { path = "../reqwest-retry" } reqwest-retry = { path = "../reqwest-retry" }
reqwest-tracing = { path = "../reqwest-tracing" } reqwest-tracing = { path = "../reqwest-tracing" }
tokio = { version = "1.0.0", features = ["macros", "rt-multi-thread"] } wiremock = "0.5"
wiremock = "0.5.0" tokio = { version = "1", features = ["macros", "rt-multi-thread"] }

View File

@ -3,22 +3,20 @@ use reqwest::multipart::Form;
use reqwest::{Body, Client, IntoUrl, Method, Request, Response}; use reqwest::{Body, Client, IntoUrl, Method, Request, Response};
use serde::Serialize; use serde::Serialize;
use std::convert::TryFrom; use std::convert::TryFrom;
use std::fmt::{self, Display}; use std::fmt::Display;
use std::sync::Arc; use std::sync::Arc;
use std::time::Duration;
use task_local_extensions::Extensions; use task_local_extensions::Extensions;
use crate::error::Result; use crate::error::Result;
use crate::middleware::{Middleware, Next}; use crate::middleware::{Middleware, Next};
use crate::RequestInitialiser;
/// A `ClientBuilder` is used to build a [`ClientWithMiddleware`]. /// A `ClientBuilder` is used to build a [`ClientWithMiddleware`].
/// ///
/// [`ClientWithMiddleware`]: crate::ClientWithMiddleware /// [`ClientWithMiddleware`]: crate::ClientWithMiddleware
#[derive(Debug)]
pub struct ClientBuilder { pub struct ClientBuilder {
client: Client, client: Client,
middleware_stack: Vec<Arc<dyn Middleware>>, middleware_stack: Vec<Arc<dyn Middleware>>,
initialiser_stack: Vec<Arc<dyn RequestInitialiser>>,
} }
impl ClientBuilder { impl ClientBuilder {
@ -26,7 +24,6 @@ impl ClientBuilder {
ClientBuilder { ClientBuilder {
client, client,
middleware_stack: Vec::new(), middleware_stack: Vec::new(),
initialiser_stack: Vec::new(),
} }
} }
@ -50,33 +47,9 @@ impl ClientBuilder {
self self
} }
/// Convenience method to attach a request initialiser.
///
/// If you need to keep a reference to the initialiser after attaching, use [`with_arc_init`].
///
/// [`with_arc_init`]: Self::with_arc_init
pub fn with_init<I>(self, initialiser: I) -> Self
where
I: RequestInitialiser,
{
self.with_arc_init(Arc::new(initialiser))
}
/// Add a request initialiser to the chain. [`with_init`] is more ergonomic if you don't need the `Arc`.
///
/// [`with_init`]: Self::with_init
pub fn with_arc_init(mut self, initialiser: Arc<dyn RequestInitialiser>) -> Self {
self.initialiser_stack.push(initialiser);
self
}
/// Returns a `ClientWithMiddleware` using this builder configuration. /// Returns a `ClientWithMiddleware` using this builder configuration.
pub fn build(self) -> ClientWithMiddleware { pub fn build(self) -> ClientWithMiddleware {
ClientWithMiddleware { ClientWithMiddleware::new(self.client, self.middleware_stack)
inner: self.client,
middleware_stack: self.middleware_stack.into_boxed_slice(),
initialiser_stack: self.initialiser_stack.into_boxed_slice(),
}
} }
} }
@ -86,7 +59,6 @@ impl ClientBuilder {
pub struct ClientWithMiddleware { pub struct ClientWithMiddleware {
inner: reqwest::Client, inner: reqwest::Client,
middleware_stack: Box<[Arc<dyn Middleware>]>, middleware_stack: Box<[Arc<dyn Middleware>]>,
initialiser_stack: Box<[Arc<dyn RequestInitialiser>]>,
} }
impl ClientWithMiddleware { impl ClientWithMiddleware {
@ -98,8 +70,6 @@ impl ClientWithMiddleware {
ClientWithMiddleware { ClientWithMiddleware {
inner: client, inner: client,
middleware_stack: middleware_stack.into(), middleware_stack: middleware_stack.into(),
// TODO(conradludgate) - allow downstream code to control this manually if desired
initialiser_stack: Box::new([]),
} }
} }
@ -135,14 +105,10 @@ impl ClientWithMiddleware {
/// See [`Client::request`] /// See [`Client::request`]
pub fn request<U: IntoUrl>(&self, method: Method, url: U) -> RequestBuilder { pub fn request<U: IntoUrl>(&self, method: Method, url: U) -> RequestBuilder {
let req = RequestBuilder { RequestBuilder {
inner: self.inner.request(method, url), inner: self.inner.request(method, url),
client: self.clone(), client: self.clone(),
extensions: Extensions::new(), }
};
self.initialiser_stack
.iter()
.fold(req, |req, i| i.init(req))
} }
/// See [`Client::execute`] /// See [`Client::execute`]
@ -168,26 +134,15 @@ impl From<Client> for ClientWithMiddleware {
ClientWithMiddleware { ClientWithMiddleware {
inner: client, inner: client,
middleware_stack: Box::new([]), middleware_stack: Box::new([]),
initialiser_stack: Box::new([]),
} }
} }
} }
impl fmt::Debug for ClientWithMiddleware {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
// skipping middleware_stack field for now
f.debug_struct("ClientWithMiddleware")
.field("inner", &self.inner)
.finish_non_exhaustive()
}
}
/// This is a wrapper around [`reqwest::RequestBuilder`] exposing the same API. /// This is a wrapper around [`reqwest::RequestBuilder`] exposing the same API.
#[must_use = "RequestBuilder does nothing until you 'send' it"] #[must_use = "RequestBuilder does nothing until you 'send' it"]
pub struct RequestBuilder { pub struct RequestBuilder {
inner: reqwest::RequestBuilder, inner: reqwest::RequestBuilder,
client: ClientWithMiddleware, client: ClientWithMiddleware,
extensions: Extensions,
} }
impl RequestBuilder { impl RequestBuilder {
@ -200,22 +155,14 @@ impl RequestBuilder {
{ {
RequestBuilder { RequestBuilder {
inner: self.inner.header(key, value), inner: self.inner.header(key, value),
..self client: self.client,
} }
} }
pub fn headers(self, headers: HeaderMap) -> Self { pub fn headers(self, headers: HeaderMap) -> Self {
RequestBuilder { RequestBuilder {
inner: self.inner.headers(headers), inner: self.inner.headers(headers),
..self client: self.client,
}
}
#[cfg(not(target_arch = "wasm32"))]
pub fn version(self, version: reqwest::Version) -> Self {
RequestBuilder {
inner: self.inner.version(version),
..self
} }
} }
@ -226,7 +173,7 @@ impl RequestBuilder {
{ {
RequestBuilder { RequestBuilder {
inner: self.inner.basic_auth(username, password), inner: self.inner.basic_auth(username, password),
..self client: self.client,
} }
} }
@ -236,57 +183,49 @@ impl RequestBuilder {
{ {
RequestBuilder { RequestBuilder {
inner: self.inner.bearer_auth(token), inner: self.inner.bearer_auth(token),
..self client: self.client,
} }
} }
pub fn body<T: Into<Body>>(self, body: T) -> Self { pub fn body<T: Into<Body>>(self, body: T) -> Self {
RequestBuilder { RequestBuilder {
inner: self.inner.body(body), inner: self.inner.body(body),
..self client: self.client,
} }
} }
#[cfg(not(target_arch = "wasm32"))] pub fn timeout(self, timeout: Duration) -> Self {
pub fn timeout(self, timeout: std::time::Duration) -> Self {
RequestBuilder { RequestBuilder {
inner: self.inner.timeout(timeout), inner: self.inner.timeout(timeout),
..self client: self.client,
} }
} }
pub fn multipart(self, multipart: Form) -> Self { pub fn multipart(self, multipart: Form) -> Self {
RequestBuilder { RequestBuilder {
inner: self.inner.multipart(multipart), inner: self.inner.multipart(multipart),
..self client: self.client,
} }
} }
pub fn query<T: Serialize + ?Sized>(self, query: &T) -> Self { pub fn query<T: Serialize + ?Sized>(self, query: &T) -> Self {
RequestBuilder { RequestBuilder {
inner: self.inner.query(query), inner: self.inner.query(query),
..self client: self.client,
} }
} }
pub fn form<T: Serialize + ?Sized>(self, form: &T) -> Self { pub fn form<T: Serialize + ?Sized>(self, form: &T) -> Self {
RequestBuilder { RequestBuilder {
inner: self.inner.form(form), inner: self.inner.form(form),
..self client: self.client,
} }
} }
pub fn json<T: Serialize + ?Sized>(self, json: &T) -> Self { pub fn json<T: Serialize + ?Sized>(self, json: &T) -> Self {
RequestBuilder { RequestBuilder {
inner: self.inner.json(json), inner: self.inner.json(json),
..self client: self.client,
}
}
pub fn fetch_mode_no_cors(self) -> Self {
RequestBuilder {
inner: self.inner.fetch_mode_no_cors(),
..self
} }
} }
@ -294,48 +233,21 @@ impl RequestBuilder {
self.inner.build() self.inner.build()
} }
/// Inserts the extension into this request builder
pub fn with_extension<T: Send + Sync + 'static>(mut self, extension: T) -> Self {
self.extensions.insert(extension);
self
}
/// Returns a mutable reference to the internal set of extensions for this request
pub fn extensions(&mut self) -> &mut Extensions {
&mut self.extensions
}
pub async fn send(self) -> Result<Response> { pub async fn send(self) -> Result<Response> {
let Self { let req = self.inner.build()?;
inner, self.client.execute(req).await
client,
mut extensions,
} = self;
let req = inner.build()?;
client.execute_with_extensions(req, &mut extensions).await
} }
/// Attempt to clone the RequestBuilder. /// Sends a request with initial [`Extensions`].
/// pub async fn send_with_extensions(self, ext: &mut Extensions) -> Result<Response> {
/// `None` is returned if the RequestBuilder can not be cloned, let req = self.inner.build()?;
/// i.e. if the request body is a stream. self.client.execute_with_extensions(req, ext).await
///
/// # Extensions
/// Note that extensions are not preserved through cloning.
pub fn try_clone(&self) -> Option<Self> {
self.inner.try_clone().map(|inner| RequestBuilder {
inner,
client: self.client.clone(),
extensions: Extensions::new(),
})
}
} }
impl fmt::Debug for RequestBuilder { pub fn try_clone(self) -> Option<Self> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { let client = self.client;
// skipping middleware_stack field for now self.inner
f.debug_struct("RequestBuilder") .try_clone()
.field("inner", &self.inner) .map(|inner| RequestBuilder { inner, client })
.finish_non_exhaustive()
} }
} }

View File

@ -1,4 +1,3 @@
use reqwest::{StatusCode, Url};
use thiserror::Error; use thiserror::Error;
pub type Result<T> = std::result::Result<T, Error>; pub type Result<T> = std::result::Result<T, Error>;
@ -20,122 +19,4 @@ impl Error {
{ {
Error::Middleware(err.into()) Error::Middleware(err.into())
} }
/// Returns a possible URL related to this error.
pub fn url(&self) -> Option<&Url> {
match self {
Error::Middleware(_) => None,
Error::Reqwest(e) => e.url(),
}
}
/// Returns a mutable reference to the URL related to this error.
///
/// This is useful if you need to remove sensitive information from the URL
/// (e.g. an API key in the query), but do not want to remove the URL
/// entirely.
pub fn url_mut(&mut self) -> Option<&mut Url> {
match self {
Error::Middleware(_) => None,
Error::Reqwest(e) => e.url_mut(),
}
}
/// Adds a url related to this error (overwriting any existing).
pub fn with_url(self, url: Url) -> Self {
match self {
Error::Middleware(_) => self,
Error::Reqwest(e) => e.with_url(url).into(),
}
}
/// Strips the related URL from this error (if, for example, it contains
/// sensitive information).
pub fn without_url(self) -> Self {
match self {
Error::Middleware(_) => self,
Error::Reqwest(e) => e.without_url().into(),
}
}
/// Returns true if the error is from any middleware.
pub fn is_middleware(&self) -> bool {
match self {
Error::Middleware(_) => true,
Error::Reqwest(_) => false,
}
}
/// Returns true if the error is from a type `Builder`.
pub fn is_builder(&self) -> bool {
match self {
Error::Middleware(_) => false,
Error::Reqwest(e) => e.is_builder(),
}
}
/// Returns true if the error is from a `RedirectPolicy`.
pub fn is_redirect(&self) -> bool {
match self {
Error::Middleware(_) => false,
Error::Reqwest(e) => e.is_redirect(),
}
}
/// Returns true if the error is from `Response::error_for_status`.
pub fn is_status(&self) -> bool {
match self {
Error::Middleware(_) => false,
Error::Reqwest(e) => e.is_status(),
}
}
/// Returns true if the error is related to a timeout.
pub fn is_timeout(&self) -> bool {
match self {
Error::Middleware(_) => false,
Error::Reqwest(e) => e.is_timeout(),
}
}
/// Returns true if the error is related to the request.
pub fn is_request(&self) -> bool {
match self {
Error::Middleware(_) => false,
Error::Reqwest(e) => e.is_request(),
}
}
#[cfg(not(target_arch = "wasm32"))]
/// Returns true if the error is related to connect.
pub fn is_connect(&self) -> bool {
match self {
Error::Middleware(_) => false,
Error::Reqwest(e) => e.is_connect(),
}
}
/// Returns true if the error is related to the request or response body.
pub fn is_body(&self) -> bool {
match self {
Error::Middleware(_) => false,
Error::Reqwest(e) => e.is_body(),
}
}
/// Returns true if the error is related to decoding the response's body.
pub fn is_decode(&self) -> bool {
match self {
Error::Middleware(_) => false,
Error::Reqwest(e) => e.is_decode(),
}
}
/// Returns the status code, if the error was generated from a response.
pub fn status(&self) -> Option<StatusCode> {
match self {
Error::Middleware(_) => None,
Error::Reqwest(e) => e.status(),
}
}
} }

View File

@ -52,9 +52,7 @@ pub struct ReadmeDoctests;
mod client; mod client;
mod error; mod error;
mod middleware; mod middleware;
mod req_init;
pub use client::{ClientBuilder, ClientWithMiddleware, RequestBuilder}; pub use client::{ClientBuilder, ClientWithMiddleware, RequestBuilder};
pub use error::{Error, Result}; pub use error::{Error, Result};
pub use middleware::{Middleware, Next}; pub use middleware::{Middleware, Next};
pub use req_init::{Extension, RequestInitialiser};

View File

@ -1,3 +1,4 @@
use futures::future::{BoxFuture, FutureExt, TryFutureExt};
use reqwest::{Client, Request, Response}; use reqwest::{Client, Request, Response};
use std::sync::Arc; use std::sync::Arc;
use task_local_extensions::Extensions; use task_local_extensions::Extensions;
@ -31,8 +32,7 @@ use crate::error::{Error, Result};
/// ///
/// [`ClientWithMiddleware`]: crate::ClientWithMiddleware /// [`ClientWithMiddleware`]: crate::ClientWithMiddleware
/// [`with`]: crate::ClientBuilder::with /// [`with`]: crate::ClientBuilder::with
#[cfg_attr(not(target_arch = "wasm32"), async_trait::async_trait)] #[async_trait::async_trait]
#[cfg_attr(target_arch = "wasm32", async_trait::async_trait(?Send))]
pub trait Middleware: 'static + Send + Sync { pub trait Middleware: 'static + Send + Sync {
/// Invoked with a request before sending it. If you want to continue processing the request, /// Invoked with a request before sending it. If you want to continue processing the request,
/// you should explicitly call `next.run(req, extensions)`. /// you should explicitly call `next.run(req, extensions)`.
@ -47,14 +47,7 @@ pub trait Middleware: 'static + Send + Sync {
) -> Result<Response>; ) -> Result<Response>;
} }
impl std::fmt::Debug for dyn Middleware { #[async_trait::async_trait]
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "Middleware")
}
}
#[cfg_attr(not(target_arch = "wasm32"), async_trait::async_trait)]
#[cfg_attr(target_arch = "wasm32", async_trait::async_trait(?Send))]
impl<F> Middleware for F impl<F> Middleware for F
where where
F: Send F: Send
@ -83,11 +76,6 @@ pub struct Next<'a> {
middlewares: &'a [Arc<dyn Middleware>], middlewares: &'a [Arc<dyn Middleware>],
} }
#[cfg(not(target_arch = "wasm32"))]
pub type BoxFuture<'a, T> = std::pin::Pin<Box<dyn std::future::Future<Output = T> + Send + 'a>>;
#[cfg(target_arch = "wasm32")]
pub type BoxFuture<'a, T> = std::pin::Pin<Box<dyn std::future::Future<Output = T> + 'a>>;
impl<'a> Next<'a> { impl<'a> Next<'a> {
pub(crate) fn new(client: &'a Client, middlewares: &'a [Arc<dyn Middleware>]) -> Self { pub(crate) fn new(client: &'a Client, middlewares: &'a [Arc<dyn Middleware>]) -> Self {
Next { Next {
@ -103,9 +91,9 @@ impl<'a> Next<'a> {
) -> BoxFuture<'a, Result<Response>> { ) -> BoxFuture<'a, Result<Response>> {
if let Some((current, rest)) = self.middlewares.split_first() { if let Some((current, rest)) = self.middlewares.split_first() {
self.middlewares = rest; self.middlewares = rest;
Box::pin(current.handle(req, extensions, self)) current.handle(req, extensions, self).boxed()
} else { } else {
Box::pin(async move { self.client.execute(req).await.map_err(Error::from) }) self.client.execute(req).map_err(Error::from).boxed()
} }
} }
} }

View File

@ -1,90 +0,0 @@
use crate::RequestBuilder;
/// When attached to a [`ClientWithMiddleware`] (generally using [`with_init`]), it is run
/// whenever the client starts building a request, in the order it was attached.
///
/// # Example
///
/// ```
/// use reqwest_middleware::{RequestInitialiser, RequestBuilder};
///
/// struct AuthInit;
///
/// impl RequestInitialiser for AuthInit {
/// fn init(&self, req: RequestBuilder) -> RequestBuilder {
/// req.bearer_auth("my_auth_token")
/// }
/// }
/// ```
///
/// [`ClientWithMiddleware`]: crate::ClientWithMiddleware
/// [`with_init`]: crate::ClientBuilder::with_init
pub trait RequestInitialiser: 'static + Send + Sync {
fn init(&self, req: RequestBuilder) -> RequestBuilder;
}
impl core::fmt::Debug for dyn RequestInitialiser {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "RequestInitialiser")
}
}
impl<F> RequestInitialiser for F
where
F: Send + Sync + 'static + Fn(RequestBuilder) -> RequestBuilder,
{
fn init(&self, req: RequestBuilder) -> RequestBuilder {
(self)(req)
}
}
/// A middleware that inserts the value into the [`Extensions`](task_local_extensions::Extensions) during the call.
///
/// This is a good way to inject extensions to middleware deeper in the stack
///
/// ```
/// use reqwest::{Client, Request, Response};
/// use reqwest_middleware::{ClientBuilder, Middleware, Next, Result, Extension};
/// use task_local_extensions::Extensions;
///
/// #[derive(Clone)]
/// struct LogName(&'static str);
/// struct LoggingMiddleware;
///
/// #[async_trait::async_trait]
/// impl Middleware for LoggingMiddleware {
/// async fn handle(
/// &self,
/// req: Request,
/// extensions: &mut Extensions,
/// next: Next<'_>,
/// ) -> Result<Response> {
/// // get the log name or default to "unknown"
/// let name = extensions
/// .get()
/// .map(|&LogName(name)| name)
/// .unwrap_or("unknown");
/// println!("[{name}] Request started {req:?}");
/// let res = next.run(req, extensions).await;
/// println!("[{name}] Result: {res:?}");
/// res
/// }
/// }
///
/// async fn run() {
/// let reqwest_client = Client::builder().build().unwrap();
/// let client = ClientBuilder::new(reqwest_client)
/// .with_init(Extension(LogName("my-client")))
/// .with(LoggingMiddleware)
/// .build();
/// let resp = client.get("https://truelayer.com").send().await.unwrap();
/// println!("TrueLayer page HTML: {}", resp.text().await.unwrap());
/// }
/// ```
pub struct Extension<T>(pub T);
impl<T: Send + Sync + Clone + 'static> RequestInitialiser for Extension<T> {
fn init(&self, req: RequestBuilder) -> RequestBuilder {
req.with_extension(self.0.clone())
}
}

View File

@ -4,30 +4,7 @@ All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
## [0.3.0] - 2023-09-07 ## [Unreleased]
### Changed
- `retry-policies` upgraded to 0.2.0
## [0.2.3] - 2023-08-30
### Added
- `RetryableStrategy` which allows for custom retry decisions based on the response that a request got
## [0.2.1] - 2022-12-01
### Changed
- Classify `io::Error`s and `hyper::Error(Canceled)` as transient
## [0.2.0] - 2022-11-15
### Changed
- Updated `reqwest-middleware` to `0.2.0`
## [0.1.4] - 2022-02-21
### Changed
- Updated `reqwest-middleware` to `0.1.5`
## [0.1.3] - 2022-01-24
### Changed
- Updated `reqwest-middleware` to `0.1.4`
## [0.1.2] - 2021-09-28 ## [0.1.2] - 2021-09-28
### Added ### Added

View File

@ -1,6 +1,6 @@
[package] [package]
name = "reqwest-retry" name = "reqwest-retry"
version = "0.3.0" version = "0.1.4-alpha.0"
authors = ["Rodrigo Gryzinski <rodrigo.gryzinski@truelayer.com>"] authors = ["Rodrigo Gryzinski <rodrigo.gryzinski@truelayer.com>"]
edition = "2018" edition = "2018"
description = "Retry middleware for reqwest." description = "Retry middleware for reqwest."
@ -10,29 +10,22 @@ keywords = ["reqwest", "http", "middleware", "retry"]
categories = ["web-programming::http-client"] categories = ["web-programming::http-client"]
[dependencies] [dependencies]
reqwest-middleware = { version = "0.2.0", path = "../reqwest-middleware" } reqwest-middleware = { version = "0.1.2", path = "../reqwest-middleware" }
anyhow = "1.0.0" anyhow = "1"
async-trait = "0.1.51" async-trait = "0.1.51"
chrono = { version = "0.4.19", features = ["clock"], default-features = false } chrono = { version = "0.4.19", features = ["clock"], default-features = false }
futures = "0.3.0" futures = "0.3"
http = "0.2.0" http = "0.2"
reqwest = { version = "0.11.0", default-features = false } hyper = "0.14"
retry-policies = "0.2.0" retry-policies = "0.1"
task-local-extensions = "0.1.4" reqwest = { version = "0.11", default-features = false }
tokio = { version = "1.6", features = ["time"] }
tracing = "0.1.26" tracing = "0.1.26"
task-local-extensions = "0.1.1"
[target.'cfg(not(target_arch = "wasm32"))'.dependencies]
hyper = "0.14.0"
tokio = { version = "1.6.0", features = ["time"] }
[target.'cfg(target_arch = "wasm32")'.dependencies]
parking_lot = { version = "0.11.2", features = ["wasm-bindgen"] } # work around https://github.com/tomaka/wasm-timer/issues/14
wasm-timer = "0.2.5"
getrandom = { version = "0.2.0", features = ["js"] }
[dev-dependencies] [dev-dependencies]
paste = "1.0.0" wiremock = "0.5"
tokio = { version = "1.0.0", features = ["full"] } tokio = { version = "1", features = ["macros"] }
wiremock = "0.5.0" paste = "1"
futures = "0.3.0" async-std = { version = "1.10"}

View File

@ -27,13 +27,8 @@
mod middleware; mod middleware;
mod retryable; mod retryable;
mod retryable_strategy;
pub use retry_policies::{policies, RetryPolicy}; pub use retry_policies::{policies, RetryPolicy};
pub use middleware::RetryTransientMiddleware; pub use middleware::RetryTransientMiddleware;
pub use retryable::Retryable; pub use retryable::Retryable;
pub use retryable_strategy::{
default_on_request_failure, default_on_request_success, DefaultRetryableStrategy,
RetryableStrategy,
};

View File

@ -1,6 +1,6 @@
//! `RetryTransientMiddleware` implements retrying requests on transient errors. //! `RetryTransientMiddleware` implements retrying requests on transient errors.
use crate::retryable_strategy::RetryableStrategy;
use crate::{retryable::Retryable, retryable_strategy::DefaultRetryableStrategy}; use crate::retryable::Retryable;
use anyhow::anyhow; use anyhow::anyhow;
use chrono::Utc; use chrono::Utc;
use reqwest::{Request, Response}; use reqwest::{Request, Response};
@ -8,13 +8,14 @@ use reqwest_middleware::{Error, Middleware, Next, Result};
use retry_policies::RetryPolicy; use retry_policies::RetryPolicy;
use task_local_extensions::Extensions; use task_local_extensions::Extensions;
/// We limit the number of retries to a maximum of `10` to avoid stack-overflow issues due to the recursion.
static MAXIMUM_NUMBER_OF_RETRIES: u32 = 10;
/// `RetryTransientMiddleware` offers retry logic for requests that fail in a transient manner /// `RetryTransientMiddleware` offers retry logic for requests that fail in a transient manner
/// and can be safely executed again. /// and can be safely executed again.
/// ///
/// Currently, it allows setting a [RetryPolicy] algorithm for calculating the __wait_time__ /// Currently, it allows setting a [RetryPolicy][retry_policies::RetryPolicy] algorithm for calculating the __wait_time__
/// between each request retry. Sleeping on non-`wasm32` archs is performed using /// between each request retry.
/// [`tokio::time::sleep`], therefore it will respect pauses/auto-advance if run under a
/// runtime that supports them.
/// ///
///```rust ///```rust
/// use reqwest_middleware::ClientBuilder; /// use reqwest_middleware::ClientBuilder;
@ -34,54 +35,19 @@ use task_local_extensions::Extensions;
/// let client = ClientBuilder::new(Client::new()).with(retry_transient_middleware).build(); /// let client = ClientBuilder::new(Client::new()).with(retry_transient_middleware).build();
///``` ///```
/// ///
/// # Note pub struct RetryTransientMiddleware<T: RetryPolicy + Send + Sync + 'static> {
///
/// This middleware always errors when given requests with streaming bodies, before even executing
/// the request. When this happens you'll get an [`Error::Middleware`] with the message
/// 'Request object is not clonable. Are you passing a streaming body?'.
///
/// Some workaround suggestions:
/// * If you can fit the data in memory, you can instead build static request bodies e.g. with
/// `Body`'s `From<String>` or `From<Bytes>` implementations.
/// * You can wrap this middleware in a custom one which skips retries for streaming requests.
/// * You can write a custom retry middleware that builds new streaming requests from the data
/// source directly, avoiding the issue of streaming requests not being clonable.
pub struct RetryTransientMiddleware<
T: RetryPolicy + Send + Sync + 'static,
R: RetryableStrategy + Send + Sync + 'static = DefaultRetryableStrategy,
> {
retry_policy: T, retry_policy: T,
retryable_strategy: R,
} }
impl<T: RetryPolicy + Send + Sync> RetryTransientMiddleware<T, DefaultRetryableStrategy> { impl<T: RetryPolicy + Send + Sync> RetryTransientMiddleware<T> {
/// Construct `RetryTransientMiddleware` with a [retry_policy][RetryPolicy]. /// Construct `RetryTransientMiddleware` with a [retry_policy][retry_policies::RetryPolicy].
pub fn new_with_policy(retry_policy: T) -> Self { pub fn new_with_policy(retry_policy: T) -> Self {
Self::new_with_policy_and_strategy(retry_policy, DefaultRetryableStrategy) Self { retry_policy }
} }
} }
impl<T, R> RetryTransientMiddleware<T, R> #[async_trait::async_trait]
where impl<T: RetryPolicy + Send + Sync> Middleware for RetryTransientMiddleware<T> {
T: RetryPolicy + Send + Sync,
R: RetryableStrategy + Send + Sync,
{
/// Construct `RetryTransientMiddleware` with a [retry_policy][RetryPolicy] and [retryable_strategy](RetryableStrategy).
pub fn new_with_policy_and_strategy(retry_policy: T, retryable_strategy: R) -> Self {
Self {
retry_policy,
retryable_strategy,
}
}
}
#[cfg_attr(not(target_arch = "wasm32"), async_trait::async_trait)]
#[cfg_attr(target_arch = "wasm32", async_trait::async_trait(?Send))]
impl<T, R> Middleware for RetryTransientMiddleware<T, R>
where
T: RetryPolicy + Send + Sync,
R: RetryableStrategy + Send + Sync + 'static,
{
async fn handle( async fn handle(
&self, &self,
req: Request, req: Request,
@ -92,26 +58,33 @@ where
// downstream. This will guard against previous retries poluting `Extensions`. // downstream. This will guard against previous retries poluting `Extensions`.
// That is, we only return what's populated in the typemap for the last retry attempt // That is, we only return what's populated in the typemap for the last retry attempt
// and copy those into the the `global` Extensions map. // and copy those into the the `global` Extensions map.
self.execute_with_retry(req, next, extensions).await self.execute_with_retry_recursive(req, next, extensions, 0)
.await
} }
} }
impl<T, R> RetryTransientMiddleware<T, R> impl<T: RetryPolicy + Send + Sync> RetryTransientMiddleware<T> {
where /// **RECURSIVE**.
T: RetryPolicy + Send + Sync, ///
R: RetryableStrategy + Send + Sync, /// SAFETY: The condition for termination is the number of retries
{ /// set on the `RetryOption` object which is capped to 10 therefore
/// we can know that this will not cause a overflow of the stack.
///
/// This function will try to execute the request, if it fails /// This function will try to execute the request, if it fails
/// with an error classified as transient it will call itself /// with an error classified as transient it will call itself
/// to retry the request. /// to retry the request.
async fn execute_with_retry<'a>( ///
/// NOTE: This function is not async because calling an async function
/// recursively is not allowed.
///
fn execute_with_retry_recursive<'a>(
&'a self, &'a self,
req: Request, req: Request,
next: Next<'a>, next: Next<'a>,
ext: &'a mut Extensions, mut ext: &'a mut Extensions,
) -> Result<Response> { n_past_retries: u32,
let mut n_past_retries = 0; ) -> futures::future::BoxFuture<'a, Result<Response>> {
loop { Box::pin(async move {
// Cloning the request object before-the-fact is not ideal.. // Cloning the request object before-the-fact is not ideal..
// However, if the body of the request is not static, e.g of type `Bytes`, // However, if the body of the request is not static, e.g of type `Bytes`,
// the Clone operation should be of constant complexity and not O(N) // the Clone operation should be of constant complexity and not O(N)
@ -122,16 +95,21 @@ where
)) ))
})?; })?;
let result = next.clone().run(duplicate_request, ext).await; let cloned_next = next.clone();
let result = next.run(req, &mut ext).await;
// We classify the response which will return None if not // We classify the response which will return None if not
// errors were returned. // errors were returned.
break match self.retryable_strategy.handle(&result) { match Retryable::from_reqwest_response(&result) {
Some(Retryable::Transient) => { Some(retryable)
if retryable == Retryable::Transient
&& n_past_retries < MAXIMUM_NUMBER_OF_RETRIES =>
{
// If the response failed and the error type was transient // If the response failed and the error type was transient
// we can safely try to retry the request. // we can safely try to retry the request.
let retry_decision = self.retry_policy.should_retry(n_past_retries); let retry_decicion = self.retry_policy.should_retry(n_past_retries);
if let retry_policies::RetryDecision::Retry { execute_after } = retry_decision { if let retry_policies::RetryDecision::Retry { execute_after } = retry_decicion {
let duration = (execute_after - Utc::now()) let duration = (execute_after - Utc::now())
.to_std() .to_std()
.map_err(Error::middleware)?; .map_err(Error::middleware)?;
@ -141,21 +119,21 @@ where
n_past_retries, n_past_retries,
duration duration
); );
#[cfg(not(target_arch = "wasm32"))]
tokio::time::sleep(duration).await; tokio::time::sleep(duration).await;
#[cfg(target_arch = "wasm32")]
wasm_timer::Delay::new(duration)
.await
.expect("failed sleeping");
n_past_retries += 1; self.execute_with_retry_recursive(
continue; duplicate_request,
cloned_next,
ext,
n_past_retries + 1,
)
.await
} else { } else {
result result
} }
} }
Some(_) | None => result, Some(_) | None => result,
}; }
} })
} }
} }

View File

@ -1,4 +1,4 @@
use crate::retryable_strategy::{DefaultRetryableStrategy, RetryableStrategy}; use http::StatusCode;
use reqwest_middleware::Error; use reqwest_middleware::Error;
/// Classification of an error/status returned by request. /// Classification of an error/status returned by request.
@ -16,7 +16,62 @@ impl Retryable {
/// Returns `None` if the response object does not contain any errors. /// Returns `None` if the response object does not contain any errors.
/// ///
pub fn from_reqwest_response(res: &Result<reqwest::Response, Error>) -> Option<Self> { pub fn from_reqwest_response(res: &Result<reqwest::Response, Error>) -> Option<Self> {
DefaultRetryableStrategy.handle(res) match res {
Ok(success) => {
let status = success.status();
if status.is_server_error() {
Some(Retryable::Transient)
} else if status.is_client_error()
&& status != StatusCode::REQUEST_TIMEOUT
&& status != StatusCode::TOO_MANY_REQUESTS
{
Some(Retryable::Fatal)
} else if status.is_success() {
None
} else if status == StatusCode::REQUEST_TIMEOUT
|| status == StatusCode::TOO_MANY_REQUESTS
{
Some(Retryable::Transient)
} else {
Some(Retryable::Fatal)
}
}
Err(error) => match error {
// If something fails in the middleware we're screwed.
Error::Middleware(_) => Some(Retryable::Fatal),
Error::Reqwest(error) => {
if error.is_timeout() || error.is_connect() {
Some(Retryable::Transient)
} else if error.is_body()
|| error.is_decode()
|| error.is_builder()
|| error.is_redirect()
{
Some(Retryable::Fatal)
} else if error.is_request() {
// It seems that hyper::Error(IncompleteMessage) is not correctly handled by reqwest.
// Here we check if the Reqwest error was originated by hyper and map it consistently.
if let Some(hyper_error) = get_source_error_type::<hyper::Error>(&error) {
// The hyper::Error(IncompleteMessage) is raised if the HTTP response is well formatted but does not contain all the bytes.
// This can happen when the server has started sending back the response but the connection is cut halfway thorugh.
// We can safely retry the call, hence marking this error as [`Retryable::Transient`].
if hyper_error.is_incomplete_message() {
Some(Retryable::Transient)
} else {
Some(Retryable::Fatal)
}
} else {
Some(Retryable::Fatal)
}
} else {
// We omit checking if error.is_status() since we check that already.
// However, if Response::error_for_status is used the status will still
// remain in the response object.
None
}
}
},
}
} }
} }
@ -25,3 +80,19 @@ impl From<&reqwest::Error> for Retryable {
Retryable::Transient Retryable::Transient
} }
} }
/// Downcasts the given err source into T.
fn get_source_error_type<T: std::error::Error + 'static>(
err: &dyn std::error::Error,
) -> Option<&T> {
let mut source = err.source();
while let Some(err) = source {
if let Some(hyper_err) = err.downcast_ref::<T>() {
return Some(hyper_err);
}
source = err.source();
}
None
}

View File

@ -1,213 +0,0 @@
use crate::retryable::Retryable;
use http::StatusCode;
use reqwest_middleware::Error;
/// A strategy to create a [`Retryable`] from a [`Result<reqwest::Response, reqwest_middleware::Error>`]
///
/// A [`RetryableStrategy`] has a single `handler` functions.
/// The result of calling the request could be:
/// - [`reqwest::Response`] In case the request has been sent and received correctly
/// This could however still mean that the server responded with a erroneous response.
/// For example a HTTP statuscode of 500
/// - [`reqwest_middleware::Error`] In this case the request actually failed.
/// This could, for example, be caused by a timeout on the connection.
///
/// Example:
///
/// ```
/// use reqwest_retry::{default_on_request_failure, policies::ExponentialBackoff, Retryable, RetryableStrategy, RetryTransientMiddleware};
/// use reqwest::{Request, Response};
/// use reqwest_middleware::{ClientBuilder, Middleware, Next, Result};
/// use task_local_extensions::Extensions;
///
/// // Log each request to show that the requests will be retried
/// struct LoggingMiddleware;
///
/// #[async_trait::async_trait]
/// impl Middleware for LoggingMiddleware {
/// async fn handle(
/// &self,
/// req: Request,
/// extensions: &mut Extensions,
/// next: Next<'_>,
/// ) -> Result<Response> {
/// println!("Request started {}", req.url());
/// let res = next.run(req, extensions).await;
/// println!("Request finished");
/// res
/// }
/// }
///
/// // Just a toy example, retry when the successful response code is 201, else do nothing.
/// struct Retry201;
/// impl RetryableStrategy for Retry201 {
/// fn handle(&self, res: &Result<reqwest::Response>) -> Option<Retryable> {
/// match res {
/// // retry if 201
/// Ok(success) if success.status() == 201 => Some(Retryable::Transient),
/// // otherwise do not retry a successful request
/// Ok(success) => None,
/// // but maybe retry a request failure
/// Err(error) => default_on_request_failure(error),
/// }
/// }
/// }
///
/// #[tokio::main]
/// async fn main() {
/// // Exponential backoff with max 2 retries
/// let retry_policy = ExponentialBackoff::builder()
/// .build_with_max_retries(2);
///
/// // Create the actual middleware, with the exponential backoff and custom retry stategy.
/// let ret_s = RetryTransientMiddleware::new_with_policy_and_strategy(
/// retry_policy,
/// Retry201,
/// );
///
/// let client = ClientBuilder::new(reqwest::Client::new())
/// // Retry failed requests.
/// .with(ret_s)
/// // Log the requests
/// .with(LoggingMiddleware)
/// .build();
///
/// // Send request which should get a 201 response. So it will be retried
/// let r = client
/// .get("https://httpbin.org/status/201")
/// .send()
/// .await;
/// println!("{:?}", r);
///
/// // Send request which should get a 200 response. So it will not be retried
/// let r = client
/// .get("https://httpbin.org/status/200")
/// .send()
/// .await;
/// println!("{:?}", r);
/// }
/// ```
pub trait RetryableStrategy {
fn handle(&self, res: &Result<reqwest::Response, Error>) -> Option<Retryable>;
}
/// The default [`RetryableStrategy`] for [`RetryTransientMiddleware`](crate::RetryTransientMiddleware).
pub struct DefaultRetryableStrategy;
impl RetryableStrategy for DefaultRetryableStrategy {
fn handle(&self, res: &Result<reqwest::Response, Error>) -> Option<Retryable> {
match res {
Ok(success) => default_on_request_success(success),
Err(error) => default_on_request_failure(error),
}
}
}
/// Default request success retry strategy.
///
/// Will only retry if:
/// * The status was 5XX (server error)
/// * The status was 408 (request timeout) or 429 (too many requests)
///
/// Note that success here means that the request finished without interruption, not that it was logically OK.
pub fn default_on_request_success(success: &reqwest::Response) -> Option<Retryable> {
let status = success.status();
if status.is_server_error() {
Some(Retryable::Transient)
} else if status.is_client_error()
&& status != StatusCode::REQUEST_TIMEOUT
&& status != StatusCode::TOO_MANY_REQUESTS
{
Some(Retryable::Fatal)
} else if status.is_success() {
None
} else if status == StatusCode::REQUEST_TIMEOUT || status == StatusCode::TOO_MANY_REQUESTS {
Some(Retryable::Transient)
} else {
Some(Retryable::Fatal)
}
}
/// Default request failure retry strategy.
///
/// Will only retry if the request failed due to a network error
pub fn default_on_request_failure(error: &Error) -> Option<Retryable> {
match error {
// If something fails in the middleware we're screwed.
Error::Middleware(_) => Some(Retryable::Fatal),
Error::Reqwest(error) => {
#[cfg(not(target_arch = "wasm32"))]
let is_connect = error.is_connect();
#[cfg(target_arch = "wasm32")]
let is_connect = false;
if error.is_timeout() || is_connect {
Some(Retryable::Transient)
} else if error.is_body()
|| error.is_decode()
|| error.is_builder()
|| error.is_redirect()
{
Some(Retryable::Fatal)
} else if error.is_request() {
// It seems that hyper::Error(IncompleteMessage) is not correctly handled by reqwest.
// Here we check if the Reqwest error was originated by hyper and map it consistently.
#[cfg(not(target_arch = "wasm32"))]
if let Some(hyper_error) = get_source_error_type::<hyper::Error>(&error) {
// The hyper::Error(IncompleteMessage) is raised if the HTTP response is well formatted but does not contain all the bytes.
// This can happen when the server has started sending back the response but the connection is cut halfway thorugh.
// We can safely retry the call, hence marking this error as [`Retryable::Transient`].
// Instead hyper::Error(Canceled) is raised when the connection is
// gracefully closed on the server side.
if hyper_error.is_incomplete_message() || hyper_error.is_canceled() {
Some(Retryable::Transient)
// Try and downcast the hyper error to io::Error if that is the
// underlying error, and try and classify it.
} else if let Some(io_error) =
get_source_error_type::<std::io::Error>(hyper_error)
{
Some(classify_io_error(io_error))
} else {
Some(Retryable::Fatal)
}
} else {
Some(Retryable::Fatal)
}
#[cfg(target_arch = "wasm32")]
Some(Retryable::Fatal)
} else {
// We omit checking if error.is_status() since we check that already.
// However, if Response::error_for_status is used the status will still
// remain in the response object.
None
}
}
}
}
#[cfg(not(target_arch = "wasm32"))]
fn classify_io_error(error: &std::io::Error) -> Retryable {
match error.kind() {
std::io::ErrorKind::ConnectionReset | std::io::ErrorKind::ConnectionAborted => {
Retryable::Transient
}
_ => Retryable::Fatal,
}
}
/// Downcasts the given err source into T.
#[cfg(not(target_arch = "wasm32"))]
fn get_source_error_type<T: std::error::Error + 'static>(
err: &dyn std::error::Error,
) -> Option<&T> {
let mut source = err.source();
while let Some(err) = source {
if let Some(err) = err.downcast_ref::<T>() {
return Some(err);
}
source = err.source();
}
None
}

View File

@ -1,12 +1,9 @@
use futures::future::BoxFuture; use async_std::io::ReadExt;
use async_std::io::WriteExt;
use async_std::net::{TcpListener, TcpStream};
use futures::stream::StreamExt;
use std::error::Error; use std::error::Error;
use std::fmt; use std::fmt;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::{TcpListener, TcpStream};
type CustomMessageHandler = Box<
dyn Fn(TcpStream) -> BoxFuture<'static, Result<(), Box<dyn std::error::Error>>> + Send + Sync,
>;
/// This is a simple server that returns the responses given at creation time: [`self.raw_http_responses`] following a round-robin mechanism. /// This is a simple server that returns the responses given at creation time: [`self.raw_http_responses`] following a round-robin mechanism.
pub struct SimpleServer { pub struct SimpleServer {
@ -15,7 +12,6 @@ pub struct SimpleServer {
host: String, host: String,
raw_http_responses: Vec<String>, raw_http_responses: Vec<String>,
calls_counter: usize, calls_counter: usize,
custom_handler: Option<CustomMessageHandler>,
} }
/// Request-Line = Method SP Request-URI SP HTTP-Version CRLF /// Request-Line = Method SP Request-URI SP HTTP-Version CRLF
@ -50,21 +46,9 @@ impl SimpleServer {
host: host.to_string(), host: host.to_string(),
raw_http_responses, raw_http_responses,
calls_counter: 0, calls_counter: 0,
custom_handler: None,
}) })
} }
pub fn set_custom_handler(
&mut self,
custom_handler: impl Fn(TcpStream) -> BoxFuture<'static, Result<(), Box<dyn std::error::Error>>>
+ Send
+ Sync
+ 'static,
) -> &mut Self {
self.custom_handler.replace(Box::new(custom_handler));
self
}
/// Returns the uri in which the server is listening to. /// Returns the uri in which the server is listening to.
pub fn uri(&self) -> String { pub fn uri(&self) -> String {
format!("http://{}:{}", self.host, self.port) format!("http://{}:{}", self.host, self.port)
@ -72,9 +56,9 @@ impl SimpleServer {
/// Starts the TcpListener and handles the requests. /// Starts the TcpListener and handles the requests.
pub async fn start(mut self) { pub async fn start(mut self) {
loop { while let Some(stream) = self.listener.incoming().next().await {
match self.listener.accept().await { match stream {
Ok((stream, _)) => { Ok(stream) => {
match self.handle_connection(stream).await { match self.handle_connection(stream).await {
Ok(_) => (), Ok(_) => (),
Err(e) => { Err(e) => {
@ -95,15 +79,11 @@ impl SimpleServer {
/// ///
/// Returns a 400 if the request if formatted badly. /// Returns a 400 if the request if formatted badly.
async fn handle_connection(&self, mut stream: TcpStream) -> Result<(), Box<dyn Error>> { async fn handle_connection(&self, mut stream: TcpStream) -> Result<(), Box<dyn Error>> {
if let Some(ref custom_handler) = self.custom_handler {
return custom_handler(stream).await;
}
let mut buffer = vec![0; 1024]; let mut buffer = vec![0; 1024];
let n = stream.read(&mut buffer).await.unwrap(); stream.read(&mut buffer).await.unwrap();
let request = String::from_utf8_lossy(&buffer[..n]); let request = String::from_utf8_lossy(&buffer[..]);
let request_line = request.lines().next().unwrap(); let request_line = request.lines().next().unwrap();
let response = match Self::parse_request_line(request_line) { let response = match Self::parse_request_line(request_line) {
@ -118,7 +98,7 @@ impl SimpleServer {
}; };
println!("-- Response --\n{}\n--------------", response.clone()); println!("-- Response --\n{}\n--------------", response.clone());
stream.write_all(response.as_bytes()).await.unwrap(); stream.write(response.as_bytes()).await.unwrap();
stream.flush().await.unwrap(); stream.flush().await.unwrap();
Ok(()) Ok(())

View File

@ -1,16 +1,12 @@
use futures::FutureExt;
use paste::paste; use paste::paste;
use reqwest::Client; use reqwest::Client;
use reqwest::StatusCode; use reqwest::StatusCode;
use reqwest_middleware::ClientBuilder; use reqwest_middleware::ClientBuilder;
use reqwest_retry::{policies::ExponentialBackoff, RetryTransientMiddleware}; use reqwest_retry::{policies::ExponentialBackoff, RetryTransientMiddleware};
use std::sync::atomic::AtomicI8;
use std::sync::{ use std::sync::{
atomic::{AtomicU32, Ordering}, atomic::{AtomicU32, Ordering},
Arc, Arc,
}; };
use tokio::io::AsyncReadExt;
use tokio::io::AsyncWriteExt;
use wiremock::matchers::{method, path}; use wiremock::matchers::{method, path};
use wiremock::{Mock, MockServer, Respond, ResponseTemplate}; use wiremock::{Mock, MockServer, Respond, ResponseTemplate};
@ -53,12 +49,12 @@ macro_rules! assert_retry_succeeds_inner {
let reqwest_client = Client::builder().build().unwrap(); let reqwest_client = Client::builder().build().unwrap();
let client = ClientBuilder::new(reqwest_client) let client = ClientBuilder::new(reqwest_client)
.with(RetryTransientMiddleware::new_with_policy( .with(RetryTransientMiddleware::new_with_policy(
ExponentialBackoff::builder() ExponentialBackoff {
.retry_bounds( max_n_retries: retry_amount,
std::time::Duration::from_millis(30), max_retry_interval: std::time::Duration::from_millis(30),
std::time::Duration::from_millis(100), min_retry_interval: std::time::Duration::from_millis(100),
) backoff_exponent: 2,
.build_with_max_retries(retry_amount), },
)) ))
.build(); .build();
@ -151,6 +147,17 @@ assert_retry_succeeds!(429, StatusCode::OK);
assert_no_retry!(431, StatusCode::REQUEST_HEADER_FIELDS_TOO_LARGE); assert_no_retry!(431, StatusCode::REQUEST_HEADER_FIELDS_TOO_LARGE);
assert_no_retry!(451, StatusCode::UNAVAILABLE_FOR_LEGAL_REASONS); assert_no_retry!(451, StatusCode::UNAVAILABLE_FOR_LEGAL_REASONS);
// We assert that we cap retries at 10, which means that we will
// get 11 calls to the RetryResponder.
assert_retry_succeeds_inner!(
500,
assert_maximum_retries_is_not_exceeded,
StatusCode::INTERNAL_SERVER_ERROR,
100,
11,
RetryResponder::new(100_u32, 500)
);
pub struct RetryTimeoutResponder(Arc<AtomicU32>, u32, std::time::Duration); pub struct RetryTimeoutResponder(Arc<AtomicU32>, u32, std::time::Duration);
impl RetryTimeoutResponder { impl RetryTimeoutResponder {
@ -189,12 +196,12 @@ async fn assert_retry_on_request_timeout() {
let reqwest_client = Client::builder().build().unwrap(); let reqwest_client = Client::builder().build().unwrap();
let client = ClientBuilder::new(reqwest_client) let client = ClientBuilder::new(reqwest_client)
.with(RetryTransientMiddleware::new_with_policy( .with(RetryTransientMiddleware::new_with_policy(
ExponentialBackoff::builder() ExponentialBackoff {
.retry_bounds( max_n_retries: 3,
std::time::Duration::from_millis(30), max_retry_interval: std::time::Duration::from_millis(100),
std::time::Duration::from_millis(100), min_retry_interval: std::time::Duration::from_millis(30),
) backoff_exponent: 2,
.build_with_max_retries(3), },
)) ))
.build(); .build();
@ -244,111 +251,12 @@ async fn assert_retry_on_incomplete_message() {
let reqwest_client = Client::builder().build().unwrap(); let reqwest_client = Client::builder().build().unwrap();
let client = ClientBuilder::new(reqwest_client) let client = ClientBuilder::new(reqwest_client)
.with(RetryTransientMiddleware::new_with_policy( .with(RetryTransientMiddleware::new_with_policy(
ExponentialBackoff::builder() ExponentialBackoff {
.retry_bounds( max_n_retries: 3,
std::time::Duration::from_millis(30), max_retry_interval: std::time::Duration::from_millis(100),
std::time::Duration::from_millis(100), min_retry_interval: std::time::Duration::from_millis(30),
) backoff_exponent: 2,
.build_with_max_retries(3), },
))
.build();
let resp = client
.get(&format!("{}/foo", uri))
.timeout(std::time::Duration::from_millis(100))
.send()
.await
.expect("call failed");
assert_eq!(resp.status(), 200);
}
#[tokio::test]
async fn assert_retry_on_hyper_canceled() {
let counter = Arc::new(AtomicI8::new(0));
let mut simple_server = SimpleServer::new("127.0.0.1", None, vec![])
.await
.expect("Error when creating a simple server");
simple_server.set_custom_handler(move |mut stream| {
let counter = counter.clone();
async move {
let mut buffer = Vec::new();
stream.read_buf(&mut buffer).await.unwrap();
if counter.fetch_add(1, Ordering::SeqCst) > 1 {
// This triggeres hyper:Error(Canceled).
let _res = stream
.into_std()
.unwrap()
.shutdown(std::net::Shutdown::Both);
} else {
let _res = stream.write("HTTP/1.1 200 OK\r\n\r\n".as_bytes()).await;
}
Ok(())
}
.boxed()
});
let uri = simple_server.uri();
tokio::spawn(simple_server.start());
let reqwest_client = Client::builder().build().unwrap();
let client = ClientBuilder::new(reqwest_client)
.with(RetryTransientMiddleware::new_with_policy(
ExponentialBackoff::builder()
.retry_bounds(
std::time::Duration::from_millis(30),
std::time::Duration::from_millis(100),
)
.build_with_max_retries(3),
))
.build();
let resp = client
.get(&format!("{}/foo", uri))
.timeout(std::time::Duration::from_millis(100))
.send()
.await
.expect("call failed");
assert_eq!(resp.status(), 200);
}
#[tokio::test]
async fn assert_retry_on_connection_reset_by_peer() {
let counter = Arc::new(AtomicI8::new(0));
let mut simple_server = SimpleServer::new("127.0.0.1", None, vec![])
.await
.expect("Error when creating a simple server");
simple_server.set_custom_handler(move |mut stream| {
let counter = counter.clone();
async move {
let mut buffer = Vec::new();
stream.read_buf(&mut buffer).await.unwrap();
if counter.fetch_add(1, Ordering::SeqCst) > 1 {
// This triggeres hyper:Error(Io, io::Error(ConnectionReset)).
drop(stream);
} else {
let _res = stream.write("HTTP/1.1 200 OK\r\n\r\n".as_bytes()).await;
}
Ok(())
}
.boxed()
});
let uri = simple_server.uri();
tokio::spawn(simple_server.start());
let reqwest_client = Client::builder().build().unwrap();
let client = ClientBuilder::new(reqwest_client)
.with(RetryTransientMiddleware::new_with_policy(
ExponentialBackoff::builder()
.retry_bounds(
std::time::Duration::from_millis(30),
std::time::Duration::from_millis(100),
)
.build_with_max_retries(3),
)) ))
.build(); .build();

View File

@ -6,73 +6,6 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
## [Unreleased] ## [Unreleased]
## [0.4.6] - 2023-08-23
### Added
- Add support for opentelemetry 0.20
## [0.4.5] - 2023-06-20
### Added
- A new extension `DisableOtelPropagation` which stops opentelemetry contexts propagating
- Support for opentelemetry 0.19
## [0.4.4] - 2023-05-15
### Added
- A new `default_span_name` method for use in custom span backends.
## [0.4.3] - 2023-05-15
### Fixed
- Fix span and http status codes
## [0.4.2] - 2023-05-12
### Added
- `OtelPathNames` extension to provide known parameterized paths that will be used in span names
### Changed
- `DefaultSpanBackend` and `SpanBackendWithUrl` default span name to HTTP method name instead of `reqwest-http-client`
## [0.4.1] - 2023-03-09
### Added
- Support for `wasm32-unknown-unknown` target
## [0.4.0] - 2022-11-15
### Changed
- Updated `reqwest-middleware` to `0.2.0`
- Before, `root_span!`/`DefaultSpanBacked` would name your spans `{METHOD} {PATH}`. Since this can be quite
high cardinality, this was changed and now the macro requires an explicit otel name.
`DefaultSpanBacked`/`SpanBackendWithUrl` will default to `reqwest-http-client` but this can be configured
using the `OtelName` Request Initialiser.
### Added
- `SpanBackendWithUrl` for capturing `http.url` in traces
- `OtelName` Request Initialiser Extension for configuring
## [0.3.1] - 2022-09-21
- Added support for `opentelemetry` version `0.18`.
## [0.3.0] - 2022-06-10
### Breaking
- Created `ReqwestOtelSpanBackend` trait with `reqwest_otel_span` macro to provide extendable default request otel fields
## [0.2.3] - 2022-06-23
### Fixed
- Fix how we set the OpenTelemetry span status, based on the HTTP response status.
## [0.2.2] - 2022-04-21
### Fixed
- Opentelemetry context is now propagated when the request span is disabled.
## [0.2.1] - 2022-02-21
### Changed
- Updated `reqwest-middleware` to `0.1.5`
## [0.2.0] - 2021-11-30 ## [0.2.0] - 2021-11-30
### Breaking ### Breaking
- Update to `tracing-subscriber` `0.3.x` when `opentelemetry_0_16` is active. - Update to `tracing-subscriber` `0.3.x` when `opentelemetry_0_16` is active.

View File

@ -1,12 +1,12 @@
[package] [package]
name = "reqwest-tracing" name = "reqwest-tracing"
version = "0.4.7" version = "0.2.1-alpha.0"
authors = ["Rodrigo Gryzinski <rodrigo.gryzinski@truelayer.com>"] authors = ["Rodrigo Gryzinski <rodrigo.gryzinski@truelayer.com>"]
edition = "2018" edition = "2018"
description = "Opentracing middleware for reqwest." description = "Opentracing middleware for reqwest."
repository = "https://github.com/TrueLayer/reqwest-middleware" repository = "https://github.com/TrueLayer/reqwest-middleware"
license = "MIT OR Apache-2.0" license = "MIT OR Apache-2.0"
keywords = ["reqwest", "http", "middleware", "opentelemetry", "tracing"] keywords = ["reqwest", "http", "middleware", "opentracing", "tracing"]
categories = ["web-programming::http-client"] categories = ["web-programming::http-client"]
[features] [features]
@ -14,51 +14,25 @@ opentelemetry_0_13 = ["opentelemetry_0_13_pkg", "tracing-opentelemetry_0_12_pkg"
opentelemetry_0_14 = ["opentelemetry_0_14_pkg", "tracing-opentelemetry_0_13_pkg"] opentelemetry_0_14 = ["opentelemetry_0_14_pkg", "tracing-opentelemetry_0_13_pkg"]
opentelemetry_0_15 = ["opentelemetry_0_15_pkg", "tracing-opentelemetry_0_14_pkg"] opentelemetry_0_15 = ["opentelemetry_0_15_pkg", "tracing-opentelemetry_0_14_pkg"]
opentelemetry_0_16 = ["opentelemetry_0_16_pkg", "tracing-opentelemetry_0_16_pkg"] opentelemetry_0_16 = ["opentelemetry_0_16_pkg", "tracing-opentelemetry_0_16_pkg"]
opentelemetry_0_17 = ["opentelemetry_0_17_pkg", "tracing-opentelemetry_0_17_pkg"]
opentelemetry_0_18 = ["opentelemetry_0_18_pkg", "tracing-opentelemetry_0_18_pkg"]
opentelemetry_0_19 = ["opentelemetry_0_19_pkg", "tracing-opentelemetry_0_19_pkg"]
opentelemetry_0_20 = ["opentelemetry_0_20_pkg", "tracing-opentelemetry_0_20_pkg"]
opentelemetry_0_21 = ["opentelemetry_0_21_pkg", "tracing-opentelemetry_0_22_pkg"]
[dependencies] [dependencies]
reqwest-middleware = { version = "0.2.0", path = "../reqwest-middleware" } reqwest-middleware = { version = "0.1.2", path = "../reqwest-middleware" }
anyhow = "1.0.70"
async-trait = "0.1.51" async-trait = "0.1.51"
matchit = "0.7.0" reqwest = { version = "0.11", default-features = false }
reqwest = { version = "0.11.0", default-features = false } tokio = { version = "1.6", features = ["time"] }
task-local-extensions = "0.1.4"
tracing = "0.1.26" tracing = "0.1.26"
task-local-extensions = "0.1.1"
opentelemetry_0_13_pkg = { package = "opentelemetry", version = "0.13.0", optional = true } opentelemetry_0_13_pkg = { package = "opentelemetry", version = "0.13", optional = true }
opentelemetry_0_14_pkg = { package = "opentelemetry", version = "0.14.0", optional = true } opentelemetry_0_14_pkg = { package = "opentelemetry", version = "0.14", optional = true }
opentelemetry_0_15_pkg = { package = "opentelemetry", version = "0.15.0", optional = true } opentelemetry_0_15_pkg = { package = "opentelemetry", version = "0.15", optional = true }
opentelemetry_0_16_pkg = { package = "opentelemetry", version = "0.16.0", optional = true } opentelemetry_0_16_pkg = { package = "opentelemetry", version = "0.16", optional = true }
opentelemetry_0_17_pkg = { package = "opentelemetry", version = "0.17.0", optional = true } tracing-opentelemetry_0_12_pkg = { package = "tracing-opentelemetry",version = "0.12", optional = true }
opentelemetry_0_18_pkg = { package = "opentelemetry", version = "0.18.0", optional = true } tracing-opentelemetry_0_13_pkg = { package = "tracing-opentelemetry", version = "0.13", optional = true }
opentelemetry_0_19_pkg = { package = "opentelemetry", version = "0.19.0", optional = true } tracing-opentelemetry_0_14_pkg = { package = "tracing-opentelemetry",version = "0.14", optional = true }
opentelemetry_0_20_pkg = { package = "opentelemetry", version = "0.20.0", optional = true } tracing-opentelemetry_0_16_pkg = { package = "tracing-opentelemetry",version = "0.16", optional = true }
opentelemetry_0_21_pkg = { package = "opentelemetry", version = "0.21.0", optional = true }
tracing-opentelemetry_0_12_pkg = { package = "tracing-opentelemetry", version = "0.12.0", optional = true }
tracing-opentelemetry_0_13_pkg = { package = "tracing-opentelemetry", version = "0.13.0", optional = true }
tracing-opentelemetry_0_14_pkg = { package = "tracing-opentelemetry", version = "0.14.0", optional = true }
tracing-opentelemetry_0_16_pkg = { package = "tracing-opentelemetry", version = "0.16.0", optional = true }
tracing-opentelemetry_0_17_pkg = { package = "tracing-opentelemetry", version = "0.17.0", optional = true }
tracing-opentelemetry_0_18_pkg = { package = "tracing-opentelemetry", version = "0.18.0", optional = true }
tracing-opentelemetry_0_19_pkg = { package = "tracing-opentelemetry", version = "0.19.0", optional = true }
tracing-opentelemetry_0_20_pkg = { package = "tracing-opentelemetry", version = "0.20.0", optional = true }
tracing-opentelemetry_0_22_pkg = { package = "tracing-opentelemetry", version = "0.22.0", optional = true }
[target.'cfg(target_arch = "wasm32")'.dependencies]
getrandom = { version = "0.2.0", features = ["js"] }
[dev-dependencies] [dev-dependencies]
tokio = { version = "1.0.0", features = ["macros"] } wiremock = "0.5"
tracing_subscriber_0_2 = { package = "tracing-subscriber", version = "0.2.0" } tokio = { version = "1", features = ["macros"] }
tracing_subscriber_0_3 = { package = "tracing-subscriber", version = "0.3.0" }
wiremock = "0.5.0"
opentelemetry_sdk_0_21 = { package = "opentelemetry_sdk", version = "0.21.0", features = ["trace"] }
opentelemetry_stdout_0_1 = { package = "opentelemetry-stdout", version = "0.1.0", features = ["trace"] }
opentelemetry_stdout_0_2 = { package = "opentelemetry-stdout", version = "0.2.0", features = ["trace"] }

View File

@ -16,44 +16,24 @@ Attach `TracingMiddleware` to your client to automatically trace HTTP requests:
# Cargo.toml # Cargo.toml
# ... # ...
[dependencies] [dependencies]
opentelemetry = "0.18" opentelemetry = "0.16"
reqwest = "0.11" reqwest = "0.11"
reqwest-middleware = "0.1.1" reqwest-middleware = "0.1.1"
reqwest-retry = "0.1.1" reqwest-retry = "0.1.1"
reqwest-tracing = { version = "0.3.1", features = ["opentelemetry_0_18"] } reqwest-tracing = { version = "0.1.2", features = ["opentelemetry_0_16"] }
tokio = { version = "1.12.0", features = ["macros", "rt-multi-thread"] } tokio = { version = "1.12.0", features = ["macros", "rt-multi-thread"] }
tracing = "0.1" tracing = "0.1"
tracing-opentelemetry = "0.18" tracing-opentelemetry = "0.15"
tracing-subscriber = "0.3" tracing-subscriber = "0.2"
task-local-extensions = "0.1.4"
``` ```
```rust,skip ```rust,skip
use reqwest_tracing::{default_on_request_end, reqwest_otel_span, ReqwestOtelSpanBackend, TracingMiddleware};
use opentelemetry::sdk::export::trace::stdout; use opentelemetry::sdk::export::trace::stdout;
use reqwest::{Request, Response}; use reqwest_middleware::ClientBuilder;
use reqwest_middleware::{ClientBuilder, Result}; use reqwest_tracing::TracingMiddleware;
use std::time::Instant;
use task_local_extensions::Extensions;
use tracing::Span;
use tracing_subscriber::layer::SubscriberExt; use tracing_subscriber::layer::SubscriberExt;
use tracing_subscriber::Registry; use tracing_subscriber::Registry;
pub struct TimeTrace;
impl ReqwestOtelSpanBackend for TimeTrace {
fn on_request_start(req: &Request, extension: &mut Extensions) -> Span {
extension.insert(Instant::now());
reqwest_otel_span!(name="example-request", req, time_elapsed = tracing::field::Empty)
}
fn on_request_end(span: &Span, outcome: &Result<Response>, extension: &mut Extensions) {
let time_elapsed = extension.get::<Instant>().unwrap().elapsed().as_millis() as i64;
default_on_request_end(span, outcome);
span.record("time_elapsed", &time_elapsed);
}
}
#[tokio::main] #[tokio::main]
async fn main() { async fn main() {
let tracer = stdout::new_pipeline().install_simple(); let tracer = stdout::new_pipeline().install_simple();
@ -66,7 +46,7 @@ async fn main() {
async fn run() { async fn run() {
let client = ClientBuilder::new(reqwest::Client::new()) let client = ClientBuilder::new(reqwest::Client::new())
.with(TracingMiddleware::<TimeTrace>::new()) .with(TracingMiddleware)
.build(); .build();
client.get("https://truelayer.com").send().await.unwrap(); client.get("https://truelayer.com").send().await.unwrap();
@ -89,12 +69,11 @@ an opentelemetry version feature:
```toml ```toml
[dependencies] [dependencies]
# ... # ...
reqwest-tracing = { version = "0.3.1", features = ["opentelemetry_0_18"] } reqwest-tracing = { version = "0.1.0", features = ["opentelemetry_0_16"] }
``` ```
Available opentelemetry features are `opentelemetry_0_21`, `opentelemetry_0_20`, Available opentelemetry features are `opentelemetry_0_16`, `opentelemetry_0_15`, `opentelemetry_0_14` and
`opentelemetry_0_19`, `opentelemetry_0_18`, `opentelemetry_0_17`, `opentelemetry_0_16`, `opentelemetry_0_13`.
`opentelemetry_0_15`, `opentelemetry_0_14` and `opentelemetry_0_13`.
#### License #### License

View File

@ -1,86 +1,6 @@
//! Opentracing middleware implementation for [`reqwest_middleware`]. //! Opentracing middleware implementation for [`reqwest-middleware`].
//! //!
//! Attach [`TracingMiddleware`] to your client to automatically trace HTTP requests. //! Attach [`TracingMiddleware`] to your client to automatically trace HTTP requests.
//!
//! The simplest possible usage:
//! ```no_run
//! # use reqwest_middleware::Result;
//! use reqwest_middleware::{ClientBuilder};
//! use reqwest_tracing::TracingMiddleware;
//!
//! # async fn example() -> Result<()> {
//! let reqwest_client = reqwest::Client::builder().build().unwrap();
//! let client = ClientBuilder::new(reqwest_client)
//! // Insert the tracing middleware
//! .with(TracingMiddleware::default())
//! .build();
//!
//! let resp = client.get("https://truelayer.com").send().await.unwrap();
//! # Ok(())
//! # }
//! ```
//!
//! To customise the span names use [`OtelName`].
//! ```no_run
//! # use reqwest_middleware::Result;
//! use reqwest_middleware::{ClientBuilder, Extension};
//! use reqwest_tracing::{
//! TracingMiddleware, OtelName
//! };
//! # async fn example() -> Result<()> {
//! let reqwest_client = reqwest::Client::builder().build().unwrap();
//! let client = ClientBuilder::new(reqwest_client)
//! // Inserts the extension before the request is started
//! .with_init(Extension(OtelName("my-client".into())))
//! // Makes use of that extension to specify the otel name
//! .with(TracingMiddleware::default())
//! .build();
//!
//! let resp = client.get("https://truelayer.com").send().await.unwrap();
//!
//! // Or specify it on the individual request (will take priority)
//! let resp = client.post("https://api.truelayer.com/payment")
//! .with_extension(OtelName("POST /payment".into()))
//! .send()
//! .await
//! .unwrap();
//! # Ok(())
//! # }
//! ```
//!
//! In this example we define a custom span builder to calculate the request time elapsed and we register the [`TracingMiddleware`].
//!
//! Note that Opentelemetry tracks start and stop already, there is no need to have a custom builder like this.
//! ```rust
//! use reqwest_middleware::Result;
//! use task_local_extensions::Extensions;
//! use reqwest::{Request, Response};
//! use reqwest_middleware::ClientBuilder;
//! use reqwest_tracing::{
//! default_on_request_end, reqwest_otel_span, ReqwestOtelSpanBackend, TracingMiddleware
//! };
//! use tracing::Span;
//! use std::time::{Duration, Instant};
//!
//! pub struct TimeTrace;
//!
//! impl ReqwestOtelSpanBackend for TimeTrace {
//! fn on_request_start(req: &Request, extension: &mut Extensions) -> Span {
//! extension.insert(Instant::now());
//! reqwest_otel_span!(name="example-request", req, time_elapsed = tracing::field::Empty)
//! }
//!
//! fn on_request_end(span: &Span, outcome: &Result<Response>, extension: &mut Extensions) {
//! let time_elapsed = extension.get::<Instant>().unwrap().elapsed().as_millis() as i64;
//! default_on_request_end(span, outcome);
//! span.record("time_elapsed", &time_elapsed);
//! }
//! }
//!
//! let http = ClientBuilder::new(reqwest::Client::new())
//! .with(TracingMiddleware::<TimeTrace>::new())
//! .build();
//! ```
mod middleware; mod middleware;
#[cfg(any( #[cfg(any(
@ -88,22 +8,7 @@ mod middleware;
feature = "opentelemetry_0_14", feature = "opentelemetry_0_14",
feature = "opentelemetry_0_15", feature = "opentelemetry_0_15",
feature = "opentelemetry_0_16", feature = "opentelemetry_0_16",
feature = "opentelemetry_0_17",
feature = "opentelemetry_0_18",
feature = "opentelemetry_0_19",
feature = "opentelemetry_0_20",
feature = "opentelemetry_0_21",
))] ))]
mod otel; mod otel;
mod reqwest_otel_span_builder;
pub use middleware::TracingMiddleware;
pub use reqwest_otel_span_builder::{
default_on_request_end, default_on_request_failure, default_on_request_success,
default_span_name, DefaultSpanBackend, DisableOtelPropagation, OtelName, OtelPathNames,
ReqwestOtelSpanBackend, SpanBackendWithUrl, ERROR_CAUSE_CHAIN, ERROR_MESSAGE, HTTP_HOST,
HTTP_METHOD, HTTP_SCHEME, HTTP_STATUS_CODE, HTTP_URL, HTTP_USER_AGENT, NET_HOST_PORT,
OTEL_KIND, OTEL_NAME, OTEL_STATUS_CODE,
};
#[doc(hidden)] pub use middleware::TracingMiddleware;
pub mod reqwest_otel_span_macro;

View File

@ -1,75 +1,130 @@
use reqwest::{Request, Response}; use reqwest::header::{HeaderMap, HeaderValue};
use reqwest_middleware::{Middleware, Next, Result}; use reqwest::{Request, Response, StatusCode as RequestStatusCode};
use reqwest_middleware::{Error, Middleware, Next, Result};
use task_local_extensions::Extensions; use task_local_extensions::Extensions;
use tracing::Instrument;
use crate::{DefaultSpanBackend, ReqwestOtelSpanBackend};
/// Middleware for tracing requests using the current Opentelemetry Context. /// Middleware for tracing requests using the current Opentelemetry Context.
pub struct TracingMiddleware<S: ReqwestOtelSpanBackend> { pub struct TracingMiddleware;
span_backend: std::marker::PhantomData<S>,
}
impl<S: ReqwestOtelSpanBackend> TracingMiddleware<S> { #[async_trait::async_trait]
pub fn new() -> TracingMiddleware<S> { impl Middleware for TracingMiddleware {
TracingMiddleware {
span_backend: Default::default(),
}
}
}
impl<S: ReqwestOtelSpanBackend> Clone for TracingMiddleware<S> {
fn clone(&self) -> Self {
Self::new()
}
}
impl Default for TracingMiddleware<DefaultSpanBackend> {
fn default() -> Self {
TracingMiddleware::new()
}
}
#[cfg_attr(not(target_arch = "wasm32"), async_trait::async_trait)]
#[cfg_attr(target_arch = "wasm32", async_trait::async_trait(?Send))]
impl<ReqwestOtelSpan> Middleware for TracingMiddleware<ReqwestOtelSpan>
where
ReqwestOtelSpan: ReqwestOtelSpanBackend + Sync + Send + 'static,
{
async fn handle( async fn handle(
&self, &self,
req: Request, req: Request,
extensions: &mut Extensions, extensions: &mut Extensions,
next: Next<'_>, next: Next<'_>,
) -> Result<Response> { ) -> Result<Response> {
let request_span = ReqwestOtelSpan::on_request_start(&req, extensions); let request_span = {
let method = req.method();
let scheme = req.url().scheme();
let host = req.url().host_str().unwrap_or("");
let host_port = req.url().port().unwrap_or(0) as i64;
let path = req.url().path();
let otel_name = format!("{} {}", method, path);
let outcome_future = async { tracing::info_span!(
"HTTP request",
http.method = %method,
http.scheme = %scheme,
http.host = %host,
net.host.port = %host_port,
otel.kind = "client",
otel.name = %otel_name,
otel.status_code = tracing::field::Empty,
http.user_agent = tracing::field::Empty,
http.status_code = tracing::field::Empty,
error.message = tracing::field::Empty,
error.cause_chain = tracing::field::Empty,
)
};
// Adds tracing headers to the given request to propagate the OpenTracing context to downstream revivers of the request.
// Spans added by downstream consumers will be part of the same trace.
#[cfg(any( #[cfg(any(
feature = "opentelemetry_0_13", feature = "opentelemetry_0_13",
feature = "opentelemetry_0_14", feature = "opentelemetry_0_14",
feature = "opentelemetry_0_15", feature = "opentelemetry_0_15",
feature = "opentelemetry_0_16", feature = "opentelemetry_0_16",
feature = "opentelemetry_0_17",
feature = "opentelemetry_0_18",
feature = "opentelemetry_0_19",
feature = "opentelemetry_0_20",
feature = "opentelemetry_0_21",
))] ))]
let req = if !extensions.contains::<crate::DisableOtelPropagation>() { let req = crate::otel::inject_opentracing_context_into_request(&request_span, req);
// Adds tracing headers to the given request to propagate the OpenTelemetry context to downstream revivers of the request.
// Spans added by downstream consumers will be part of the same trace.
crate::otel::inject_opentelemetry_context_into_request(req)
} else {
req
};
// Run the request // Run the request
let outcome = next.run(req, extensions).await; let outcome = next.run(req, extensions).await;
ReqwestOtelSpan::on_request_end(&request_span, &outcome, extensions); match &outcome {
Ok(response) => {
// The request ran successfully
let span_status = get_span_status(response.status());
let status_code = response.status().as_u16() as i64;
let user_agent = get_header_value("user_agent", response.headers());
if let Some(span_status) = span_status {
request_span.record("otel.status_code", &span_status);
}
request_span.record("http.status_code", &status_code);
request_span.record("http.user_agent", &user_agent.as_str());
}
Err(e) => {
// The request didn't run successfully
let error_message = e.to_string();
let error_cause_chain = format!("{:?}", e);
request_span.record("otel.status_code", &"ERROR");
request_span.record("error.message", &error_message.as_str());
request_span.record("error.cause_chain", &error_cause_chain.as_str());
if let Error::Reqwest(e) = e {
request_span.record(
"http.status_code",
&e.status()
.map(|s| s.to_string())
.unwrap_or_else(|| "".to_string())
.as_str(),
);
}
}
}
outcome outcome
}; }
}
outcome_future.instrument(request_span.clone()).await fn get_header_value(key: &str, headers: &HeaderMap) -> String {
let header_default = &HeaderValue::from_static("");
format!("{:?}", headers.get(key).unwrap_or(header_default)).replace("\"", "")
}
/// HTTP Mapping <https://github.com/open-telemetry/opentelemetry-specification/blob/c4b7f4307de79009c97b3a98563e91fee39b7ba3/work_in_progress/opencensus/HTTP.md#status>
// | HTTP code | Span status code |
// |-------------------------|-----------------------|
// | 100...299 | `Ok` |
// | 3xx redirect codes | `DeadlineExceeded` in case of loop (see above) [1], otherwise `Ok` |
// | 401 Unauthorized ⚠ | `Unauthenticated` ⚠ (Unauthorized actually means unauthenticated according to [RFC 7235][rfc-unauthorized]) |
// | 403 Forbidden | `PermissionDenied` |
// | 404 Not Found | `NotFound` |
// | 429 Too Many Requests | `ResourceExhausted` |
// | Other 4xx code | `InvalidArgument` [1] |
// | 501 Not Implemented | `Unimplemented` |
// | 503 Service Unavailable | `Unavailable` |
// | 504 Gateway Timeout | `DeadlineExceeded` |
// | Other 5xx code | `InternalError` [1] |
// | Any status code the client fails to interpret (e.g., 093 or 573) | `UnknownError` |
///
/// Maps the the http status to an Opentelemetry span status following the the specified convention above.
fn get_span_status(request_status: RequestStatusCode) -> Option<&'static str> {
match request_status.as_u16() {
100..=399 => Some("OK"),
400..=599 => Some("ERROR"),
_ => None,
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn get_header_value_for_span_attribute() {
let expect = "IMPORTANT_HEADER";
let mut header_map = HeaderMap::new();
header_map.insert("test", expect.parse().unwrap());
let value = get_header_value("test", &header_map);
assert_eq!(value, expect);
} }
} }

View File

@ -15,21 +15,6 @@ use opentelemetry_0_15_pkg as opentelemetry;
#[cfg(feature = "opentelemetry_0_16")] #[cfg(feature = "opentelemetry_0_16")]
use opentelemetry_0_16_pkg as opentelemetry; use opentelemetry_0_16_pkg as opentelemetry;
#[cfg(feature = "opentelemetry_0_17")]
use opentelemetry_0_17_pkg as opentelemetry;
#[cfg(feature = "opentelemetry_0_18")]
use opentelemetry_0_18_pkg as opentelemetry;
#[cfg(feature = "opentelemetry_0_19")]
use opentelemetry_0_19_pkg as opentelemetry;
#[cfg(feature = "opentelemetry_0_20")]
use opentelemetry_0_20_pkg as opentelemetry;
#[cfg(feature = "opentelemetry_0_21")]
use opentelemetry_0_21_pkg as opentelemetry;
#[cfg(feature = "opentelemetry_0_13")] #[cfg(feature = "opentelemetry_0_13")]
pub use tracing_opentelemetry_0_12_pkg as tracing_opentelemetry; pub use tracing_opentelemetry_0_12_pkg as tracing_opentelemetry;
@ -42,28 +27,14 @@ pub use tracing_opentelemetry_0_14_pkg as tracing_opentelemetry;
#[cfg(feature = "opentelemetry_0_16")] #[cfg(feature = "opentelemetry_0_16")]
pub use tracing_opentelemetry_0_16_pkg as tracing_opentelemetry; pub use tracing_opentelemetry_0_16_pkg as tracing_opentelemetry;
#[cfg(feature = "opentelemetry_0_17")]
pub use tracing_opentelemetry_0_17_pkg as tracing_opentelemetry;
#[cfg(feature = "opentelemetry_0_18")]
pub use tracing_opentelemetry_0_18_pkg as tracing_opentelemetry;
#[cfg(feature = "opentelemetry_0_19")]
pub use tracing_opentelemetry_0_19_pkg as tracing_opentelemetry;
#[cfg(feature = "opentelemetry_0_20")]
pub use tracing_opentelemetry_0_20_pkg as tracing_opentelemetry;
#[cfg(feature = "opentelemetry_0_21")]
pub use tracing_opentelemetry_0_22_pkg as tracing_opentelemetry;
use opentelemetry::global; use opentelemetry::global;
use opentelemetry::propagation::Injector; use opentelemetry::propagation::Injector;
use tracing_opentelemetry::OpenTelemetrySpanExt; use tracing_opentelemetry::OpenTelemetrySpanExt;
/// Injects the given OpenTelemetry Context into a reqwest::Request headers to allow propagation downstream. /// Injects the given Opentelemetry Context into a reqwest::Request headers to allow propagation downstream.
pub fn inject_opentelemetry_context_into_request(mut request: Request) -> Request { pub fn inject_opentracing_context_into_request(span: &Span, request: Request) -> Request {
let context = Span::current().context(); let context = span.context();
let mut request = request;
global::get_text_map_propagator(|injector| { global::get_text_map_propagator(|injector| {
injector.inject_context(&context, &mut RequestCarrier::new(&mut request)) injector.inject_context(&context, &mut RequestCarrier::new(&mut request))
@ -95,126 +66,3 @@ impl<'a> Injector for RequestCarrier<'a> {
self.request.headers_mut().insert(header_name, header_value); self.request.headers_mut().insert(header_name, header_value);
} }
} }
#[cfg(test)]
mod test {
use std::sync::OnceLock;
use super::*;
use crate::{DisableOtelPropagation, TracingMiddleware};
#[cfg(not(feature = "opentelemetry_0_21"))]
use opentelemetry::sdk::propagation::TraceContextPropagator;
#[cfg(feature = "opentelemetry_0_21")]
use opentelemetry_sdk_0_21::propagation::TraceContextPropagator;
use reqwest::Response;
use reqwest_middleware::{ClientBuilder, ClientWithMiddleware, Extension};
use tracing::{info_span, Instrument, Level};
#[cfg(any(
feature = "opentelemetry_0_13",
feature = "opentelemetry_0_14",
feature = "opentelemetry_0_15"
))]
use tracing_subscriber_0_2::{filter, layer::SubscriberExt, Registry};
#[cfg(not(any(
feature = "opentelemetry_0_13",
feature = "opentelemetry_0_14",
feature = "opentelemetry_0_15"
)))]
use tracing_subscriber_0_3::{filter, layer::SubscriberExt, Registry};
use wiremock::{matchers::any, Mock, MockServer, ResponseTemplate};
async fn make_echo_request_in_otel_context(client: ClientWithMiddleware) -> Response {
static TELEMETRY: OnceLock<()> = OnceLock::new();
TELEMETRY.get_or_init(|| {
#[cfg(all(
not(feature = "opentelemetry_0_20"),
not(feature = "opentelemetry_0_21")
))]
let tracer = opentelemetry::sdk::export::trace::stdout::new_pipeline()
.with_writer(std::io::sink())
.install_simple();
#[cfg(any(feature = "opentelemetry_0_20", feature = "opentelemetry_0_21"))]
let tracer = {
use opentelemetry::trace::TracerProvider;
#[cfg(feature = "opentelemetry_0_20")]
use opentelemetry_stdout_0_1::SpanExporterBuilder;
#[cfg(feature = "opentelemetry_0_21")]
use opentelemetry_stdout_0_2::SpanExporterBuilder;
let exporter = SpanExporterBuilder::default()
.with_writer(std::io::sink())
.build();
#[cfg(feature = "opentelemetry_0_20")]
let provider = opentelemetry::sdk::trace::TracerProvider::builder()
.with_simple_exporter(exporter)
.build();
#[cfg(feature = "opentelemetry_0_21")]
let provider = opentelemetry_sdk_0_21::trace::TracerProvider::builder()
.with_simple_exporter(exporter)
.build();
let tracer = provider.versioned_tracer("reqwest", None::<&str>, None::<&str>, None);
let _ = global::set_tracer_provider(provider);
tracer
};
let telemetry = tracing_opentelemetry::layer().with_tracer(tracer);
let subscriber = Registry::default()
.with(
filter::Targets::new().with_target("reqwest_tracing::otel::test", Level::DEBUG),
)
.with(telemetry);
tracing::subscriber::set_global_default(subscriber).unwrap();
global::set_text_map_propagator(TraceContextPropagator::new());
});
// Mock server - sends all request headers back in the response
let server = MockServer::start().await;
Mock::given(any())
.respond_with(|req: &wiremock::Request| {
req.headers
.iter()
.fold(ResponseTemplate::new(200), |resp, (k, v)| {
resp.append_header(k.clone(), v.clone())
})
})
.mount(&server)
.await;
client
.get(server.uri())
.send()
.instrument(info_span!("some_span"))
.await
.unwrap()
}
#[tokio::test]
async fn tracing_middleware_propagates_otel_data_even_when_the_span_is_disabled() {
let client = ClientBuilder::new(reqwest::Client::new())
.with(TracingMiddleware::default())
.build();
let resp = make_echo_request_in_otel_context(client).await;
assert!(
resp.headers().contains_key("traceparent"),
"by default, the tracing middleware will propagate otel contexts"
);
}
#[tokio::test]
async fn context_no_propagated() {
let client = ClientBuilder::new(reqwest::Client::new())
.with_init(Extension(DisableOtelPropagation))
.with(TracingMiddleware::default())
.build();
let resp = make_echo_request_in_otel_context(client).await;
assert!(
!resp.headers().contains_key("traceparent"),
"request should not contain traceparent if context propagation is disabled"
);
}
}

View File

@ -1,370 +0,0 @@
use std::borrow::Cow;
use matchit::Router;
use reqwest::header::{HeaderMap, HeaderValue};
use reqwest::{Request, Response, StatusCode as RequestStatusCode, Url};
use reqwest_middleware::{Error, Result};
use task_local_extensions::Extensions;
use tracing::{warn, Span};
use crate::reqwest_otel_span;
/// The `http.method` field added to the span by [`reqwest_otel_span`]
pub const HTTP_METHOD: &str = "http.method";
/// The `http.scheme` field added to the span by [`reqwest_otel_span`]
pub const HTTP_SCHEME: &str = "http.scheme";
/// The `http.host` field added to the span by [`reqwest_otel_span`]
pub const HTTP_HOST: &str = "http.host";
/// The `http.url` field added to the span by [`reqwest_otel_span`]
pub const HTTP_URL: &str = "http.url";
/// The `host.port` field added to the span by [`reqwest_otel_span`]
pub const NET_HOST_PORT: &str = "net.host.port";
/// The `otel.kind` field added to the span by [`reqwest_otel_span`]
pub const OTEL_KIND: &str = "otel.kind";
/// The `otel.name` field added to the span by [`reqwest_otel_span`]
pub const OTEL_NAME: &str = "otel.name";
/// The `otel.status_code` field added to the span by [`reqwest_otel_span`]
pub const OTEL_STATUS_CODE: &str = "otel.status_code";
/// The `error.message` field added to the span by [`reqwest_otel_span`]
pub const ERROR_MESSAGE: &str = "error.message";
/// The `error.cause_chain` field added to the span by [`reqwest_otel_span`]
pub const ERROR_CAUSE_CHAIN: &str = "error.cause_chain";
/// The `http.status_code` field added to the span by [`reqwest_otel_span`]
pub const HTTP_STATUS_CODE: &str = "http.status_code";
/// The `http.user_agent` added to the span by [`reqwest_otel_span`]
pub const HTTP_USER_AGENT: &str = "http.user_agent";
/// [`ReqwestOtelSpanBackend`] allows you to customise the span attached by
/// [`TracingMiddleware`] to incoming requests.
///
/// Check out [`reqwest_otel_span`] documentation for examples.
///
/// [`TracingMiddleware`]: crate::middleware::TracingMiddleware.
pub trait ReqwestOtelSpanBackend {
/// Initalized a new span before the request is executed.
fn on_request_start(req: &Request, extension: &mut Extensions) -> Span;
/// Runs after the request call has executed.
fn on_request_end(span: &Span, outcome: &Result<Response>, extension: &mut Extensions);
}
/// Populates default success/failure fields for a given [`reqwest_otel_span!`] span.
#[inline]
pub fn default_on_request_end(span: &Span, outcome: &Result<Response>) {
match outcome {
Ok(res) => default_on_request_success(span, res),
Err(err) => default_on_request_failure(span, err),
}
}
/// Populates default success fields for a given [`reqwest_otel_span!`] span.
#[inline]
pub fn default_on_request_success(span: &Span, response: &Response) {
let span_status = get_span_status(response.status());
let user_agent = get_header_value("user_agent", response.headers());
if let Some(span_status) = span_status {
span.record(OTEL_STATUS_CODE, span_status);
}
span.record(HTTP_STATUS_CODE, response.status().as_u16());
span.record(HTTP_USER_AGENT, user_agent.as_str());
}
/// Populates default failure fields for a given [`reqwest_otel_span!`] span.
#[inline]
pub fn default_on_request_failure(span: &Span, e: &Error) {
let error_message = e.to_string();
let error_cause_chain = format!("{:?}", e);
span.record(OTEL_STATUS_CODE, "ERROR");
span.record(ERROR_MESSAGE, error_message.as_str());
span.record(ERROR_CAUSE_CHAIN, error_cause_chain.as_str());
if let Error::Reqwest(e) = e {
if let Some(status) = e.status() {
span.record(HTTP_STATUS_CODE, status.as_u16());
}
}
}
/// Determine the name of the span that should be associated with this request.
///
/// This tries to be PII safe by default, not including any path information unless
/// specifically opted in using either [`OtelName`] or [`OtelPathNames`]
#[inline]
pub fn default_span_name<'a>(req: &'a Request, ext: &'a Extensions) -> Cow<'a, str> {
if let Some(name) = ext.get::<OtelName>() {
Cow::Borrowed(name.0.as_ref())
} else if let Some(path_names) = ext.get::<OtelPathNames>() {
path_names
.find(req.url().path())
.map(|path| Cow::Owned(format!("{} {}", req.method(), path)))
.unwrap_or_else(|| {
warn!("no OTEL path name found");
Cow::Owned(format!("{} UNKNOWN", req.method().as_str()))
})
} else {
Cow::Borrowed(req.method().as_str())
}
}
/// The default [`ReqwestOtelSpanBackend`] for [`TracingMiddleware`]. Note that it doesn't include
/// the `http.url` field in spans, you can use [`SpanBackendWithUrl`] to add it.
///
/// [`TracingMiddleware`]: crate::middleware::TracingMiddleware
pub struct DefaultSpanBackend;
impl ReqwestOtelSpanBackend for DefaultSpanBackend {
fn on_request_start(req: &Request, ext: &mut Extensions) -> Span {
let name = default_span_name(req, ext);
reqwest_otel_span!(name = name, req)
}
fn on_request_end(span: &Span, outcome: &Result<Response>, _: &mut Extensions) {
default_on_request_end(span, outcome)
}
}
fn get_header_value(key: &str, headers: &HeaderMap) -> String {
let header_default = &HeaderValue::from_static("");
format!("{:?}", headers.get(key).unwrap_or(header_default)).replace('"', "")
}
/// Similar to [`DefaultSpanBackend`] but also adds the `http.url` attribute to request spans.
///
/// [`TracingMiddleware`]: crate::middleware::TracingMiddleware
pub struct SpanBackendWithUrl;
impl ReqwestOtelSpanBackend for SpanBackendWithUrl {
fn on_request_start(req: &Request, ext: &mut Extensions) -> Span {
let name = default_span_name(req, ext);
reqwest_otel_span!(name = name, req, http.url = %remove_credentials(req.url()))
}
fn on_request_end(span: &Span, outcome: &Result<Response>, _: &mut Extensions) {
default_on_request_end(span, outcome)
}
}
/// HTTP Mapping <https://github.com/open-telemetry/opentelemetry-specification/blob/main/specification/trace/semantic_conventions/http.md#status>
///
/// Maps the the http status to an Opentelemetry span status following the the specified convention above.
fn get_span_status(request_status: RequestStatusCode) -> Option<&'static str> {
match request_status.as_u16() {
// Span Status MUST be left unset if HTTP status code was in the 1xx, 2xx or 3xx ranges, unless there was
// another error (e.g., network error receiving the response body; or 3xx codes with max redirects exceeded),
// in which case status MUST be set to Error.
100..=399 => None,
// For HTTP status codes in the 4xx range span status MUST be left unset in case of SpanKind.SERVER and MUST be
// set to Error in case of SpanKind.CLIENT.
400..=499 => Some("ERROR"),
// For HTTP status codes in the 5xx range, as well as any other code the client failed to interpret, span
// status MUST be set to Error.
_ => Some("ERROR"),
}
}
/// [`OtelName`] allows customisation of the name of the spans created by
/// [`DefaultSpanBackend`] and [`SpanBackendWithUrl`].
///
/// Usage:
/// ```no_run
/// # use reqwest_middleware::Result;
/// use reqwest_middleware::{ClientBuilder, Extension};
/// use reqwest_tracing::{
/// TracingMiddleware, OtelName
/// };
/// # async fn example() -> Result<()> {
/// let reqwest_client = reqwest::Client::builder().build().unwrap();
/// let client = ClientBuilder::new(reqwest_client)
/// // Inserts the extension before the request is started
/// .with_init(Extension(OtelName("my-client".into())))
/// // Makes use of that extension to specify the otel name
/// .with(TracingMiddleware::default())
/// .build();
///
/// let resp = client.get("https://truelayer.com").send().await.unwrap();
///
/// // Or specify it on the individual request (will take priority)
/// let resp = client.post("https://api.truelayer.com/payment")
/// .with_extension(OtelName("POST /payment".into()))
/// .send()
/// .await
/// .unwrap();
/// # Ok(())
/// # }
/// ```
#[derive(Clone)]
pub struct OtelName(pub Cow<'static, str>);
/// [`OtelPathNames`] allows including templated paths in the spans created by
/// [`DefaultSpanBackend`] and [`SpanBackendWithUrl`].
///
/// When creating spans this can be used to try to match the path against some
/// known paths. If the path matches value returned is the templated path. This
/// can be used in span names as it will not contain values that would
/// increase the cardinality.
///
/// ```
/// /// # use reqwest_middleware::Result;
/// use reqwest_middleware::{ClientBuilder, Extension};
/// use reqwest_tracing::{
/// TracingMiddleware, OtelPathNames
/// };
/// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
/// let reqwest_client = reqwest::Client::builder().build()?;
/// let client = ClientBuilder::new(reqwest_client)
/// // Inserts the extension before the request is started
/// .with_init(Extension(OtelPathNames::known_paths(["/payment/:paymentId"])?))
/// // Makes use of that extension to specify the otel name
/// .with(TracingMiddleware::default())
/// .build();
///
/// let resp = client.get("https://truelayer.com/payment/id-123").send().await?;
///
/// // Or specify it on the individual request (will take priority)
/// let resp = client.post("https://api.truelayer.com/payment/id-123/authorization-flow")
/// .with_extension(OtelPathNames::known_paths(["/payment/:paymentId/authorization-flow"])?)
/// .send()
/// .await?;
/// # Ok(())
/// # }
/// ```
#[derive(Clone)]
pub struct OtelPathNames(matchit::Router<String>);
impl OtelPathNames {
/// Create a new [`OtelPathNames`] from a set of known paths.
///
/// Paths in this set will be found with `find`.
///
/// Paths can have different parameters:
/// - Named parameters like `:paymentId` match anything until the next `/` or the end of the path.
/// - Catch-all parameters start with `*` and match everything after the `/`. They must be at the end of the route.
/// ```
/// # use reqwest_tracing::OtelPathNames;
/// OtelPathNames::known_paths([
/// "/",
/// "/payment",
/// "/payment/:paymentId",
/// "/payment/:paymentId/*action",
/// ]).unwrap();
/// ```
pub fn known_paths<Paths, Path>(paths: Paths) -> anyhow::Result<Self>
where
Paths: IntoIterator<Item = Path>,
Path: Into<String>,
{
let mut router = Router::new();
for path in paths {
let path = path.into();
router.insert(path.clone(), path)?;
}
Ok(Self(router))
}
/// Find the templated path from the actual path.
///
/// Returns the templated path if a match is found.
///
/// ```
/// # use reqwest_tracing::OtelPathNames;
/// let path_names = OtelPathNames::known_paths(["/payment/:paymentId"]).unwrap();
/// let path = path_names.find("/payment/payment-id-123");
/// assert_eq!(path, Some("/payment/:paymentId"));
/// ```
pub fn find(&self, path: &str) -> Option<&str> {
self.0.at(path).map(|mtch| mtch.value.as_str()).ok()
}
}
/// `DisableOtelPropagation` disables opentelemetry header propagation, while still tracing the HTTP request.
///
/// By default, the [`TracingMiddleware`](super::TracingMiddleware) middleware will also propagate any opentelemtry
/// contexts to the server. For any external facing requests, this can be problematic and it should be disabled.
///
/// Usage:
/// ```no_run
/// # use reqwest_middleware::Result;
/// use reqwest_middleware::{ClientBuilder, Extension};
/// use reqwest_tracing::{
/// TracingMiddleware, DisableOtelPropagation
/// };
/// # async fn example() -> Result<()> {
/// let reqwest_client = reqwest::Client::builder().build().unwrap();
/// let client = ClientBuilder::new(reqwest_client)
/// // Inserts the extension before the request is started
/// .with_init(Extension(DisableOtelPropagation))
/// // Makes use of that extension to specify the otel name
/// .with(TracingMiddleware::default())
/// .build();
///
/// let resp = client.get("https://truelayer.com").send().await.unwrap();
///
/// // Or specify it on the individual request (will take priority)
/// let resp = client.post("https://api.truelayer.com/payment")
/// .with_extension(DisableOtelPropagation)
/// .send()
/// .await
/// .unwrap();
/// # Ok(())
/// # }
/// ```
#[derive(Clone)]
pub struct DisableOtelPropagation;
/// Removes the username and/or password parts of the url, if present.
fn remove_credentials(url: &Url) -> Cow<'_, str> {
if !url.username().is_empty() || url.password().is_some() {
let mut url = url.clone();
// Errors settings username/password are set when the URL can't have credentials, so
// they're just ignored.
url.set_username("")
.and_then(|_| url.set_password(None))
.ok();
url.to_string().into()
} else {
url.as_ref().into()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn get_header_value_for_span_attribute() {
let expect = "IMPORTANT_HEADER";
let mut header_map = HeaderMap::new();
header_map.insert("test", expect.parse().unwrap());
let value = get_header_value("test", &header_map);
assert_eq!(value, expect);
}
#[test]
fn remove_credentials_from_url_without_credentials_is_noop() {
let url = "http://nocreds.com/".parse().unwrap();
let clean = remove_credentials(&url);
assert_eq!(clean, "http://nocreds.com/");
}
#[test]
fn remove_credentials_removes_username_only() {
let url = "http://user@withuser.com/".parse().unwrap();
let clean = remove_credentials(&url);
assert_eq!(clean, "http://withuser.com/");
}
#[test]
fn remove_credentials_removes_password_only() {
let url = "http://:123@withpwd.com/".parse().unwrap();
let clean = remove_credentials(&url);
assert_eq!(clean, "http://withpwd.com/");
}
#[test]
fn remove_credentials_removes_username_and_password() {
let url = "http://user:123@both.com/".parse().unwrap();
let clean = remove_credentials(&url);
assert_eq!(clean, "http://both.com/");
}
}

View File

@ -1,175 +0,0 @@
#[macro_export]
/// [`reqwest_otel_span!`](crate::reqwest_otel_span) creates a new [`tracing::Span`].
/// It empowers you to add custom properties to the span on top of the default properties provided by the macro
///
/// Default Fields:
/// - http.method
/// - http.scheme
/// - http.host
/// - net.host
/// - otel.kind
/// - otel.name
/// - otel.status_code
/// - http.user_agent
/// - http.status_code
/// - error.message
/// - error.cause_chain
///
/// Here are some convenient functions to checkout [`default_on_request_success`], [`default_on_request_failure`],
/// and [`default_on_request_end`].
///
/// # Why a macro?
///
/// [`tracing`] requires all the properties attached to a span to be declared upfront, when the span is created.
/// You cannot add new ones afterwards.
/// This makes it extremely fast, but it pushes us to reach for macros when we need some level of composition.
///
/// # Macro syntax
///
/// The first argument is a [span name](https://opentelemetry.io/docs/reference/specification/trace/api/#span).
/// The second argument passed to [`reqwest_otel_span!`](crate::reqwest_otel_span) is a reference to an [`reqwest::Request`].
///
/// ```rust
/// use reqwest_middleware::Result;
/// use task_local_extensions::Extensions;
/// use reqwest::{Request, Response};
/// use reqwest_tracing::{
/// default_on_request_end, reqwest_otel_span, ReqwestOtelSpanBackend
/// };
/// use tracing::Span;
///
/// pub struct CustomReqwestOtelSpanBackend;
///
/// impl ReqwestOtelSpanBackend for CustomReqwestOtelSpanBackend {
/// fn on_request_start(req: &Request, _extension: &mut Extensions) -> Span {
/// reqwest_otel_span!(name = "reqwest-http-request", req)
/// }
///
/// fn on_request_end(span: &Span, outcome: &Result<Response>, _extension: &mut Extensions) {
/// default_on_request_end(span, outcome)
/// }
/// }
/// ```
///
/// If nothing else is specified, the span generated by `reqwest_otel_span!` is identical to the one you'd
/// get by using [`DefaultSpanBackend`]. Note that to avoid leaking sensitive information, the
/// macro doesn't include `http.url`, even though it's required by opentelemetry. You can add the
/// URL attribute explicitly by usng [`SpanBackendWithUrl`] instead of `DefaultSpanBackend` or
/// adding the field on your own implementation.
///
/// You can define new fields following the same syntax of [`tracing::info_span!`] for fields:
///
/// ```rust,should_panic
/// use reqwest_tracing::reqwest_otel_span;
/// # let request: &reqwest::Request = todo!();
///
/// // Define a `time_elapsed` field as empty. It might be populated later.
/// // (This example is just to show how to inject data - otel already tracks durations)
/// reqwest_otel_span!(name = "reqwest-http-request", request, time_elapsed = tracing::field::Empty);
///
/// // Define a `name` field with a known value, `AppName`.
/// reqwest_otel_span!(name = "reqwest-http-request", request, name = "AppName");
///
/// // Define an `app_id` field using the variable with the same name as value.
/// let app_id = "XYZ";
/// reqwest_otel_span!(name = "reqwest-http-request", request, app_id);
///
/// // All together
/// reqwest_otel_span!(name = "reqwest-http-request", request, time_elapsed = tracing::field::Empty, name = "AppName", app_id);
/// ```
///
/// You can also choose to customise the level of the generated span:
///
/// ```rust,should_panic
/// use reqwest_tracing::reqwest_otel_span;
/// use tracing::Level;
/// # let request: &reqwest::Request = todo!();
///
/// // Reduce the log level for service endpoints/probes
/// let level = if request.method().as_str() == "POST" {
/// Level::DEBUG
/// } else {
/// Level::INFO
/// };
///
/// // `level =` and name MUST come before the request, in this order
/// reqwest_otel_span!(level = level, name = "reqwest-http-request", request);
/// ```
///
///
/// [`DefaultSpanBackend`]: crate::reqwest_otel_span_builder::DefaultSpanBackend
/// [`SpanBackendWithUrl`]: crate::reqwest_otel_span_builder::DefaultSpanBackend
/// [`default_on_request_success`]: crate::reqwest_otel_span_builder::default_on_request_success
/// [`default_on_request_failure`]: crate::reqwest_otel_span_builder::default_on_request_failure
/// [`default_on_request_end`]: crate::reqwest_otel_span_builder::default_on_request_end
macro_rules! reqwest_otel_span {
// Vanilla root span at default INFO level, with no additional fields
(name=$name:expr, $request:ident) => {
reqwest_otel_span!(name=$name, $request,)
};
// Vanilla root span, with no additional fields but custom level
(level=$level:expr, name=$name:expr, $request:ident) => {
reqwest_otel_span!(level=$level, name=$name, $request,)
};
// Root span with additional fields, default INFO level
(name=$name:expr, $request:ident, $($field:tt)*) => {
reqwest_otel_span!(level=$crate::reqwest_otel_span_macro::private::Level::INFO, name=$name, $request, $($field)*)
};
// Root span with additional fields and custom level
(level=$level:expr, name=$name:expr, $request:ident, $($field:tt)*) => {
{
let method = $request.method();
let url = $request.url();
let scheme = url.scheme();
let host = url.host_str().unwrap_or("");
let host_port = url.port().unwrap_or(0) as i64;
let otel_name = $name.to_string();
macro_rules! request_span {
($lvl:expr) => {
$crate::reqwest_otel_span_macro::private::span!(
$lvl,
"HTTP request",
http.method = %method,
http.scheme = %scheme,
http.host = %host,
net.host.port = %host_port,
otel.kind = "client",
otel.name = %otel_name,
otel.status_code = tracing::field::Empty,
http.user_agent = tracing::field::Empty,
http.status_code = tracing::field::Empty,
error.message = tracing::field::Empty,
error.cause_chain = tracing::field::Empty,
$($field)*
)
}
}
let span = match $level {
$crate::reqwest_otel_span_macro::private::Level::TRACE => {
request_span!($crate::reqwest_otel_span_macro::private::Level::TRACE)
},
$crate::reqwest_otel_span_macro::private::Level::DEBUG => {
request_span!($crate::reqwest_otel_span_macro::private::Level::DEBUG)
},
$crate::reqwest_otel_span_macro::private::Level::INFO => {
request_span!($crate::reqwest_otel_span_macro::private::Level::INFO)
},
$crate::reqwest_otel_span_macro::private::Level::WARN => {
request_span!($crate::reqwest_otel_span_macro::private::Level::WARN)
},
$crate::reqwest_otel_span_macro::private::Level::ERROR => {
request_span!($crate::reqwest_otel_span_macro::private::Level::ERROR)
},
};
span
}
}
}
#[doc(hidden)]
pub mod private {
#[doc(hidden)]
pub use tracing::{span, Level};
}