diff --git a/Cargo.lock b/Cargo.lock index c9327f2..f3937cc 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -292,9 +292,11 @@ name = "backend" version = "0.1.0" dependencies = [ "anyhow", + "async-trait", "axum", "axum-extra", "base64 0.22.1", + "bb8", "bcrypt", "bytes", "clap", @@ -356,6 +358,18 @@ version = "1.8.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2af50177e190e07a26ab74f8b1efbfe2ef87da2116221318cb1c2e82baf7de06" +[[package]] +name = "bb8" +version = "0.8.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d89aabfae550a5c44b43ab941844ffcd2e993cb6900b342debf59e9ea74acdb8" +dependencies = [ + "async-trait", + "futures-util", + "parking_lot", + "tokio", +] + [[package]] name = "bcrypt" version = "0.17.1" diff --git a/backend/Cargo.toml b/backend/Cargo.toml index c0d79c8..7358243 100644 --- a/backend/Cargo.toml +++ b/backend/Cargo.toml @@ -43,3 +43,5 @@ time = { version = "0.3.47", features = ["serde", "formatting", "parsing"] } tower_governor = "0.8.0" governor = "0.10.4" strum = { version = "0.25", features = ["derive", "strum_macros"] } +bb8 = "0.8" +async-trait = "0.1" \ No newline at end of file diff --git a/backend/src/handlers/mod.rs b/backend/src/handlers/mod.rs index c858a26..c6b33ea 100644 --- a/backend/src/handlers/mod.rs +++ b/backend/src/handlers/mod.rs @@ -71,7 +71,7 @@ pub async fn add_torrent_handler( "Received add_torrent request. URI length: {}", payload.uri.len() ); - let client = xmlrpc::RtorrentClient::new(&state.scgi_socket_path); + let client = xmlrpc::RtorrentClient::new(state.rtorrent_pool.clone()); let params = vec![RpcParam::from(""), RpcParam::from(payload.uri.as_str())]; match client.call("load.start", ¶ms).await { @@ -114,7 +114,7 @@ pub async fn handle_torrent_action( payload.hash ); - let client = xmlrpc::RtorrentClient::new(&state.scgi_socket_path); + let client = xmlrpc::RtorrentClient::new(state.rtorrent_pool.clone()); // Special handling for delete_with_data if payload.action == "delete_with_data" { @@ -298,7 +298,7 @@ async fn delete_torrent_with_data( ) )] pub async fn get_version_handler(State(state): State) -> impl IntoResponse { - let client = xmlrpc::RtorrentClient::new(&state.scgi_socket_path); + let client = xmlrpc::RtorrentClient::new(state.rtorrent_pool.clone()); match client.call("system.client_version", &[]).await { Ok(xml) => { let version = xmlrpc::parse_string_response(&xml).unwrap_or(xml); @@ -327,7 +327,7 @@ pub async fn get_files_handler( State(state): State, Path(hash): Path, ) -> impl IntoResponse { - let client = xmlrpc::RtorrentClient::new(&state.scgi_socket_path); + let client = xmlrpc::RtorrentClient::new(state.rtorrent_pool.clone()); let params = vec![ RpcParam::from(hash.as_str()), RpcParam::from(""), @@ -383,7 +383,7 @@ pub async fn get_peers_handler( State(state): State, Path(hash): Path, ) -> impl IntoResponse { - let client = xmlrpc::RtorrentClient::new(&state.scgi_socket_path); + let client = xmlrpc::RtorrentClient::new(state.rtorrent_pool.clone()); let params = vec![ RpcParam::from(hash.as_str()), RpcParam::from(""), @@ -439,7 +439,7 @@ pub async fn get_trackers_handler( State(state): State, Path(hash): Path, ) -> impl IntoResponse { - let client = xmlrpc::RtorrentClient::new(&state.scgi_socket_path); + let client = xmlrpc::RtorrentClient::new(state.rtorrent_pool.clone()); let params = vec![ RpcParam::from(hash.as_str()), RpcParam::from(""), @@ -493,7 +493,7 @@ pub async fn set_file_priority_handler( State(state): State, Json(payload): Json, ) -> impl IntoResponse { - let client = xmlrpc::RtorrentClient::new(&state.scgi_socket_path); + let client = xmlrpc::RtorrentClient::new(state.rtorrent_pool.clone()); // f.set_priority takes "hash", index, priority // Priority: 0 (off), 1 (normal), 2 (high) @@ -541,7 +541,7 @@ pub async fn set_label_handler( State(state): State, Json(payload): Json, ) -> impl IntoResponse { - let client = xmlrpc::RtorrentClient::new(&state.scgi_socket_path); + let client = xmlrpc::RtorrentClient::new(state.rtorrent_pool.clone()); let params = vec![ RpcParam::from(payload.hash.as_str()), RpcParam::from(payload.label), @@ -567,7 +567,7 @@ pub async fn set_label_handler( ) )] pub async fn get_global_limit_handler(State(state): State) -> impl IntoResponse { - let client = xmlrpc::RtorrentClient::new(&state.scgi_socket_path); + let client = xmlrpc::RtorrentClient::new(state.rtorrent_pool.clone()); // throttle.global_down.max_rate, throttle.global_up.max_rate let down_fut = client.call("throttle.global_down.max_rate", &[]); let up_fut = client.call("throttle.global_up.max_rate", &[]); @@ -604,7 +604,7 @@ pub async fn set_global_limit_handler( State(state): State, Json(payload): Json, ) -> impl IntoResponse { - let client = xmlrpc::RtorrentClient::new(&state.scgi_socket_path); + let client = xmlrpc::RtorrentClient::new(state.rtorrent_pool.clone()); // Use throttle.global_*.max_rate.set_kb which is more reliable than .set (which is buggy) // The .set_kb method expects KB/s, so we convert bytes to KB diff --git a/backend/src/main.rs b/backend/src/main.rs index bac6fea..f6322da 100644 --- a/backend/src/main.rs +++ b/backend/src/main.rs @@ -1,3 +1,4 @@ +use crate::scgi::{ScgiPool, create_pool}; mod db; mod diff; mod handlers; @@ -41,7 +42,7 @@ use utoipa_swagger_ui::SwaggerUi; pub struct AppState { pub tx: Arc>>, pub event_bus: broadcast::Sender, - pub scgi_socket_path: String, + pub rtorrent_pool: ScgiPool, pub db: db::Db, #[cfg(feature = "push-notifications")] pub push_store: push::PushSubscriptionStore, @@ -292,30 +293,16 @@ async fn main() { tracing::info!("Socket: {}", args.socket); tracing::info!("Port: {}", args.port); - // ... rest of the main function ... - // Startup Health Check - let socket_path = std::path::Path::new(&args.socket); - if !socket_path.exists() { - tracing::error!("CRITICAL: rTorrent socket not found at {:?}.", socket_path); - tracing::warn!( - "HINT: Make sure rTorrent is running and the SCGI socket is enabled in .rtorrent.rc" - ); - tracing::warn!( - "HINT: You can configure the socket path via --socket ARG or RTORRENT_SOCKET ENV." - ); - } else { - tracing::info!("Socket file exists. Testing connection..."); - let client = xmlrpc::RtorrentClient::new(&args.socket); - // We use a lightweight call to verify connectivity - let params: Vec = vec![]; - match client.call("system.client_version", ¶ms).await { - Ok(xml) => { - let version = xmlrpc::parse_string_response(&xml).unwrap_or(xml); - tracing::info!("Connected to rTorrent successfully. Version: {}", version); - } - Err(e) => tracing::error!("Socket exists but failed to connect to rTorrent: {}", e), + // Initialize SCGI connection pool + tracing::info!("Creating SCGI connection pool..."); + let rtorrent_pool = match create_pool(&args.socket, 10).await { + Ok(pool) => pool, + Err(e) => { + tracing::error!("Failed to create SCGI connection pool: {}", e); + std::process::exit(1); } - } + }; + tracing::info!("SCGI connection pool created successfully."); // Channel for latest state (for new clients) let (tx, _rx) = watch::channel(vec![]); @@ -339,7 +326,7 @@ async fn main() { let app_state = AppState { tx: tx.clone(), event_bus: event_bus.clone(), - scgi_socket_path: args.socket.clone(), + rtorrent_pool: rtorrent_pool.clone(), db: db.clone(), #[cfg(feature = "push-notifications")] push_store, @@ -348,12 +335,12 @@ async fn main() { // Spawn background task to poll rTorrent let tx_clone = tx.clone(); let event_bus_tx = event_bus.clone(); - let socket_path = args.socket.clone(); // Clone for background task + let rtorrent_pool_clone = rtorrent_pool.clone(); #[cfg(feature = "push-notifications")] let push_store_clone = app_state.push_store.clone(); tokio::spawn(async move { - let client = xmlrpc::RtorrentClient::new(&socket_path); + let client = xmlrpc::RtorrentClient::new(rtorrent_pool_clone); let mut previous_torrents: Vec = Vec::new(); let mut consecutive_errors = 0; let mut backoff_duration = Duration::from_secs(1); diff --git a/backend/src/scgi.rs b/backend/src/scgi.rs index 7231255..7810a98 100644 --- a/backend/src/scgi.rs +++ b/backend/src/scgi.rs @@ -1,3 +1,5 @@ +use async_trait::async_trait; +use bb8::ManageConnection; use bytes::Bytes; use std::collections::HashMap; use thiserror::Error; @@ -9,8 +11,8 @@ pub enum ScgiError { #[error("IO Error: {0}")] Io(#[from] std::io::Error), #[allow(dead_code)] - #[error("Protocol Error: {0}")] - Protocol(String), + #[error("Pool Error: {0}")] + Pool(String), } pub struct ScgiRequest { @@ -77,21 +79,45 @@ impl ScgiRequest { } } -pub async fn send_request(socket_path: &str, request: ScgiRequest) -> Result { - let mut stream = UnixStream::connect(socket_path).await?; - let data = request.encode(); - stream.write_all(&data).await?; +pub struct ScgiConnectionManager { + socket_path: String, +} - let mut response = Vec::new(); - stream.read_to_end(&mut response).await?; - - let double_newline = b"\r\n\r\n"; - if let Some(pos) = response - .windows(double_newline.len()) - .position(|window| window == double_newline) - { - Ok(Bytes::from(response.split_off(pos + double_newline.len()))) - } else { - Ok(Bytes::from(response)) +impl ScgiConnectionManager { + pub fn new(socket_path: &str) -> Self { + Self { + socket_path: socket_path.to_string(), + } } } + +#[async_trait] +impl ManageConnection for ScgiConnectionManager { + type Connection = UnixStream; + type Error = ScgiError; + + async fn connect(&self) -> Result { + let stream = UnixStream::connect(&self.socket_path).await?; + Ok(stream) + } + + async fn is_valid(&self, _conn: &mut Self::Connection) -> Result<(), Self::Error> { + Ok(()) + } + + fn has_broken(&self, _conn: &mut Self::Connection) -> bool { + false + } +} + +pub type ScgiPool = bb8::Pool; + +pub async fn create_pool(socket_path: &str, max_size: u32) -> Result { + let manager = ScgiConnectionManager::new(socket_path); + let pool = bb8::Pool::builder() + .max_size(max_size) + .min_idle(Some(max_size / 2)) + .build(manager) + .await?; + Ok(pool) +} diff --git a/backend/src/xmlrpc.rs b/backend/src/xmlrpc.rs index 9ca6022..1934003 100644 --- a/backend/src/xmlrpc.rs +++ b/backend/src/xmlrpc.rs @@ -1,8 +1,9 @@ -use crate::scgi::{send_request, ScgiError, ScgiRequest}; +use crate::scgi::{ScgiError, ScgiPool, ScgiRequest}; use quick_xml::de::from_str; use quick_xml::se::to_string; use serde::{Deserialize, Serialize}; use thiserror::Error; +use tokio::io::{AsyncReadExt, AsyncWriteExt}; #[derive(Error, Debug)] pub enum XmlRpcError { @@ -14,6 +15,8 @@ pub enum XmlRpcError { Deserialization(#[from] quick_xml::de::DeError), #[error("XML Parse Error: {0}")] Parse(String), + #[error("IO Error: {0}")] + Io(#[from] std::io::Error), } // --- Request Parameters Enum --- @@ -205,13 +208,18 @@ struct IntegerResponseValue { // --- Client Implementation --- pub struct RtorrentClient { - socket_path: String, + pool: ScgiPool, } impl RtorrentClient { - pub fn new(socket_path: &str) -> Self { + pub fn new(pool: ScgiPool) -> Self { + Self { pool } + } + + #[cfg(test)] + pub fn new_unittest() -> Self { Self { - socket_path: socket_path.to_string(), + pool: panic!("Pool not available in unit tests"), } } @@ -248,8 +256,23 @@ impl RtorrentClient { let xml = self.build_method_call(method, params)?; let req = ScgiRequest::new().body(xml.into_bytes()); - let bytes = send_request(&self.socket_path, req).await?; - let s = String::from_utf8_lossy(&bytes).to_string(); + let mut conn = self.pool.get().await.map_err(|e| XmlRpcError::Scgi(ScgiError::Pool(e.to_string())))?; + conn.write_all(&req.encode()).await.map_err(|e| XmlRpcError::Io(e))?; + + let mut response = Vec::new(); + conn.read_to_end(&mut response).await.map_err(|e| XmlRpcError::Io(e))?; + + let double_newline = b"\r\n\r\n"; + let result = if let Some(pos) = response + .windows(double_newline.len()) + .position(|window| window == double_newline) + { + response.split_off(pos + double_newline.len()) + } else { + response + }; + + let s = String::from_utf8_lossy(&result).to_string(); Ok(s) } } @@ -295,7 +318,7 @@ mod tests { #[test] fn test_build_method_call() { - let client = RtorrentClient::new("dummy"); + let client = RtorrentClient::new_unittest(); let params = vec![ RpcParam::String("".to_string()), RpcParam::String("main".to_string()), @@ -309,7 +332,7 @@ mod tests { #[test] fn test_build_method_call_int() { - let client = RtorrentClient::new("dummy"); + let client = RtorrentClient::new_unittest(); let params = vec![RpcParam::Int(1024)]; let xml = client.build_method_call("test.int", ¶ms).unwrap(); // Should produce 1024 @@ -319,27 +342,27 @@ mod tests { #[test] fn test_parse_multicall_response() { let xml = r#" - - - - - - - - -HASH123 -Ubuntu ISO -1024 - - - - - - - - - -"#; + + + + + + + + + HASH123 + Ubuntu ISO + 1024 + + + + + + + + + + "#; let result = parse_multicall_response(xml).expect("Failed to parse"); assert_eq!(result.len(), 1); assert_eq!(result[0][0], "HASH123");