diff --git a/backend/src/handlers/mod.rs b/backend/src/handlers/mod.rs index c3e117f..2fca00d 100644 --- a/backend/src/handlers/mod.rs +++ b/backend/src/handlers/mod.rs @@ -690,8 +690,10 @@ pub async fn handle_timeout_error(err: BoxError) -> (StatusCode, &'static str) { (status = 200, description = "VAPID public key", body = String) ) )] -pub async fn get_push_public_key_handler() -> impl IntoResponse { - let public_key = push::get_vapid_public_key(); +pub async fn get_push_public_key_handler( + State(state): State, +) -> impl IntoResponse { + let public_key = state.push_store.get_public_key(); (StatusCode::OK, Json(serde_json::json!({ "publicKey": public_key }))).into_response() } diff --git a/backend/src/push.rs b/backend/src/push.rs index 7d3b257..d68ba2a 100644 --- a/backend/src/push.rs +++ b/backend/src/push.rs @@ -5,6 +5,7 @@ use utoipa::ToSchema; use web_push::{ HyperWebPushClient, SubscriptionInfo, VapidSignatureBuilder, WebPushClient, WebPushMessageBuilder, }; +use futures::StreamExt; use crate::db::Db; @@ -20,17 +21,34 @@ pub struct PushKeys { pub auth: String, } +#[derive(Clone)] +pub struct VapidConfig { + pub private_key: String, + pub public_key: String, + pub email: String, +} + #[derive(Clone)] pub struct PushSubscriptionStore { db: Option, subscriptions: Arc>>, + vapid_config: VapidConfig, } impl PushSubscriptionStore { pub fn new() -> Self { + let private_key = std::env::var("VAPID_PRIVATE_KEY").expect("VAPID_PRIVATE_KEY must be set in .env"); + let public_key = std::env::var("VAPID_PUBLIC_KEY").expect("VAPID_PUBLIC_KEY must be set in .env"); + let email = std::env::var("VAPID_EMAIL").expect("VAPID_EMAIL must be set in .env"); + Self { db: None, subscriptions: Arc::new(RwLock::new(Vec::new())), + vapid_config: VapidConfig { + private_key, + public_key, + email, + }, } } @@ -47,9 +65,18 @@ impl PushSubscriptionStore { } tracing::info!("Loaded {} push subscriptions from database", subscriptions_vec.len()); + let private_key = std::env::var("VAPID_PRIVATE_KEY").expect("VAPID_PRIVATE_KEY must be set in .env"); + let public_key = std::env::var("VAPID_PUBLIC_KEY").expect("VAPID_PUBLIC_KEY must be set in .env"); + let email = std::env::var("VAPID_EMAIL").expect("VAPID_EMAIL must be set in .env"); + Ok(Self { db: Some(db.clone()), subscriptions: Arc::new(RwLock::new(subscriptions_vec)), + vapid_config: VapidConfig { + private_key, + public_key, + email, + }, }) } @@ -91,6 +118,10 @@ impl PushSubscriptionStore { pub async fn get_all_subscriptions(&self) -> Vec { self.subscriptions.read().await.clone() } + + pub fn get_public_key(&self) -> &str { + &self.vapid_config.public_key + } } /// Send push notification to all subscribed clients @@ -116,50 +147,69 @@ pub async fn send_push_notification( "tag": "vibetorrent" }); - let client = HyperWebPushClient::new(); + let client = Arc::new(HyperWebPushClient::new()); + let vapid_config = store.vapid_config.clone(); + let payload_str = payload.to_string(); - let vapid_private_key = std::env::var("VAPID_PRIVATE_KEY").expect("VAPID_PRIVATE_KEY must be set in .env"); - let vapid_email = std::env::var("VAPID_EMAIL").expect("VAPID_EMAIL must be set in .env"); + // Send notifications concurrently + futures::stream::iter(subscriptions) + .for_each_concurrent(10, |subscription| { + let client = client.clone(); + let vapid_config = vapid_config.clone(); + let payload_str = payload_str.clone(); - for subscription in subscriptions { - let subscription_info = SubscriptionInfo { - endpoint: subscription.endpoint.clone(), - keys: web_push::SubscriptionKeys { - p256dh: subscription.keys.p256dh.clone(), - auth: subscription.keys.auth.clone(), - }, - }; + async move { + let subscription_info = SubscriptionInfo { + endpoint: subscription.endpoint.clone(), + keys: web_push::SubscriptionKeys { + p256dh: subscription.keys.p256dh.clone(), + auth: subscription.keys.auth.clone(), + }, + }; - let mut sig_builder = VapidSignatureBuilder::from_base64( - &vapid_private_key, - web_push::URL_SAFE_NO_PAD, - &subscription_info, - )?; + let sig_res = VapidSignatureBuilder::from_base64( + &vapid_config.private_key, + web_push::URL_SAFE_NO_PAD, + &subscription_info, + ); - sig_builder.add_claim("sub", vapid_email.as_str()); - sig_builder.add_claim("aud", subscription.endpoint.as_str()); - let signature = sig_builder.build()?; + match sig_res { + Ok(mut sig_builder) => { + sig_builder.add_claim("sub", vapid_config.email.as_str()); + sig_builder.add_claim("aud", subscription.endpoint.as_str()); + + match sig_builder.build() { + Ok(signature) => { + let mut builder = WebPushMessageBuilder::new(&subscription_info); + builder.set_vapid_signature(signature); + builder.set_payload(web_push::ContentEncoding::Aes128Gcm, payload_str.as_bytes()); - let mut builder = WebPushMessageBuilder::new(&subscription_info); - builder.set_vapid_signature(signature); - - let payload_str = payload.to_string(); - builder.set_payload(web_push::ContentEncoding::Aes128Gcm, payload_str.as_bytes()); - - match client.send(builder.build()?).await { - Ok(_) => { - tracing::debug!("Push notification sent to: {}", subscription.endpoint); + match builder.build() { + Ok(msg) => { + match client.send(msg).await { + Ok(_) => { + tracing::debug!("Push notification sent to: {}", subscription.endpoint); + } + Err(e) => { + tracing::error!("Failed to send push notification to {}: {}", subscription.endpoint, e); + } + } + } + Err(e) => tracing::error!("Failed to build push message: {}", e), + } + } + Err(e) => tracing::error!("Failed to build VAPID signature: {}", e), + } + } + Err(e) => tracing::error!("Failed to create VAPID signature builder: {}", e), + } } - Err(e) => { - tracing::error!("Failed to send push notification: {}", e); - // TODO: Remove invalid subscriptions - } - } - } + }) + .await; Ok(()) } pub fn get_vapid_public_key() -> String { std::env::var("VAPID_PUBLIC_KEY").expect("VAPID_PUBLIC_KEY must be set in .env") -} +} \ No newline at end of file