diff --git a/CHANGELOG.md b/CHANGELOG.md index bce41c7..25e230d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,13 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added - Implementation of `Debug` trait for `RequestBuilder`. +- A new `RequestInitialiser` trait that can be added to `ClientWithMiddleware` +- A new `Extension` initialiser that adds extensions to the request +- Adds `with_extension` method functionality to `RequestBuilder` that can add extensions for the `send` method to use - deprecating `send_with_extensions`. + +## [0.1.6] - 2022-04-21 + +Absolutely nothing changed ## [0.1.5] - 2022-02-21 diff --git a/reqwest-middleware/src/client.rs b/reqwest-middleware/src/client.rs index 3a5f247..58a1764 100644 --- a/reqwest-middleware/src/client.rs +++ b/reqwest-middleware/src/client.rs @@ -10,6 +10,7 @@ use task_local_extensions::Extensions; use crate::error::Result; use crate::middleware::{Middleware, Next}; +use crate::RequestInitialiser; /// A `ClientBuilder` is used to build a [`ClientWithMiddleware`]. /// @@ -17,6 +18,7 @@ use crate::middleware::{Middleware, Next}; pub struct ClientBuilder { client: Client, middleware_stack: Vec>, + initialiser_stack: Vec>, } impl ClientBuilder { @@ -24,6 +26,7 @@ impl ClientBuilder { ClientBuilder { client, middleware_stack: Vec::new(), + initialiser_stack: Vec::new(), } } @@ -47,9 +50,33 @@ impl ClientBuilder { self } + /// Convenience method to attach a request initialiser. + /// + /// If you need to keep a reference to the initialiser after attaching, use [`with_arc_init`]. + /// + /// [`with_arc_init`]: Self::with_arc_init + pub fn with_init(self, initialiser: I) -> Self + where + I: RequestInitialiser, + { + self.with_arc_init(Arc::new(initialiser)) + } + + /// Add a request initialiser to the chain. [`with_init`] is more ergonomic if you don't need the `Arc`. + /// + /// [`with_init`]: Self::with_init + pub fn with_arc_init(mut self, initialiser: Arc) -> Self { + self.initialiser_stack.push(initialiser); + self + } + /// Returns a `ClientWithMiddleware` using this builder configuration. pub fn build(self) -> ClientWithMiddleware { - ClientWithMiddleware::new(self.client, self.middleware_stack) + ClientWithMiddleware { + inner: self.client, + middleware_stack: self.middleware_stack.into_boxed_slice(), + initialiser_stack: self.initialiser_stack.into_boxed_slice(), + } } } @@ -59,6 +86,7 @@ impl ClientBuilder { pub struct ClientWithMiddleware { inner: reqwest::Client, middleware_stack: Box<[Arc]>, + initialiser_stack: Box<[Arc]>, } impl ClientWithMiddleware { @@ -70,6 +98,8 @@ impl ClientWithMiddleware { ClientWithMiddleware { inner: client, middleware_stack: middleware_stack.into(), + // TODO(conradludgate) - allow downstream code to control this manually if desired + initialiser_stack: Box::new([]), } } @@ -105,10 +135,14 @@ impl ClientWithMiddleware { /// See [`Client::request`] pub fn request(&self, method: Method, url: U) -> RequestBuilder { - RequestBuilder { + let req = RequestBuilder { inner: self.inner.request(method, url), client: self.clone(), - } + extensions: Extensions::new(), + }; + self.initialiser_stack + .iter() + .fold(req, |req, i| i.init(req)) } /// See [`Client::execute`] @@ -134,6 +168,7 @@ impl From for ClientWithMiddleware { ClientWithMiddleware { inner: client, middleware_stack: Box::new([]), + initialiser_stack: Box::new([]), } } } @@ -152,6 +187,7 @@ impl fmt::Debug for ClientWithMiddleware { pub struct RequestBuilder { inner: reqwest::RequestBuilder, client: ClientWithMiddleware, + extensions: Extensions, } impl RequestBuilder { @@ -164,14 +200,14 @@ impl RequestBuilder { { RequestBuilder { inner: self.inner.header(key, value), - client: self.client, + ..self } } pub fn headers(self, headers: HeaderMap) -> Self { RequestBuilder { inner: self.inner.headers(headers), - client: self.client, + ..self } } @@ -182,7 +218,7 @@ impl RequestBuilder { { RequestBuilder { inner: self.inner.basic_auth(username, password), - client: self.client, + ..self } } @@ -192,49 +228,49 @@ impl RequestBuilder { { RequestBuilder { inner: self.inner.bearer_auth(token), - client: self.client, + ..self } } pub fn body>(self, body: T) -> Self { RequestBuilder { inner: self.inner.body(body), - client: self.client, + ..self } } pub fn timeout(self, timeout: Duration) -> Self { RequestBuilder { inner: self.inner.timeout(timeout), - client: self.client, + ..self } } pub fn multipart(self, multipart: Form) -> Self { RequestBuilder { inner: self.inner.multipart(multipart), - client: self.client, + ..self } } pub fn query(self, query: &T) -> Self { RequestBuilder { inner: self.inner.query(query), - client: self.client, + ..self } } pub fn form(self, form: &T) -> Self { RequestBuilder { inner: self.inner.form(form), - client: self.client, + ..self } } pub fn json(self, json: &T) -> Self { RequestBuilder { inner: self.inner.json(json), - client: self.client, + ..self } } @@ -242,22 +278,44 @@ impl RequestBuilder { self.inner.build() } + /// Inserts the extension into this request builder + pub fn with_extension(mut self, extension: T) -> Self { + self.extensions.insert(extension); + self + } + + /// Returns a mutable reference to the internal set of extensions for this request + pub fn extensions(&mut self) -> &mut Extensions { + &mut self.extensions + } + pub async fn send(self) -> Result { - let req = self.inner.build()?; - self.client.execute(req).await + let Self { + inner, + client, + mut extensions, + } = self; + let req = inner.build()?; + client.execute_with_extensions(req, &mut extensions).await } /// Sends a request with initial [`Extensions`]. + #[deprecated = "use the with_extension method and send directly"] pub async fn send_with_extensions(self, ext: &mut Extensions) -> Result { - let req = self.inner.build()?; - self.client.execute_with_extensions(req, ext).await + let Self { inner, client, .. } = self; + let req = inner.build()?; + client.execute_with_extensions(req, ext).await } + // TODO(conradludgate): fix this method to take `&self`. It's currently useless as it is. + // I'm tempted to make this breaking change without a major bump, but I'll wait for now + #[deprecated = "This method was badly replicated from the base RequestBuilder. If you somehow made use of this method, it will break next major version"] pub fn try_clone(self) -> Option { - let client = self.client; - self.inner - .try_clone() - .map(|inner| RequestBuilder { inner, client }) + self.inner.try_clone().map(|inner| RequestBuilder { + inner, + client: self.client, + extensions: self.extensions, + }) } } diff --git a/reqwest-middleware/src/lib.rs b/reqwest-middleware/src/lib.rs index 6432f67..a92243d 100644 --- a/reqwest-middleware/src/lib.rs +++ b/reqwest-middleware/src/lib.rs @@ -52,7 +52,9 @@ pub struct ReadmeDoctests; mod client; mod error; mod middleware; +mod req_init; pub use client::{ClientBuilder, ClientWithMiddleware, RequestBuilder}; pub use error::{Error, Result}; pub use middleware::{Middleware, Next}; +pub use req_init::{Extension, RequestInitialiser}; diff --git a/reqwest-middleware/src/req_init.rs b/reqwest-middleware/src/req_init.rs new file mode 100644 index 0000000..92c8167 --- /dev/null +++ b/reqwest-middleware/src/req_init.rs @@ -0,0 +1,84 @@ +use crate::RequestBuilder; + +/// When attached to a [`ClientWithMiddleware`] (generally using [`with_init`]), it is run +/// whenever the client starts building a request, in the order it was attached. +/// +/// # Example +/// +/// ``` +/// use reqwest_middleware::{RequestInitialiser, RequestBuilder}; +/// +/// struct AuthInit; +/// +/// impl RequestInitialiser for AuthInit { +/// fn init(&self, req: RequestBuilder) -> RequestBuilder { +/// req.bearer_auth("my_auth_token") +/// } +/// } +/// ``` +/// +/// [`ClientWithMiddleware`]: crate::ClientWithMiddleware +/// [`with_init`]: crate::ClientBuilder::with_init +pub trait RequestInitialiser: 'static + Send + Sync { + fn init(&self, req: RequestBuilder) -> RequestBuilder; +} + +impl RequestInitialiser for F +where + F: Send + Sync + 'static + Fn(RequestBuilder) -> RequestBuilder, +{ + fn init(&self, req: RequestBuilder) -> RequestBuilder { + (self)(req) + } +} + +/// A middleware that inserts the value into the [`Extensions`](task_local_extensions::Extensions) during the call. +/// +/// This is a good way to inject extensions to middleware deeper in the stack +/// +/// ``` +/// use reqwest::{Client, Request, Response}; +/// use reqwest_middleware::{ClientBuilder, Middleware, Next, Result, Extension}; +/// use task_local_extensions::Extensions; +/// +/// #[derive(Clone)] +/// struct LogName(&'static str); +/// struct LoggingMiddleware; +/// +/// #[async_trait::async_trait] +/// impl Middleware for LoggingMiddleware { +/// async fn handle( +/// &self, +/// req: Request, +/// extensions: &mut Extensions, +/// next: Next<'_>, +/// ) -> Result { +/// // get the log name or default to "unknown" +/// let name = extensions +/// .get() +/// .map(|&LogName(name)| name) +/// .unwrap_or("unknown"); +/// println!("[{name}] Request started {req:?}"); +/// let res = next.run(req, extensions).await; +/// println!("[{name}] Result: {res:?}"); +/// res +/// } +/// } +/// +/// async fn run() { +/// let reqwest_client = Client::builder().build().unwrap(); +/// let client = ClientBuilder::new(reqwest_client) +/// .with_init(Extension(LogName("my-client"))) +/// .with(LoggingMiddleware) +/// .build(); +/// let resp = client.get("https://truelayer.com").send().await.unwrap(); +/// println!("TrueLayer page HTML: {}", resp.text().await.unwrap()); +/// } +/// ``` +pub struct Extension(pub T); + +impl RequestInitialiser for Extension { + fn init(&self, req: RequestBuilder) -> RequestBuilder { + req.with_extension(self.0.clone()) + } +}