From 31e2cab280b2c4b987e4ee6e7202741ec291ae84 Mon Sep 17 00:00:00 2001 From: Zynh Ludwig Date: Thu, 21 Nov 2024 10:01:32 -0800 Subject: [PATCH] feat: deny non hx-requests on fragment --- src/router/link.rs | 29 ++++++++++++++++++++++------- src/util/headers.rs | 35 +++++++++++++++++++++++++++++++++++ 2 files changed, 57 insertions(+), 7 deletions(-) diff --git a/src/router/link.rs b/src/router/link.rs index 3867e28..70b5ef0 100644 --- a/src/router/link.rs +++ b/src/router/link.rs @@ -4,9 +4,14 @@ use axum::{ routing::get, Router, }; +use axum_extra::TypedHeader; use reqwest::StatusCode; -use crate::{templates::DownloadLinkTemplate, AppState, AsyncRemoveRecord}; +use crate::{ + templates::{self, DownloadLinkTemplate}, + util::headers::HxRequest, + AppState, AsyncRemoveRecord, +}; pub fn get_link_router() -> Router { // Link pages @@ -62,12 +67,22 @@ async fn link_delete( async fn remaining( State(state): State, + hx_request: Option>, axum::extract::Path(id): axum::extract::Path, -) -> impl IntoResponse { - let records = state.records.lock().await; - if let Some(record) = records.get(&id) { - Html(crate::templates::get_downloads_remaining_text(record)) - } else { - Html("?".to_string()) +) -> Result { + if hx_request.is_none() { + return Err(( + StatusCode::BAD_REQUEST, + "Attempt to fetch html fragment from non-htmx request".to_string(), + )); } + + let records = state.records.lock().await; + + Ok(Html( + records + .get(&id) + .map(templates::get_downloads_remaining_text) + .unwrap_or_else(|| "?".to_string()), + )) } diff --git a/src/util/headers.rs b/src/util/headers.rs index 075fabf..fd00000 100644 --- a/src/util/headers.rs +++ b/src/util/headers.rs @@ -31,3 +31,38 @@ impl Header for ForwardedFor { values.extend(std::iter::once(HeaderValue::from_str(&self.0).unwrap())); } } + +#[derive(Debug)] +pub struct HxRequest; + +pub static HXR_TEXT: &str = "hx-request"; + +pub static HXR_NAME: HeaderName = HeaderName::from_static(HXR_TEXT); + +impl Header for HxRequest { + fn name() -> &'static HeaderName { + &FF_NAME + } + + fn decode<'i, I>(values: &mut I) -> Result + where + Self: Sized, + I: Iterator, + { + let value = values + .next() + .ok_or_else(headers::Error::invalid)? + .to_str() + .map_err(|_| headers::Error::invalid())? + .to_owned(); + + match &value[..] { + "true" => Ok(HxRequest), + _ => Err(headers::Error::invalid()), + } + } + + fn encode>(&self, values: &mut E) { + values.extend(std::iter::once(HeaderValue::from_static("true"))); + } +}