remove but inspired by tower

This commit is contained in:
Conrad Ludgate 2022-11-15 21:51:15 +00:00
parent df4990d62e
commit 8f3623eae0
No known key found for this signature in database
GPG key ID: 197E3CACA1C980B5
10 changed files with 277 additions and 120 deletions

View file

@ -34,7 +34,7 @@ tower = "0.4"
```rust ```rust
use reqwest::Response; use reqwest::Response;
use reqwest_middleware::{ClientBuilder, ClientWithMiddleware, Error, MiddlewareRequest, RequestInitialiser, ReqwestService}; use reqwest_middleware::{ClientBuilder, ClientWithMiddleware, Error, Layer, RequestInitialiser, ReqwestService, Service};
use reqwest_retry::{RetryTransientMiddleware, policies::ExponentialBackoff}; use reqwest_retry::{RetryTransientMiddleware, policies::ExponentialBackoff};
use reqwest_tracing::TracingMiddleware; use reqwest_tracing::TracingMiddleware;
@ -53,8 +53,8 @@ async fn main() {
async fn run<M, I>(client: ClientWithMiddleware<M, I>) async fn run<M, I>(client: ClientWithMiddleware<M, I>)
where where
M: tower::Layer<ReqwestService>, M: Layer<ReqwestService>,
M::Service: tower::Service<MiddlewareRequest, Response = Response, Error = Error>, M::Service: Service,
I: RequestInitialiser, I: RequestInitialiser,
{ {
client client

View file

@ -18,7 +18,6 @@ reqwest = { version = "0.11", default-features = false, features = ["json", "mul
serde = "1" serde = "1"
task-local-extensions = "0.1.1" task-local-extensions = "0.1.1"
thiserror = "1" thiserror = "1"
tower = { version = "0.4", features = ["util"] }
futures = "0.3" futures = "0.3"
[dev-dependencies] [dev-dependencies]

View file

@ -6,20 +6,18 @@ use reqwest::{Body, Client, IntoUrl, Method, Request, Response};
use serde::Serialize; use serde::Serialize;
use std::convert::TryFrom; use std::convert::TryFrom;
use std::fmt::{self, Display}; use std::fmt::{self, Display};
use std::task::{Context, Poll};
use std::time::Duration; use std::time::Duration;
use task_local_extensions::Extensions; use task_local_extensions::Extensions;
use tower::layer::util::{Identity, Stack}; // use tower::{Layer, Service, ServiceBuilder, ServiceExt};
use tower::{Layer, Service, ServiceBuilder, ServiceExt};
use crate::{Error, MiddlewareRequest, RequestInitialiser, RequestStack}; use crate::{Error, Identity, Layer, RequestInitialiser, RequestStack, Service, Stack};
/// A `ClientBuilder` is used to build a [`ClientWithMiddleware`]. /// A `ClientBuilder` is used to build a [`ClientWithMiddleware`].
/// ///
/// [`ClientWithMiddleware`]: crate::ClientWithMiddleware /// [`ClientWithMiddleware`]: crate::ClientWithMiddleware
pub struct ClientBuilder<M, I> { pub struct ClientBuilder<M, I> {
client: Client, client: Client,
middleware_stack: ServiceBuilder<M>, middleware_stack: M,
initialiser_stack: I, initialiser_stack: I,
} }
@ -27,8 +25,8 @@ impl ClientBuilder<Identity, Identity> {
pub fn new(client: Client) -> Self { pub fn new(client: Client) -> Self {
ClientBuilder { ClientBuilder {
client, client,
middleware_stack: ServiceBuilder::new(), middleware_stack: Identity,
initialiser_stack: Identity::new(), initialiser_stack: Identity,
} }
} }
} }
@ -38,7 +36,10 @@ impl<M, I> ClientBuilder<M, I> {
pub fn with<T>(self, layer: T) -> ClientBuilder<Stack<T, M>, I> { pub fn with<T>(self, layer: T) -> ClientBuilder<Stack<T, M>, I> {
ClientBuilder { ClientBuilder {
client: self.client, client: self.client,
middleware_stack: self.middleware_stack.layer(layer), middleware_stack: Stack {
inner: layer,
outer: self.middleware_stack,
},
initialiser_stack: self.initialiser_stack, initialiser_stack: self.initialiser_stack,
} }
} }
@ -70,14 +71,11 @@ impl<M, I> ClientBuilder<M, I> {
#[derive(Clone)] #[derive(Clone)]
pub struct ClientWithMiddleware<M, I> { pub struct ClientWithMiddleware<M, I> {
inner: reqwest::Client, inner: reqwest::Client,
middleware_stack: ServiceBuilder<M>, middleware_stack: M,
initialiser_stack: I, initialiser_stack: I,
} }
impl<M: Layer<ReqwestService>, I: RequestInitialiser> ClientWithMiddleware<M, I> impl<M: Layer<ReqwestService>, I: RequestInitialiser> ClientWithMiddleware<M, I> {
where
M::Service: Service<MiddlewareRequest, Response = Response, Error = Error>,
{
/// See [`Client::get`] /// See [`Client::get`]
pub fn get<U: IntoUrl>(&self, url: U) -> RequestBuilder<M, I> { pub fn get<U: IntoUrl>(&self, url: U) -> RequestBuilder<M, I> {
self.request(Method::GET, url) self.request(Method::GET, url)
@ -122,12 +120,12 @@ where
} }
/// Create a `ClientWithMiddleware` without any middleware. /// Create a `ClientWithMiddleware` without any middleware.
impl From<Client> for ClientWithMiddleware<Identity, ()> { impl From<Client> for ClientWithMiddleware<Identity, Identity> {
fn from(client: Client) -> Self { fn from(client: Client) -> Self {
ClientWithMiddleware { ClientWithMiddleware {
inner: client, inner: client,
middleware_stack: ServiceBuilder::new(), middleware_stack: Identity,
initialiser_stack: (), initialiser_stack: Identity,
} }
} }
} }
@ -152,25 +150,18 @@ pub struct RequestBuilder<'client, M, I> {
#[derive(Clone)] #[derive(Clone)]
pub struct ReqwestService(Client); pub struct ReqwestService(Client);
impl Service<MiddlewareRequest> for ReqwestService { impl Service for ReqwestService {
type Response = Response;
type Error = Error;
type Future = BoxFuture<'static, Result<Response, Error>>; type Future = BoxFuture<'static, Result<Response, Error>>;
fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll<Result<(), Error>> { fn call(&mut self, req: Request, _: &mut Extensions) -> Self::Future {
Poll::Ready(Ok(())) let fut = self.0.execute(req);
} async { fut.await.map_err(Error::from) }.boxed()
fn call(&mut self, req: MiddlewareRequest) -> Self::Future {
let req = req.request;
let client = self.0.clone();
async move { client.execute(req).await.map_err(Error::from) }.boxed()
} }
} }
impl<M: Layer<ReqwestService>, I: RequestInitialiser> RequestBuilder<'_, M, I> impl<M: Layer<ReqwestService>, I: RequestInitialiser> RequestBuilder<'_, M, I>
where where
M::Service: Service<MiddlewareRequest, Response = Response, Error = Error>, M::Service: Service,
{ {
pub fn header<K, V>(self, key: K, value: V) -> Self pub fn header<K, V>(self, key: K, value: V) -> Self
where where
@ -274,17 +265,13 @@ where
let Self { let Self {
inner, inner,
client, client,
extensions, mut extensions,
} = self; } = self;
let req = inner.build()?; let req = inner.build()?;
client let mut svc = client
.middleware_stack .middleware_stack
.service(ReqwestService(client.inner.clone())) .layer(ReqwestService(client.inner.clone()));
.oneshot(MiddlewareRequest { svc.call(req, &mut extensions).await
request: req,
extensions,
})
.await
// client.execute_with_extensions(req, &mut extensions).await // client.execute_with_extensions(req, &mut extensions).await
} }

View file

@ -7,7 +7,7 @@
//! //!
//! ``` //! ```
//! use reqwest::{Client, Request, Response}; //! use reqwest::{Client, Request, Response};
//! use reqwest_middleware::{ClientBuilder, Error, Extension, MiddlewareRequest}; //! use reqwest_middleware::{ClientBuilder, Error, Extension, Layer, Service};
//! use task_local_extensions::Extensions; //! use task_local_extensions::Extensions;
//! use futures::future::{BoxFuture, FutureExt}; //! use futures::future::{BoxFuture, FutureExt};
//! use std::task::{Context, Poll}; //! use std::task::{Context, Poll};
@ -15,7 +15,7 @@
//! struct LoggingLayer; //! struct LoggingLayer;
//! struct LoggingService<S>(S); //! struct LoggingService<S>(S);
//! //!
//! impl<S> tower::Layer<S> for LoggingLayer { //! impl<S> Layer<S> for LoggingLayer {
//! type Service = LoggingService<S>; //! type Service = LoggingService<S>;
//! //!
//! fn layer(&self, inner: S) -> Self::Service { //! fn layer(&self, inner: S) -> Self::Service {
@ -23,25 +23,19 @@
//! } //! }
//! } //! }
//! //!
//! impl<S> tower::Service<MiddlewareRequest> for LoggingService<S> //! impl<S> Service for LoggingService<S>
//! where //! where
//! S: tower::Service<MiddlewareRequest, Response = Response, Error = Error>, //! S: Service,
//! S::Future: Send + 'static, //! S::Future: Send + 'static,
//! { //! {
//! type Response = Response;
//! type Error = Error;
//! type Future = BoxFuture<'static, Result<Response, Error>>; //! type Future = BoxFuture<'static, Result<Response, Error>>;
//! //!
//! fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> { //! fn call(&mut self, req: Request, ext: &mut Extensions) -> Self::Future {
//! self.0.poll_ready(cx) //! println!("Request started {req:?}");
//! } //! let fut = self.0.call(req, ext);
//!
//! fn call(&mut self, req: MiddlewareRequest) -> Self::Future {
//! println!("Request started {:?}", &req.request);
//! let fut = self.0.call(req);
//! async { //! async {
//! let res = fut.await; //! let res = fut.await;
//! println!("Result: {:?}", res); //! println!("Result: {res:?}");
//! res //! res
//! }.boxed() //! }.boxed()
//! } //! }
@ -76,8 +70,49 @@ mod req_init;
pub use client::{ClientBuilder, ClientWithMiddleware, RequestBuilder, ReqwestService}; pub use client::{ClientBuilder, ClientWithMiddleware, RequestBuilder, ReqwestService};
pub use error::Error; pub use error::Error;
pub use req_init::{Extension, RequestInitialiser, RequestStack}; pub use req_init::{Extension, RequestInitialiser, RequestStack};
use reqwest::{Request, Response};
use task_local_extensions::Extensions;
pub struct MiddlewareRequest { /// Two [`RequestInitialiser`]s or [`Service`]s chained together.
pub request: reqwest::Request, #[derive(Clone)]
pub extensions: task_local_extensions::Extensions, pub struct Stack<Inner, Outer> {
pub(crate) inner: Inner,
pub(crate) outer: Outer,
}
pub trait Service {
type Future: std::future::Future<Output = Result<Response, Error>>;
fn call(&mut self, req: Request, ext: &mut Extensions) -> Self::Future;
}
pub struct Identity;
impl<S: Service> Layer<S> for Identity {
type Service = S;
fn layer(&self, inner: S) -> Self::Service {
inner
}
}
pub trait Layer<S> {
/// The wrapped service
type Service;
/// Wrap the given service with the middleware, returning a new service
/// that has been decorated with the middleware.
fn layer(&self, inner: S) -> Self::Service;
}
impl<S, Inner, Outer> Layer<S> for Stack<Inner, Outer>
where
Inner: Layer<S>,
Outer: Layer<Inner::Service>,
{
type Service = Outer::Service;
fn layer(&self, service: S) -> Self::Service {
let inner = self.inner.layer(service);
self.outer.layer(inner)
}
} }

View file

@ -1,6 +1,7 @@
use reqwest::RequestBuilder; use reqwest::RequestBuilder;
use task_local_extensions::Extensions; use task_local_extensions::Extensions;
use tower::layer::util::Identity;
use crate::Identity;
/// When attached to a [`ClientWithMiddleware`] (generally using [`with_init`]), it is run /// When attached to a [`ClientWithMiddleware`] (generally using [`with_init`]), it is run
/// whenever the client starts building a request, in the order it was attached. /// whenever the client starts building a request, in the order it was attached.
@ -56,8 +57,8 @@ where
/// This is a good way to inject extensions to middleware deeper in the stack /// This is a good way to inject extensions to middleware deeper in the stack
/// ///
/// ``` /// ```
/// use reqwest::{Client, RequestBuilder, Response}; /// use reqwest::{Client, Request, Response};
/// use reqwest_middleware::{ClientBuilder, Error, Extension, MiddlewareRequest}; /// use reqwest_middleware::{ClientBuilder, Error, Extension, Layer, Service};
/// use task_local_extensions::Extensions; /// use task_local_extensions::Extensions;
/// use futures::future::{BoxFuture, FutureExt}; /// use futures::future::{BoxFuture, FutureExt};
/// use std::task::{Context, Poll}; /// use std::task::{Context, Poll};
@ -68,7 +69,7 @@ where
/// struct LoggingLayer; /// struct LoggingLayer;
/// struct LoggingService<S>(S); /// struct LoggingService<S>(S);
/// ///
/// impl<S> tower::Layer<S> for LoggingLayer { /// impl<S> Layer<S> for LoggingLayer {
/// type Service = LoggingService<S>; /// type Service = LoggingService<S>;
/// ///
/// fn layer(&self, inner: S) -> Self::Service { /// fn layer(&self, inner: S) -> Self::Service {
@ -76,28 +77,21 @@ where
/// } /// }
/// } /// }
/// ///
/// impl<S> tower::Service<MiddlewareRequest> for LoggingService<S> /// impl<S> Service for LoggingService<S>
/// where /// where
/// S: tower::Service<MiddlewareRequest, Response = Response, Error = Error>, /// S: Service,
/// S::Future: Send + 'static, /// S::Future: Send + 'static,
/// { /// {
/// type Response = Response;
/// type Error = Error;
/// type Future = BoxFuture<'static, Result<Response, Error>>; /// type Future = BoxFuture<'static, Result<Response, Error>>;
/// ///
/// fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> { /// fn call(&mut self, req: Request, ext: &mut Extensions) -> Self::Future {
/// self.0.poll_ready(cx)
/// }
///
/// fn call(&mut self, req: MiddlewareRequest) -> Self::Future {
/// // get the log name or default to "unknown" /// // get the log name or default to "unknown"
/// let name = req /// let name = ext
/// .extensions
/// .get() /// .get()
/// .map(|&LogName(name)| name) /// .map(|&LogName(name)| name)
/// .unwrap_or("unknown"); /// .unwrap_or("unknown");
/// println!("[{name}] Request started {:?}", &req.request); /// println!("[{name}] Request started {req:?}");
/// let fut = self.0.call(req); /// let fut = self.0.call(req, ext);
/// async move { /// async move {
/// let res = fut.await; /// let res = fut.await;
/// println!("[{name}] Result: {res:?}"); /// println!("[{name}] Result: {res:?}");

View file

@ -23,7 +23,6 @@ retry-policies = "0.1"
task-local-extensions = "0.1.1" task-local-extensions = "0.1.1"
tokio = { version = "1.6", features = ["time"] } tokio = { version = "1.6", features = ["time"] }
tracing = "0.1.26" tracing = "0.1.26"
tower = { version = "0.4", features = ["retry"] }
pin-project-lite = "0.2" pin-project-lite = "0.2"
[dev-dependencies] [dev-dependencies]

View file

@ -7,13 +7,11 @@ use crate::retryable::Retryable;
use chrono::Utc; use chrono::Utc;
use futures::Future; use futures::Future;
use pin_project_lite::pin_project; use pin_project_lite::pin_project;
use reqwest::Response; use reqwest::{Request, Response};
use reqwest_middleware::{Error, MiddlewareRequest}; use reqwest_middleware::{Error, Layer, Service};
use retry_policies::RetryPolicy; use retry_policies::RetryPolicy;
use task_local_extensions::Extensions; use task_local_extensions::Extensions;
use tokio::time::Sleep; use tokio::time::Sleep;
use tower::retry::{Policy, Retry};
use tower::{Layer, Service};
/// `RetryTransientMiddleware` offers retry logic for requests that fail in a transient manner /// `RetryTransientMiddleware` offers retry logic for requests that fail in a transient manner
/// and can be safely executed again. /// and can be safely executed again.
@ -62,20 +60,20 @@ impl<T: RetryPolicy + Send + Sync> RetryTransientMiddleware<T> {
} }
} }
impl<T: RetryPolicy + Clone + Send + Sync + 'static, Svc> Layer<Svc> for RetryTransientMiddleware<T> impl<T, Svc> Layer<Svc> for RetryTransientMiddleware<T>
where where
Svc: Service<MiddlewareRequest, Response = Response, Error = Error>, T: RetryPolicy + Clone + Send + Sync + 'static,
{ {
type Service = Retry<TowerRetryPolicy<T>, Svc>; type Service = Retry<TowerRetryPolicy<T>, Svc>;
fn layer(&self, inner: Svc) -> Self::Service { fn layer(&self, inner: Svc) -> Self::Service {
Retry::new( Retry {
TowerRetryPolicy { policy: TowerRetryPolicy {
n_past_retries: 0, n_past_retries: 0,
retry_policy: self.retry_policy.clone(), retry_policy: self.retry_policy.clone(),
}, },
inner, service: inner,
) }
} }
} }
@ -108,14 +106,10 @@ impl<T> Future for RetryFuture<T> {
} }
} }
impl<T: RetryPolicy + Clone> Policy<MiddlewareRequest, Response, Error> for TowerRetryPolicy<T> { impl<T: RetryPolicy + Clone> Policy for TowerRetryPolicy<T> {
type Future = RetryFuture<T>; type Future = RetryFuture<T>;
fn retry( fn retry(&self, _req: &Request, result: &Result<Response, Error>) -> Option<Self::Future> {
&self,
_req: &MiddlewareRequest,
result: std::result::Result<&Response, &Error>,
) -> Option<Self::Future> {
// We classify the response which will return None if not // We classify the response which will return None if not
// errors were returned. // errors were returned.
match Retryable::from_reqwest_response(result) { match Retryable::from_reqwest_response(result) {
@ -147,10 +141,172 @@ impl<T: RetryPolicy + Clone> Policy<MiddlewareRequest, Response, Error> for Towe
} }
} }
fn clone_request(&self, req: &MiddlewareRequest) -> Option<MiddlewareRequest> { fn clone_request(&self, req: &Request) -> Option<Request> {
Some(MiddlewareRequest { req.try_clone()
request: req.request.try_clone()?, }
extensions: Extensions::new(), }
})
pub trait Policy: Sized {
/// The [`Future`] type returned by [`Policy::retry`].
type Future: Future<Output = Self>;
/// Check the policy if a certain request should be retried.
///
/// This method is passed a reference to the original request, and either
/// the [`Service::Response`] or [`Service::Error`] from the inner service.
///
/// If the request should **not** be retried, return `None`.
///
/// If the request *should* be retried, return `Some` future of a new
/// policy that would apply for the next request attempt.
///
/// [`Service::Response`]: crate::Service::Response
/// [`Service::Error`]: crate::Service::Error
fn retry(&self, req: &Request, result: &Result<Response, Error>) -> Option<Self::Future>;
/// Tries to clone a request before being passed to the inner service.
///
/// If the request cannot be cloned, return [`None`].
fn clone_request(&self, req: &Request) -> Option<Request>;
}
pin_project! {
/// Configure retrying requests of "failed" responses.
///
/// A [`Policy`] classifies what is a "failed" response.
#[derive(Clone, Debug)]
pub struct Retry<P, S> {
#[pin]
policy: P,
service: S,
}
}
impl<P, S> Service for Retry<P, S>
where
P: 'static + Policy + Clone,
S: 'static + Service + Clone,
{
type Future = ResponseFuture<P, S>;
fn call(&mut self, request: Request, ext: &mut Extensions) -> Self::Future {
let cloned = self.policy.clone_request(&request);
let future = self.service.call(request, ext);
ResponseFuture::new(cloned, self.clone(), future)
}
// fn call(&mut self, request: Request) -> Self::Future {
// let cloned = self.policy.clone_request(&request);
// let future = self.service.call(request);
// ResponseFuture::new(cloned, self.clone(), future)
// }
}
pin_project! {
/// The [`Future`] returned by a [`Retry`] service.
#[derive(Debug)]
pub struct ResponseFuture<P, S>
where
P: Policy,
S: Service,
{
request: Option<Request>,
#[pin]
retry: Retry<P, S>,
#[pin]
state: State<S::Future, P::Future>,
}
}
pin_project! {
#[project = StateProj]
#[derive(Debug)]
enum State<F, P> {
// Polling the future from [`Service::call`]
Called {
#[pin]
future: F
},
// Polling the future from [`Policy::retry`]
Checking {
#[pin]
checking: P
},
// Polling [`Service::poll_ready`] after [`Checking`] was OK.
Retrying,
}
}
impl<P, S> ResponseFuture<P, S>
where
P: Policy,
S: Service,
{
pub(crate) fn new(
request: Option<Request>,
retry: Retry<P, S>,
future: S::Future,
) -> ResponseFuture<P, S> {
ResponseFuture {
request,
retry,
state: State::Called { future },
}
}
}
impl<P, S> Future for ResponseFuture<P, S>
where
P: Policy + Clone,
S: Service + Clone,
{
type Output = Result<Response, Error>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let mut this = self.project();
loop {
match this.state.as_mut().project() {
StateProj::Called { future } => {
let result = ready!(future.poll(cx));
if let Some(ref req) = this.request {
match this.retry.policy.retry(req, &result) {
Some(checking) => {
this.state.set(State::Checking { checking });
}
None => return Poll::Ready(result),
}
} else {
// request wasn't cloned, so no way to retry it
return Poll::Ready(result);
}
}
StateProj::Checking { checking } => {
this.retry
.as_mut()
.project()
.policy
.set(ready!(checking.poll(cx)));
this.state.set(State::Retrying);
}
StateProj::Retrying => {
let req = this
.request
.take()
.expect("retrying requires cloned request");
*this.request = this.retry.policy.clone_request(&req);
this.state.set(State::Called {
future: this
.retry
.as_mut()
.project()
.service
.call(req, &mut Extensions::new()),
});
}
}
}
} }
} }

View file

@ -15,7 +15,7 @@ impl Retryable {
/// ///
/// Returns `None` if the response object does not contain any errors. /// Returns `None` if the response object does not contain any errors.
/// ///
pub fn from_reqwest_response(res: Result<&reqwest::Response, &Error>) -> Option<Self> { pub fn from_reqwest_response(res: &Result<reqwest::Response, Error>) -> Option<Self> {
match res { match res {
Ok(success) => { Ok(success) => {
let status = success.status(); let status = success.status();

View file

@ -25,7 +25,6 @@ async-trait = "0.1.51"
reqwest = { version = "0.11", default-features = false } reqwest = { version = "0.11", default-features = false }
task-local-extensions = "0.1.1" task-local-extensions = "0.1.1"
tracing = "0.1.26" tracing = "0.1.26"
tower = "0.4"
pin-project-lite = "0.2" pin-project-lite = "0.2"
opentelemetry_0_13_pkg = { package = "opentelemetry", version = "0.13", optional = true } opentelemetry_0_13_pkg = { package = "opentelemetry", version = "0.13", optional = true }

View file

@ -4,9 +4,10 @@ use std::{
}; };
use pin_project_lite::pin_project; use pin_project_lite::pin_project;
use reqwest::Response; use reqwest::{Request, Response};
use reqwest_middleware::{Error, MiddlewareRequest}; use reqwest_middleware::{Error, Layer, Service};
use tower::{Layer, Service}; use task_local_extensions::Extensions;
// use tower::{Layer, Service};
use tracing::Span; use tracing::Span;
use crate::{DefaultSpanBackend, ReqwestOtelSpanBackend}; use crate::{DefaultSpanBackend, ReqwestOtelSpanBackend};
@ -41,6 +42,7 @@ impl Default for TracingMiddleware<DefaultSpanBackend> {
impl<ReqwestOtelSpan, Svc> Layer<Svc> for TracingMiddleware<ReqwestOtelSpan> impl<ReqwestOtelSpan, Svc> Layer<Svc> for TracingMiddleware<ReqwestOtelSpan>
where where
ReqwestOtelSpan: ReqwestOtelSpanBackend + Sync + Send + 'static, ReqwestOtelSpan: ReqwestOtelSpanBackend + Sync + Send + 'static,
Svc: Service,
{ {
type Service = TracingMiddlewareService<ReqwestOtelSpan, Svc>; type Service = TracingMiddlewareService<ReqwestOtelSpan, Svc>;
@ -58,26 +60,15 @@ pub struct TracingMiddlewareService<S: ReqwestOtelSpanBackend, Svc> {
service: Svc, service: Svc,
} }
impl<ReqwestOtelSpan, Svc> Service<MiddlewareRequest> impl<ReqwestOtelSpan, Svc> Service for TracingMiddlewareService<ReqwestOtelSpan, Svc>
for TracingMiddlewareService<ReqwestOtelSpan, Svc>
where where
ReqwestOtelSpan: ReqwestOtelSpanBackend + Sync + Send + 'static, ReqwestOtelSpan: ReqwestOtelSpanBackend + Sync + Send + 'static,
Svc: Service<MiddlewareRequest, Response = Response, Error = Error>, Svc: Service,
{ {
type Response = Response;
type Error = Error;
type Future = TracingMiddlewareFuture<ReqwestOtelSpan, Svc::Future>; type Future = TracingMiddlewareFuture<ReqwestOtelSpan, Svc::Future>;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> { fn call(&mut self, req: Request, ext: &mut Extensions) -> Self::Future {
self.service.poll_ready(cx) let (backend, span) = ReqwestOtelSpan::on_request_start(&req, ext);
}
fn call(&mut self, req: MiddlewareRequest) -> Self::Future {
let MiddlewareRequest {
request,
mut extensions,
} = req;
let (backend, span) = ReqwestOtelSpan::on_request_start(&request, &mut extensions);
// Adds tracing headers to the given request to propagate the OpenTelemetry context to downstream revivers of the request. // Adds tracing headers to the given request to propagate the OpenTelemetry context to downstream revivers of the request.
// Spans added by downstream consumers will be part of the same trace. // Spans added by downstream consumers will be part of the same trace.
#[cfg(any( #[cfg(any(
@ -90,10 +81,7 @@ where
))] ))]
let request = crate::otel::inject_opentelemetry_context_into_request(request); let request = crate::otel::inject_opentelemetry_context_into_request(request);
let future = self.service.call(MiddlewareRequest { let future = self.service.call(req, ext);
request,
extensions,
});
TracingMiddlewareFuture { TracingMiddlewareFuture {
span, span,