diff --git a/README.md b/README.md index b987255..3e07e7b 100644 --- a/README.md +++ b/README.md @@ -34,7 +34,7 @@ tower = "0.4" ```rust use reqwest::Response; -use reqwest_middleware::{ClientBuilder, ClientWithMiddleware, Error, MiddlewareRequest, RequestInitialiser, ReqwestService}; +use reqwest_middleware::{ClientBuilder, ClientWithMiddleware, Error, Layer, RequestInitialiser, ReqwestService, Service}; use reqwest_retry::{RetryTransientMiddleware, policies::ExponentialBackoff}; use reqwest_tracing::TracingMiddleware; @@ -53,8 +53,8 @@ async fn main() { async fn run(client: ClientWithMiddleware) where - M: tower::Layer, - M::Service: tower::Service, + M: Layer, + M::Service: Service, I: RequestInitialiser, { client diff --git a/reqwest-middleware/Cargo.toml b/reqwest-middleware/Cargo.toml index 5b969d1..8e81a82 100644 --- a/reqwest-middleware/Cargo.toml +++ b/reqwest-middleware/Cargo.toml @@ -18,7 +18,6 @@ reqwest = { version = "0.11", default-features = false, features = ["json", "mul serde = "1" task-local-extensions = "0.1.1" thiserror = "1" -tower = { version = "0.4", features = ["util"] } futures = "0.3" [dev-dependencies] diff --git a/reqwest-middleware/src/client.rs b/reqwest-middleware/src/client.rs index 033179c..6f2964c 100644 --- a/reqwest-middleware/src/client.rs +++ b/reqwest-middleware/src/client.rs @@ -6,20 +6,18 @@ use reqwest::{Body, Client, IntoUrl, Method, Request, Response}; use serde::Serialize; use std::convert::TryFrom; use std::fmt::{self, Display}; -use std::task::{Context, Poll}; use std::time::Duration; use task_local_extensions::Extensions; -use tower::layer::util::{Identity, Stack}; -use tower::{Layer, Service, ServiceBuilder, ServiceExt}; +// use tower::{Layer, Service, ServiceBuilder, ServiceExt}; -use crate::{Error, MiddlewareRequest, RequestInitialiser, RequestStack}; +use crate::{Error, Identity, Layer, RequestInitialiser, RequestStack, Service, Stack}; /// A `ClientBuilder` is used to build a [`ClientWithMiddleware`]. /// /// [`ClientWithMiddleware`]: crate::ClientWithMiddleware pub struct ClientBuilder { client: Client, - middleware_stack: ServiceBuilder, + middleware_stack: M, initialiser_stack: I, } @@ -27,8 +25,8 @@ impl ClientBuilder { pub fn new(client: Client) -> Self { ClientBuilder { client, - middleware_stack: ServiceBuilder::new(), - initialiser_stack: Identity::new(), + middleware_stack: Identity, + initialiser_stack: Identity, } } } @@ -38,7 +36,10 @@ impl ClientBuilder { pub fn with(self, layer: T) -> ClientBuilder, I> { ClientBuilder { client: self.client, - middleware_stack: self.middleware_stack.layer(layer), + middleware_stack: Stack { + inner: layer, + outer: self.middleware_stack, + }, initialiser_stack: self.initialiser_stack, } } @@ -70,14 +71,11 @@ impl ClientBuilder { #[derive(Clone)] pub struct ClientWithMiddleware { inner: reqwest::Client, - middleware_stack: ServiceBuilder, + middleware_stack: M, initialiser_stack: I, } -impl, I: RequestInitialiser> ClientWithMiddleware -where - M::Service: Service, -{ +impl, I: RequestInitialiser> ClientWithMiddleware { /// See [`Client::get`] pub fn get(&self, url: U) -> RequestBuilder { self.request(Method::GET, url) @@ -122,12 +120,12 @@ where } /// Create a `ClientWithMiddleware` without any middleware. -impl From for ClientWithMiddleware { +impl From for ClientWithMiddleware { fn from(client: Client) -> Self { ClientWithMiddleware { inner: client, - middleware_stack: ServiceBuilder::new(), - initialiser_stack: (), + middleware_stack: Identity, + initialiser_stack: Identity, } } } @@ -152,25 +150,18 @@ pub struct RequestBuilder<'client, M, I> { #[derive(Clone)] pub struct ReqwestService(Client); -impl Service for ReqwestService { - type Response = Response; - type Error = Error; +impl Service for ReqwestService { type Future = BoxFuture<'static, Result>; - fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll> { - Poll::Ready(Ok(())) - } - - fn call(&mut self, req: MiddlewareRequest) -> Self::Future { - let req = req.request; - let client = self.0.clone(); - async move { client.execute(req).await.map_err(Error::from) }.boxed() + fn call(&mut self, req: Request, _: &mut Extensions) -> Self::Future { + let fut = self.0.execute(req); + async { fut.await.map_err(Error::from) }.boxed() } } impl, I: RequestInitialiser> RequestBuilder<'_, M, I> where - M::Service: Service, + M::Service: Service, { pub fn header(self, key: K, value: V) -> Self where @@ -274,17 +265,13 @@ where let Self { inner, client, - extensions, + mut extensions, } = self; let req = inner.build()?; - client + let mut svc = client .middleware_stack - .service(ReqwestService(client.inner.clone())) - .oneshot(MiddlewareRequest { - request: req, - extensions, - }) - .await + .layer(ReqwestService(client.inner.clone())); + svc.call(req, &mut extensions).await // client.execute_with_extensions(req, &mut extensions).await } diff --git a/reqwest-middleware/src/lib.rs b/reqwest-middleware/src/lib.rs index d0fd72c..d0894db 100644 --- a/reqwest-middleware/src/lib.rs +++ b/reqwest-middleware/src/lib.rs @@ -7,7 +7,7 @@ //! //! ``` //! use reqwest::{Client, Request, Response}; -//! use reqwest_middleware::{ClientBuilder, Error, Extension, MiddlewareRequest}; +//! use reqwest_middleware::{ClientBuilder, Error, Extension, Layer, Service}; //! use task_local_extensions::Extensions; //! use futures::future::{BoxFuture, FutureExt}; //! use std::task::{Context, Poll}; @@ -15,7 +15,7 @@ //! struct LoggingLayer; //! struct LoggingService(S); //! -//! impl tower::Layer for LoggingLayer { +//! impl Layer for LoggingLayer { //! type Service = LoggingService; //! //! fn layer(&self, inner: S) -> Self::Service { @@ -23,25 +23,19 @@ //! } //! } //! -//! impl tower::Service for LoggingService +//! impl Service for LoggingService //! where -//! S: tower::Service, +//! S: Service, //! S::Future: Send + 'static, //! { -//! type Response = Response; -//! type Error = Error; //! type Future = BoxFuture<'static, Result>; -//! -//! fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { -//! self.0.poll_ready(cx) -//! } //! -//! fn call(&mut self, req: MiddlewareRequest) -> Self::Future { -//! println!("Request started {:?}", &req.request); -//! let fut = self.0.call(req); +//! fn call(&mut self, req: Request, ext: &mut Extensions) -> Self::Future { +//! println!("Request started {req:?}"); +//! let fut = self.0.call(req, ext); //! async { //! let res = fut.await; -//! println!("Result: {:?}", res); +//! println!("Result: {res:?}"); //! res //! }.boxed() //! } @@ -76,8 +70,49 @@ mod req_init; pub use client::{ClientBuilder, ClientWithMiddleware, RequestBuilder, ReqwestService}; pub use error::Error; pub use req_init::{Extension, RequestInitialiser, RequestStack}; +use reqwest::{Request, Response}; +use task_local_extensions::Extensions; -pub struct MiddlewareRequest { - pub request: reqwest::Request, - pub extensions: task_local_extensions::Extensions, +/// Two [`RequestInitialiser`]s or [`Service`]s chained together. +#[derive(Clone)] +pub struct Stack { + pub(crate) inner: Inner, + pub(crate) outer: Outer, +} + +pub trait Service { + type Future: std::future::Future>; + fn call(&mut self, req: Request, ext: &mut Extensions) -> Self::Future; +} + +pub struct Identity; + +impl Layer for Identity { + type Service = S; + + fn layer(&self, inner: S) -> Self::Service { + inner + } +} + +pub trait Layer { + /// The wrapped service + type Service; + /// Wrap the given service with the middleware, returning a new service + /// that has been decorated with the middleware. + fn layer(&self, inner: S) -> Self::Service; +} + +impl Layer for Stack +where + Inner: Layer, + Outer: Layer, +{ + type Service = Outer::Service; + + fn layer(&self, service: S) -> Self::Service { + let inner = self.inner.layer(service); + + self.outer.layer(inner) + } } diff --git a/reqwest-middleware/src/req_init.rs b/reqwest-middleware/src/req_init.rs index 75d8aa3..625927e 100644 --- a/reqwest-middleware/src/req_init.rs +++ b/reqwest-middleware/src/req_init.rs @@ -1,6 +1,7 @@ use reqwest::RequestBuilder; use task_local_extensions::Extensions; -use tower::layer::util::Identity; + +use crate::Identity; /// When attached to a [`ClientWithMiddleware`] (generally using [`with_init`]), it is run /// whenever the client starts building a request, in the order it was attached. @@ -56,8 +57,8 @@ where /// This is a good way to inject extensions to middleware deeper in the stack /// /// ``` -/// use reqwest::{Client, RequestBuilder, Response}; -/// use reqwest_middleware::{ClientBuilder, Error, Extension, MiddlewareRequest}; +/// use reqwest::{Client, Request, Response}; +/// use reqwest_middleware::{ClientBuilder, Error, Extension, Layer, Service}; /// use task_local_extensions::Extensions; /// use futures::future::{BoxFuture, FutureExt}; /// use std::task::{Context, Poll}; @@ -68,7 +69,7 @@ where /// struct LoggingLayer; /// struct LoggingService(S); /// -/// impl tower::Layer for LoggingLayer { +/// impl Layer for LoggingLayer { /// type Service = LoggingService; /// /// fn layer(&self, inner: S) -> Self::Service { @@ -76,28 +77,21 @@ where /// } /// } /// -/// impl tower::Service for LoggingService +/// impl Service for LoggingService /// where -/// S: tower::Service, +/// S: Service, /// S::Future: Send + 'static, /// { -/// type Response = Response; -/// type Error = Error; /// type Future = BoxFuture<'static, Result>; -/// -/// fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { -/// self.0.poll_ready(cx) -/// } /// -/// fn call(&mut self, req: MiddlewareRequest) -> Self::Future { +/// fn call(&mut self, req: Request, ext: &mut Extensions) -> Self::Future { /// // get the log name or default to "unknown" -/// let name = req -/// .extensions +/// let name = ext /// .get() /// .map(|&LogName(name)| name) /// .unwrap_or("unknown"); -/// println!("[{name}] Request started {:?}", &req.request); -/// let fut = self.0.call(req); +/// println!("[{name}] Request started {req:?}"); +/// let fut = self.0.call(req, ext); /// async move { /// let res = fut.await; /// println!("[{name}] Result: {res:?}"); diff --git a/reqwest-retry/Cargo.toml b/reqwest-retry/Cargo.toml index a0098a7..9c2e060 100644 --- a/reqwest-retry/Cargo.toml +++ b/reqwest-retry/Cargo.toml @@ -23,7 +23,6 @@ retry-policies = "0.1" task-local-extensions = "0.1.1" tokio = { version = "1.6", features = ["time"] } tracing = "0.1.26" -tower = { version = "0.4", features = ["retry"] } pin-project-lite = "0.2" [dev-dependencies] diff --git a/reqwest-retry/src/middleware.rs b/reqwest-retry/src/middleware.rs index 65f07bd..6c2d31c 100644 --- a/reqwest-retry/src/middleware.rs +++ b/reqwest-retry/src/middleware.rs @@ -7,13 +7,11 @@ use crate::retryable::Retryable; use chrono::Utc; use futures::Future; use pin_project_lite::pin_project; -use reqwest::Response; -use reqwest_middleware::{Error, MiddlewareRequest}; +use reqwest::{Request, Response}; +use reqwest_middleware::{Error, Layer, Service}; use retry_policies::RetryPolicy; use task_local_extensions::Extensions; use tokio::time::Sleep; -use tower::retry::{Policy, Retry}; -use tower::{Layer, Service}; /// `RetryTransientMiddleware` offers retry logic for requests that fail in a transient manner /// and can be safely executed again. @@ -62,20 +60,20 @@ impl RetryTransientMiddleware { } } -impl Layer for RetryTransientMiddleware +impl Layer for RetryTransientMiddleware where - Svc: Service, + T: RetryPolicy + Clone + Send + Sync + 'static, { type Service = Retry, Svc>; fn layer(&self, inner: Svc) -> Self::Service { - Retry::new( - TowerRetryPolicy { + Retry { + policy: TowerRetryPolicy { n_past_retries: 0, retry_policy: self.retry_policy.clone(), }, - inner, - ) + service: inner, + } } } @@ -108,14 +106,10 @@ impl Future for RetryFuture { } } -impl Policy for TowerRetryPolicy { +impl Policy for TowerRetryPolicy { type Future = RetryFuture; - fn retry( - &self, - _req: &MiddlewareRequest, - result: std::result::Result<&Response, &Error>, - ) -> Option { + fn retry(&self, _req: &Request, result: &Result) -> Option { // We classify the response which will return None if not // errors were returned. match Retryable::from_reqwest_response(result) { @@ -147,10 +141,172 @@ impl Policy for Towe } } - fn clone_request(&self, req: &MiddlewareRequest) -> Option { - Some(MiddlewareRequest { - request: req.request.try_clone()?, - extensions: Extensions::new(), - }) + fn clone_request(&self, req: &Request) -> Option { + req.try_clone() + } +} + +pub trait Policy: Sized { + /// The [`Future`] type returned by [`Policy::retry`]. + type Future: Future; + + /// Check the policy if a certain request should be retried. + /// + /// This method is passed a reference to the original request, and either + /// the [`Service::Response`] or [`Service::Error`] from the inner service. + /// + /// If the request should **not** be retried, return `None`. + /// + /// If the request *should* be retried, return `Some` future of a new + /// policy that would apply for the next request attempt. + /// + /// [`Service::Response`]: crate::Service::Response + /// [`Service::Error`]: crate::Service::Error + fn retry(&self, req: &Request, result: &Result) -> Option; + + /// Tries to clone a request before being passed to the inner service. + /// + /// If the request cannot be cloned, return [`None`]. + fn clone_request(&self, req: &Request) -> Option; +} + +pin_project! { + /// Configure retrying requests of "failed" responses. + /// + /// A [`Policy`] classifies what is a "failed" response. + #[derive(Clone, Debug)] + pub struct Retry { + #[pin] + policy: P, + service: S, + } +} + +impl Service for Retry +where + P: 'static + Policy + Clone, + S: 'static + Service + Clone, +{ + type Future = ResponseFuture; + + fn call(&mut self, request: Request, ext: &mut Extensions) -> Self::Future { + let cloned = self.policy.clone_request(&request); + let future = self.service.call(request, ext); + + ResponseFuture::new(cloned, self.clone(), future) + } + + // fn call(&mut self, request: Request) -> Self::Future { + // let cloned = self.policy.clone_request(&request); + // let future = self.service.call(request); + + // ResponseFuture::new(cloned, self.clone(), future) + // } +} + +pin_project! { + /// The [`Future`] returned by a [`Retry`] service. + #[derive(Debug)] + pub struct ResponseFuture + where + P: Policy, + S: Service, + { + request: Option, + #[pin] + retry: Retry, + #[pin] + state: State, + } +} + +pin_project! { + #[project = StateProj] + #[derive(Debug)] + enum State { + // Polling the future from [`Service::call`] + Called { + #[pin] + future: F + }, + // Polling the future from [`Policy::retry`] + Checking { + #[pin] + checking: P + }, + // Polling [`Service::poll_ready`] after [`Checking`] was OK. + Retrying, + } +} + +impl ResponseFuture +where + P: Policy, + S: Service, +{ + pub(crate) fn new( + request: Option, + retry: Retry, + future: S::Future, + ) -> ResponseFuture { + ResponseFuture { + request, + retry, + state: State::Called { future }, + } + } +} + +impl Future for ResponseFuture +where + P: Policy + Clone, + S: Service + Clone, +{ + type Output = Result; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let mut this = self.project(); + + loop { + match this.state.as_mut().project() { + StateProj::Called { future } => { + let result = ready!(future.poll(cx)); + if let Some(ref req) = this.request { + match this.retry.policy.retry(req, &result) { + Some(checking) => { + this.state.set(State::Checking { checking }); + } + None => return Poll::Ready(result), + } + } else { + // request wasn't cloned, so no way to retry it + return Poll::Ready(result); + } + } + StateProj::Checking { checking } => { + this.retry + .as_mut() + .project() + .policy + .set(ready!(checking.poll(cx))); + this.state.set(State::Retrying); + } + StateProj::Retrying => { + let req = this + .request + .take() + .expect("retrying requires cloned request"); + *this.request = this.retry.policy.clone_request(&req); + this.state.set(State::Called { + future: this + .retry + .as_mut() + .project() + .service + .call(req, &mut Extensions::new()), + }); + } + } + } } } diff --git a/reqwest-retry/src/retryable.rs b/reqwest-retry/src/retryable.rs index 2a36deb..d6e4cae 100644 --- a/reqwest-retry/src/retryable.rs +++ b/reqwest-retry/src/retryable.rs @@ -15,7 +15,7 @@ impl Retryable { /// /// Returns `None` if the response object does not contain any errors. /// - pub fn from_reqwest_response(res: Result<&reqwest::Response, &Error>) -> Option { + pub fn from_reqwest_response(res: &Result) -> Option { match res { Ok(success) => { let status = success.status(); diff --git a/reqwest-tracing/Cargo.toml b/reqwest-tracing/Cargo.toml index 9e214d2..785dadf 100644 --- a/reqwest-tracing/Cargo.toml +++ b/reqwest-tracing/Cargo.toml @@ -25,7 +25,6 @@ async-trait = "0.1.51" reqwest = { version = "0.11", default-features = false } task-local-extensions = "0.1.1" tracing = "0.1.26" -tower = "0.4" pin-project-lite = "0.2" opentelemetry_0_13_pkg = { package = "opentelemetry", version = "0.13", optional = true } diff --git a/reqwest-tracing/src/middleware.rs b/reqwest-tracing/src/middleware.rs index 762bd32..f53f67e 100644 --- a/reqwest-tracing/src/middleware.rs +++ b/reqwest-tracing/src/middleware.rs @@ -4,9 +4,10 @@ use std::{ }; use pin_project_lite::pin_project; -use reqwest::Response; -use reqwest_middleware::{Error, MiddlewareRequest}; -use tower::{Layer, Service}; +use reqwest::{Request, Response}; +use reqwest_middleware::{Error, Layer, Service}; +use task_local_extensions::Extensions; +// use tower::{Layer, Service}; use tracing::Span; use crate::{DefaultSpanBackend, ReqwestOtelSpanBackend}; @@ -41,6 +42,7 @@ impl Default for TracingMiddleware { impl Layer for TracingMiddleware where ReqwestOtelSpan: ReqwestOtelSpanBackend + Sync + Send + 'static, + Svc: Service, { type Service = TracingMiddlewareService; @@ -58,26 +60,15 @@ pub struct TracingMiddlewareService { service: Svc, } -impl Service - for TracingMiddlewareService +impl Service for TracingMiddlewareService where ReqwestOtelSpan: ReqwestOtelSpanBackend + Sync + Send + 'static, - Svc: Service, + Svc: Service, { - type Response = Response; - type Error = Error; type Future = TracingMiddlewareFuture; - fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { - self.service.poll_ready(cx) - } - - fn call(&mut self, req: MiddlewareRequest) -> Self::Future { - let MiddlewareRequest { - request, - mut extensions, - } = req; - let (backend, span) = ReqwestOtelSpan::on_request_start(&request, &mut extensions); + fn call(&mut self, req: Request, ext: &mut Extensions) -> Self::Future { + let (backend, span) = ReqwestOtelSpan::on_request_start(&req, ext); // Adds tracing headers to the given request to propagate the OpenTelemetry context to downstream revivers of the request. // Spans added by downstream consumers will be part of the same trace. #[cfg(any( @@ -90,10 +81,7 @@ where ))] let request = crate::otel::inject_opentelemetry_context_into_request(request); - let future = self.service.call(MiddlewareRequest { - request, - extensions, - }); + let future = self.service.call(req, ext); TracingMiddlewareFuture { span,