diff --git a/src/api.rs b/src/api.rs index a4e89ed..40847aa 100644 --- a/src/api.rs +++ b/src/api.rs @@ -125,16 +125,9 @@ async fn add_drink( async fn ada_subscribe( State(state): State, ) -> Sse>> { - let stream = tokio_stream::wrappers::BroadcastStream::new(state.ada_sender.subscribe()) - .map(|r| r.map(|s| Event::default().event("ada").data(s))); + let stream = + tokio_stream::wrappers::BroadcastStream::new(state.sse_handler.ada_sender.subscribe()) + .map(|r| r.map(|s| Event::default().event("ada").data(s))); Sse::new(stream).keep_alive(KeepAlive::default()) } - -async fn get_ada_list() -> String { - let mut buf = Vec::new(); - - crate::templates::components::ada_list_html(&mut buf).unwrap(); - - String::from_utf8(buf).unwrap() -} diff --git a/src/main.rs b/src/main.rs index 095dcd2..e4dd26d 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,5 +1,6 @@ mod api; mod axum_ructe; +mod sse_handler; use axum_ructe::render; @@ -26,7 +27,8 @@ use diesel_async::{ }; use dotenvy::dotenv; -use std::net::SocketAddr; +use sse_handler::SseHandler; +use std::{net::SocketAddr, time::Duration}; use tower_http::services::{ServeDir, ServeFile}; use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; @@ -45,22 +47,17 @@ fn establish_connection() -> Pool { .expect("Error making connection pool") } -#[derive(Clone)] -enum AdaUpdate { - RefreshDancers, -} - #[derive(Clone)] pub(crate) struct AppState { connection: Pool, - ada_sender: tokio::sync::broadcast::Sender, + sse_handler: SseHandler, } impl AppState { fn init() -> Self { Self { connection: establish_connection(), - ada_sender: tokio::sync::broadcast::channel(10).0, + sse_handler: SseHandler::init(), } } } @@ -78,6 +75,19 @@ async fn main() { let state = AppState::init(); + tokio::spawn({ + let sender = state.clone().sse_handler.sse_sender; + async move { + loop { + sender + .send(sse_handler::SseMessage::RefreshAda) + .await + .expect("Failed to send sse message"); + tokio::time::sleep(Duration::from_secs(5)).await; + } + } + }); + let fallback_handler = ServeDir::new("dist").not_found_service(ServeFile::new("dist/404.html")); // build our application with a route diff --git a/src/sse_handler.rs b/src/sse_handler.rs new file mode 100644 index 0000000..8aee8d3 --- /dev/null +++ b/src/sse_handler.rs @@ -0,0 +1,45 @@ +#[derive(Debug, Clone)] +pub enum SseMessage { + RefreshAda, +} + +#[derive(Debug, Clone)] +pub struct SseHandler { + pub ada_sender: tokio::sync::broadcast::Sender, + pub sse_sender: tokio::sync::mpsc::Sender, +} + +impl SseHandler { + pub fn init() -> Self { + let (ada_sender, _) = tokio::sync::broadcast::channel(10); + let (sse_sender, mut sse_receiver) = tokio::sync::mpsc::channel(10); + + tokio::spawn({ + let sender = ada_sender.clone(); + async move { + while let Some(message) = sse_receiver.recv().await { + match message { + SseMessage::RefreshAda => { + if sender.receiver_count() > 0 { + sender.send(get_ada_list().await).unwrap(); + } + } + }; + } + } + }); + + Self { + ada_sender, + sse_sender, + } + } +} + +async fn get_ada_list() -> String { + let mut buf = Vec::new(); + + crate::templates::components::ada_list_html(&mut buf).unwrap(); + + String::from_utf8(buf).unwrap() +}