remove but inspired by tower

This commit is contained in:
Conrad Ludgate 2022-11-15 21:51:15 +00:00
parent df4990d62e
commit 8f3623eae0
No known key found for this signature in database
GPG key ID: 197E3CACA1C980B5
10 changed files with 277 additions and 120 deletions

View file

@ -34,7 +34,7 @@ tower = "0.4"
```rust
use reqwest::Response;
use reqwest_middleware::{ClientBuilder, ClientWithMiddleware, Error, MiddlewareRequest, RequestInitialiser, ReqwestService};
use reqwest_middleware::{ClientBuilder, ClientWithMiddleware, Error, Layer, RequestInitialiser, ReqwestService, Service};
use reqwest_retry::{RetryTransientMiddleware, policies::ExponentialBackoff};
use reqwest_tracing::TracingMiddleware;
@ -53,8 +53,8 @@ async fn main() {
async fn run<M, I>(client: ClientWithMiddleware<M, I>)
where
M: tower::Layer<ReqwestService>,
M::Service: tower::Service<MiddlewareRequest, Response = Response, Error = Error>,
M: Layer<ReqwestService>,
M::Service: Service,
I: RequestInitialiser,
{
client

View file

@ -18,7 +18,6 @@ reqwest = { version = "0.11", default-features = false, features = ["json", "mul
serde = "1"
task-local-extensions = "0.1.1"
thiserror = "1"
tower = { version = "0.4", features = ["util"] }
futures = "0.3"
[dev-dependencies]

View file

@ -6,20 +6,18 @@ use reqwest::{Body, Client, IntoUrl, Method, Request, Response};
use serde::Serialize;
use std::convert::TryFrom;
use std::fmt::{self, Display};
use std::task::{Context, Poll};
use std::time::Duration;
use task_local_extensions::Extensions;
use tower::layer::util::{Identity, Stack};
use tower::{Layer, Service, ServiceBuilder, ServiceExt};
// use tower::{Layer, Service, ServiceBuilder, ServiceExt};
use crate::{Error, MiddlewareRequest, RequestInitialiser, RequestStack};
use crate::{Error, Identity, Layer, RequestInitialiser, RequestStack, Service, Stack};
/// A `ClientBuilder` is used to build a [`ClientWithMiddleware`].
///
/// [`ClientWithMiddleware`]: crate::ClientWithMiddleware
pub struct ClientBuilder<M, I> {
client: Client,
middleware_stack: ServiceBuilder<M>,
middleware_stack: M,
initialiser_stack: I,
}
@ -27,8 +25,8 @@ impl ClientBuilder<Identity, Identity> {
pub fn new(client: Client) -> Self {
ClientBuilder {
client,
middleware_stack: ServiceBuilder::new(),
initialiser_stack: Identity::new(),
middleware_stack: Identity,
initialiser_stack: Identity,
}
}
}
@ -38,7 +36,10 @@ impl<M, I> ClientBuilder<M, I> {
pub fn with<T>(self, layer: T) -> ClientBuilder<Stack<T, M>, I> {
ClientBuilder {
client: self.client,
middleware_stack: self.middleware_stack.layer(layer),
middleware_stack: Stack {
inner: layer,
outer: self.middleware_stack,
},
initialiser_stack: self.initialiser_stack,
}
}
@ -70,14 +71,11 @@ impl<M, I> ClientBuilder<M, I> {
#[derive(Clone)]
pub struct ClientWithMiddleware<M, I> {
inner: reqwest::Client,
middleware_stack: ServiceBuilder<M>,
middleware_stack: M,
initialiser_stack: I,
}
impl<M: Layer<ReqwestService>, I: RequestInitialiser> ClientWithMiddleware<M, I>
where
M::Service: Service<MiddlewareRequest, Response = Response, Error = Error>,
{
impl<M: Layer<ReqwestService>, I: RequestInitialiser> ClientWithMiddleware<M, I> {
/// See [`Client::get`]
pub fn get<U: IntoUrl>(&self, url: U) -> RequestBuilder<M, I> {
self.request(Method::GET, url)
@ -122,12 +120,12 @@ where
}
/// Create a `ClientWithMiddleware` without any middleware.
impl From<Client> for ClientWithMiddleware<Identity, ()> {
impl From<Client> for ClientWithMiddleware<Identity, Identity> {
fn from(client: Client) -> Self {
ClientWithMiddleware {
inner: client,
middleware_stack: ServiceBuilder::new(),
initialiser_stack: (),
middleware_stack: Identity,
initialiser_stack: Identity,
}
}
}
@ -152,25 +150,18 @@ pub struct RequestBuilder<'client, M, I> {
#[derive(Clone)]
pub struct ReqwestService(Client);
impl Service<MiddlewareRequest> for ReqwestService {
type Response = Response;
type Error = Error;
impl Service for ReqwestService {
type Future = BoxFuture<'static, Result<Response, Error>>;
fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll<Result<(), Error>> {
Poll::Ready(Ok(()))
}
fn call(&mut self, req: MiddlewareRequest) -> Self::Future {
let req = req.request;
let client = self.0.clone();
async move { client.execute(req).await.map_err(Error::from) }.boxed()
fn call(&mut self, req: Request, _: &mut Extensions) -> Self::Future {
let fut = self.0.execute(req);
async { fut.await.map_err(Error::from) }.boxed()
}
}
impl<M: Layer<ReqwestService>, I: RequestInitialiser> RequestBuilder<'_, M, I>
where
M::Service: Service<MiddlewareRequest, Response = Response, Error = Error>,
M::Service: Service,
{
pub fn header<K, V>(self, key: K, value: V) -> Self
where
@ -274,17 +265,13 @@ where
let Self {
inner,
client,
extensions,
mut extensions,
} = self;
let req = inner.build()?;
client
let mut svc = client
.middleware_stack
.service(ReqwestService(client.inner.clone()))
.oneshot(MiddlewareRequest {
request: req,
extensions,
})
.await
.layer(ReqwestService(client.inner.clone()));
svc.call(req, &mut extensions).await
// client.execute_with_extensions(req, &mut extensions).await
}

View file

@ -7,7 +7,7 @@
//!
//! ```
//! use reqwest::{Client, Request, Response};
//! use reqwest_middleware::{ClientBuilder, Error, Extension, MiddlewareRequest};
//! use reqwest_middleware::{ClientBuilder, Error, Extension, Layer, Service};
//! use task_local_extensions::Extensions;
//! use futures::future::{BoxFuture, FutureExt};
//! use std::task::{Context, Poll};
@ -15,7 +15,7 @@
//! struct LoggingLayer;
//! struct LoggingService<S>(S);
//!
//! impl<S> tower::Layer<S> for LoggingLayer {
//! impl<S> Layer<S> for LoggingLayer {
//! type Service = LoggingService<S>;
//!
//! fn layer(&self, inner: S) -> Self::Service {
@ -23,25 +23,19 @@
//! }
//! }
//!
//! impl<S> tower::Service<MiddlewareRequest> for LoggingService<S>
//! impl<S> Service for LoggingService<S>
//! where
//! S: tower::Service<MiddlewareRequest, Response = Response, Error = Error>,
//! S: Service,
//! S::Future: Send + 'static,
//! {
//! type Response = Response;
//! type Error = Error;
//! type Future = BoxFuture<'static, Result<Response, Error>>;
//!
//! fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
//! self.0.poll_ready(cx)
//! }
//!
//! fn call(&mut self, req: MiddlewareRequest) -> Self::Future {
//! println!("Request started {:?}", &req.request);
//! let fut = self.0.call(req);
//! fn call(&mut self, req: Request, ext: &mut Extensions) -> Self::Future {
//! println!("Request started {req:?}");
//! let fut = self.0.call(req, ext);
//! async {
//! let res = fut.await;
//! println!("Result: {:?}", res);
//! println!("Result: {res:?}");
//! res
//! }.boxed()
//! }
@ -76,8 +70,49 @@ mod req_init;
pub use client::{ClientBuilder, ClientWithMiddleware, RequestBuilder, ReqwestService};
pub use error::Error;
pub use req_init::{Extension, RequestInitialiser, RequestStack};
use reqwest::{Request, Response};
use task_local_extensions::Extensions;
pub struct MiddlewareRequest {
pub request: reqwest::Request,
pub extensions: task_local_extensions::Extensions,
/// Two [`RequestInitialiser`]s or [`Service`]s chained together.
#[derive(Clone)]
pub struct Stack<Inner, Outer> {
pub(crate) inner: Inner,
pub(crate) outer: Outer,
}
pub trait Service {
type Future: std::future::Future<Output = Result<Response, Error>>;
fn call(&mut self, req: Request, ext: &mut Extensions) -> Self::Future;
}
pub struct Identity;
impl<S: Service> Layer<S> for Identity {
type Service = S;
fn layer(&self, inner: S) -> Self::Service {
inner
}
}
pub trait Layer<S> {
/// The wrapped service
type Service;
/// Wrap the given service with the middleware, returning a new service
/// that has been decorated with the middleware.
fn layer(&self, inner: S) -> Self::Service;
}
impl<S, Inner, Outer> Layer<S> for Stack<Inner, Outer>
where
Inner: Layer<S>,
Outer: Layer<Inner::Service>,
{
type Service = Outer::Service;
fn layer(&self, service: S) -> Self::Service {
let inner = self.inner.layer(service);
self.outer.layer(inner)
}
}

View file

@ -1,6 +1,7 @@
use reqwest::RequestBuilder;
use task_local_extensions::Extensions;
use tower::layer::util::Identity;
use crate::Identity;
/// When attached to a [`ClientWithMiddleware`] (generally using [`with_init`]), it is run
/// whenever the client starts building a request, in the order it was attached.
@ -56,8 +57,8 @@ where
/// This is a good way to inject extensions to middleware deeper in the stack
///
/// ```
/// use reqwest::{Client, RequestBuilder, Response};
/// use reqwest_middleware::{ClientBuilder, Error, Extension, MiddlewareRequest};
/// use reqwest::{Client, Request, Response};
/// use reqwest_middleware::{ClientBuilder, Error, Extension, Layer, Service};
/// use task_local_extensions::Extensions;
/// use futures::future::{BoxFuture, FutureExt};
/// use std::task::{Context, Poll};
@ -68,7 +69,7 @@ where
/// struct LoggingLayer;
/// struct LoggingService<S>(S);
///
/// impl<S> tower::Layer<S> for LoggingLayer {
/// impl<S> Layer<S> for LoggingLayer {
/// type Service = LoggingService<S>;
///
/// fn layer(&self, inner: S) -> Self::Service {
@ -76,28 +77,21 @@ where
/// }
/// }
///
/// impl<S> tower::Service<MiddlewareRequest> for LoggingService<S>
/// impl<S> Service for LoggingService<S>
/// where
/// S: tower::Service<MiddlewareRequest, Response = Response, Error = Error>,
/// S: Service,
/// S::Future: Send + 'static,
/// {
/// type Response = Response;
/// type Error = Error;
/// type Future = BoxFuture<'static, Result<Response, Error>>;
///
/// fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
/// self.0.poll_ready(cx)
/// }
///
/// fn call(&mut self, req: MiddlewareRequest) -> Self::Future {
/// fn call(&mut self, req: Request, ext: &mut Extensions) -> Self::Future {
/// // get the log name or default to "unknown"
/// let name = req
/// .extensions
/// let name = ext
/// .get()
/// .map(|&LogName(name)| name)
/// .unwrap_or("unknown");
/// println!("[{name}] Request started {:?}", &req.request);
/// let fut = self.0.call(req);
/// println!("[{name}] Request started {req:?}");
/// let fut = self.0.call(req, ext);
/// async move {
/// let res = fut.await;
/// println!("[{name}] Result: {res:?}");

View file

@ -23,7 +23,6 @@ retry-policies = "0.1"
task-local-extensions = "0.1.1"
tokio = { version = "1.6", features = ["time"] }
tracing = "0.1.26"
tower = { version = "0.4", features = ["retry"] }
pin-project-lite = "0.2"
[dev-dependencies]

View file

@ -7,13 +7,11 @@ use crate::retryable::Retryable;
use chrono::Utc;
use futures::Future;
use pin_project_lite::pin_project;
use reqwest::Response;
use reqwest_middleware::{Error, MiddlewareRequest};
use reqwest::{Request, Response};
use reqwest_middleware::{Error, Layer, Service};
use retry_policies::RetryPolicy;
use task_local_extensions::Extensions;
use tokio::time::Sleep;
use tower::retry::{Policy, Retry};
use tower::{Layer, Service};
/// `RetryTransientMiddleware` offers retry logic for requests that fail in a transient manner
/// and can be safely executed again.
@ -62,20 +60,20 @@ impl<T: RetryPolicy + Send + Sync> RetryTransientMiddleware<T> {
}
}
impl<T: RetryPolicy + Clone + Send + Sync + 'static, Svc> Layer<Svc> for RetryTransientMiddleware<T>
impl<T, Svc> Layer<Svc> for RetryTransientMiddleware<T>
where
Svc: Service<MiddlewareRequest, Response = Response, Error = Error>,
T: RetryPolicy + Clone + Send + Sync + 'static,
{
type Service = Retry<TowerRetryPolicy<T>, Svc>;
fn layer(&self, inner: Svc) -> Self::Service {
Retry::new(
TowerRetryPolicy {
Retry {
policy: TowerRetryPolicy {
n_past_retries: 0,
retry_policy: self.retry_policy.clone(),
},
inner,
)
service: inner,
}
}
}
@ -108,14 +106,10 @@ impl<T> Future for RetryFuture<T> {
}
}
impl<T: RetryPolicy + Clone> Policy<MiddlewareRequest, Response, Error> for TowerRetryPolicy<T> {
impl<T: RetryPolicy + Clone> Policy for TowerRetryPolicy<T> {
type Future = RetryFuture<T>;
fn retry(
&self,
_req: &MiddlewareRequest,
result: std::result::Result<&Response, &Error>,
) -> Option<Self::Future> {
fn retry(&self, _req: &Request, result: &Result<Response, Error>) -> Option<Self::Future> {
// We classify the response which will return None if not
// errors were returned.
match Retryable::from_reqwest_response(result) {
@ -147,10 +141,172 @@ impl<T: RetryPolicy + Clone> Policy<MiddlewareRequest, Response, Error> for Towe
}
}
fn clone_request(&self, req: &MiddlewareRequest) -> Option<MiddlewareRequest> {
Some(MiddlewareRequest {
request: req.request.try_clone()?,
extensions: Extensions::new(),
})
fn clone_request(&self, req: &Request) -> Option<Request> {
req.try_clone()
}
}
pub trait Policy: Sized {
/// The [`Future`] type returned by [`Policy::retry`].
type Future: Future<Output = Self>;
/// Check the policy if a certain request should be retried.
///
/// This method is passed a reference to the original request, and either
/// the [`Service::Response`] or [`Service::Error`] from the inner service.
///
/// If the request should **not** be retried, return `None`.
///
/// If the request *should* be retried, return `Some` future of a new
/// policy that would apply for the next request attempt.
///
/// [`Service::Response`]: crate::Service::Response
/// [`Service::Error`]: crate::Service::Error
fn retry(&self, req: &Request, result: &Result<Response, Error>) -> Option<Self::Future>;
/// Tries to clone a request before being passed to the inner service.
///
/// If the request cannot be cloned, return [`None`].
fn clone_request(&self, req: &Request) -> Option<Request>;
}
pin_project! {
/// Configure retrying requests of "failed" responses.
///
/// A [`Policy`] classifies what is a "failed" response.
#[derive(Clone, Debug)]
pub struct Retry<P, S> {
#[pin]
policy: P,
service: S,
}
}
impl<P, S> Service for Retry<P, S>
where
P: 'static + Policy + Clone,
S: 'static + Service + Clone,
{
type Future = ResponseFuture<P, S>;
fn call(&mut self, request: Request, ext: &mut Extensions) -> Self::Future {
let cloned = self.policy.clone_request(&request);
let future = self.service.call(request, ext);
ResponseFuture::new(cloned, self.clone(), future)
}
// fn call(&mut self, request: Request) -> Self::Future {
// let cloned = self.policy.clone_request(&request);
// let future = self.service.call(request);
// ResponseFuture::new(cloned, self.clone(), future)
// }
}
pin_project! {
/// The [`Future`] returned by a [`Retry`] service.
#[derive(Debug)]
pub struct ResponseFuture<P, S>
where
P: Policy,
S: Service,
{
request: Option<Request>,
#[pin]
retry: Retry<P, S>,
#[pin]
state: State<S::Future, P::Future>,
}
}
pin_project! {
#[project = StateProj]
#[derive(Debug)]
enum State<F, P> {
// Polling the future from [`Service::call`]
Called {
#[pin]
future: F
},
// Polling the future from [`Policy::retry`]
Checking {
#[pin]
checking: P
},
// Polling [`Service::poll_ready`] after [`Checking`] was OK.
Retrying,
}
}
impl<P, S> ResponseFuture<P, S>
where
P: Policy,
S: Service,
{
pub(crate) fn new(
request: Option<Request>,
retry: Retry<P, S>,
future: S::Future,
) -> ResponseFuture<P, S> {
ResponseFuture {
request,
retry,
state: State::Called { future },
}
}
}
impl<P, S> Future for ResponseFuture<P, S>
where
P: Policy + Clone,
S: Service + Clone,
{
type Output = Result<Response, Error>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let mut this = self.project();
loop {
match this.state.as_mut().project() {
StateProj::Called { future } => {
let result = ready!(future.poll(cx));
if let Some(ref req) = this.request {
match this.retry.policy.retry(req, &result) {
Some(checking) => {
this.state.set(State::Checking { checking });
}
None => return Poll::Ready(result),
}
} else {
// request wasn't cloned, so no way to retry it
return Poll::Ready(result);
}
}
StateProj::Checking { checking } => {
this.retry
.as_mut()
.project()
.policy
.set(ready!(checking.poll(cx)));
this.state.set(State::Retrying);
}
StateProj::Retrying => {
let req = this
.request
.take()
.expect("retrying requires cloned request");
*this.request = this.retry.policy.clone_request(&req);
this.state.set(State::Called {
future: this
.retry
.as_mut()
.project()
.service
.call(req, &mut Extensions::new()),
});
}
}
}
}
}

View file

@ -15,7 +15,7 @@ impl Retryable {
///
/// Returns `None` if the response object does not contain any errors.
///
pub fn from_reqwest_response(res: Result<&reqwest::Response, &Error>) -> Option<Self> {
pub fn from_reqwest_response(res: &Result<reqwest::Response, Error>) -> Option<Self> {
match res {
Ok(success) => {
let status = success.status();

View file

@ -25,7 +25,6 @@ async-trait = "0.1.51"
reqwest = { version = "0.11", default-features = false }
task-local-extensions = "0.1.1"
tracing = "0.1.26"
tower = "0.4"
pin-project-lite = "0.2"
opentelemetry_0_13_pkg = { package = "opentelemetry", version = "0.13", optional = true }

View file

@ -4,9 +4,10 @@ use std::{
};
use pin_project_lite::pin_project;
use reqwest::Response;
use reqwest_middleware::{Error, MiddlewareRequest};
use tower::{Layer, Service};
use reqwest::{Request, Response};
use reqwest_middleware::{Error, Layer, Service};
use task_local_extensions::Extensions;
// use tower::{Layer, Service};
use tracing::Span;
use crate::{DefaultSpanBackend, ReqwestOtelSpanBackend};
@ -41,6 +42,7 @@ impl Default for TracingMiddleware<DefaultSpanBackend> {
impl<ReqwestOtelSpan, Svc> Layer<Svc> for TracingMiddleware<ReqwestOtelSpan>
where
ReqwestOtelSpan: ReqwestOtelSpanBackend + Sync + Send + 'static,
Svc: Service,
{
type Service = TracingMiddlewareService<ReqwestOtelSpan, Svc>;
@ -58,26 +60,15 @@ pub struct TracingMiddlewareService<S: ReqwestOtelSpanBackend, Svc> {
service: Svc,
}
impl<ReqwestOtelSpan, Svc> Service<MiddlewareRequest>
for TracingMiddlewareService<ReqwestOtelSpan, Svc>
impl<ReqwestOtelSpan, Svc> Service for TracingMiddlewareService<ReqwestOtelSpan, Svc>
where
ReqwestOtelSpan: ReqwestOtelSpanBackend + Sync + Send + 'static,
Svc: Service<MiddlewareRequest, Response = Response, Error = Error>,
Svc: Service,
{
type Response = Response;
type Error = Error;
type Future = TracingMiddlewareFuture<ReqwestOtelSpan, Svc::Future>;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.service.poll_ready(cx)
}
fn call(&mut self, req: MiddlewareRequest) -> Self::Future {
let MiddlewareRequest {
request,
mut extensions,
} = req;
let (backend, span) = ReqwestOtelSpan::on_request_start(&request, &mut extensions);
fn call(&mut self, req: Request, ext: &mut Extensions) -> Self::Future {
let (backend, span) = ReqwestOtelSpan::on_request_start(&req, ext);
// Adds tracing headers to the given request to propagate the OpenTelemetry context to downstream revivers of the request.
// Spans added by downstream consumers will be part of the same trace.
#[cfg(any(
@ -90,10 +81,7 @@ where
))]
let request = crate::otel::inject_opentelemetry_context_into_request(request);
let future = self.service.call(MiddlewareRequest {
request,
extensions,
});
let future = self.service.call(req, ext);
TracingMiddlewareFuture {
span,