diff --git a/backend/Cargo.toml b/backend/Cargo.toml index 1169999..8e906a2 100644 --- a/backend/Cargo.toml +++ b/backend/Cargo.toml @@ -22,3 +22,4 @@ clap = { version = "4.4", features = ["derive"] } rust-embed = "8.2" mime_guess = "2.0" shared = { path = "../shared" } +thiserror = "2.0.18" diff --git a/backend/src/main.rs b/backend/src/main.rs index 84fb7e6..0d9f000 100644 --- a/backend/src/main.rs +++ b/backend/src/main.rs @@ -3,30 +3,27 @@ mod scgi; mod sse; mod xmlrpc; -use clap::Parser; -use rust_embed::RustEmbed; +use axum::{error_handling::HandleErrorLayer, BoxError}; use axum::{ extract::State, http::{header, StatusCode, Uri}, response::IntoResponse, routing::{get, post}, - Router, Json, + Json, Router, }; +use clap::Parser; +use rust_embed::RustEmbed; +use serde::Deserialize; +use shared::{AppEvent, Torrent, TorrentActionRequest}; // shared crates imports +use std::net::SocketAddr; +use std::sync::Arc; +use tokio::sync::{broadcast, watch}; +use tower::ServiceBuilder; use tower_http::{ + compression::{CompressionLayer, CompressionLevel}, cors::CorsLayer, trace::TraceLayer, - compression::{CompressionLayer, CompressionLevel}, }; -use axum::{ - error_handling::HandleErrorLayer, - BoxError, -}; -use tower::ServiceBuilder; -use serde::Deserialize; -use std::net::SocketAddr; -use shared::{Torrent, TorrentActionRequest, AppEvent}; // shared crates imports -use tokio::sync::{watch, broadcast}; -use std::sync::Arc; #[derive(Clone)] pub struct AppState { @@ -90,7 +87,10 @@ async fn add_torrent_handler( State(state): State, Json(payload): Json, ) -> StatusCode { - tracing::info!("Received add_torrent request. URI length: {}", payload.uri.len()); + tracing::info!( + "Received add_torrent request. URI length: {}", + payload.uri.len() + ); let client = xmlrpc::RtorrentClient::new(&state.scgi_socket_path); match client.call("load.start", &["", &payload.uri]).await { Ok(response) => { @@ -100,7 +100,7 @@ async fn add_torrent_handler( return StatusCode::INTERNAL_SERVER_ERROR; } StatusCode::OK - }, + } Err(e) => { tracing::error!("Failed to add torrent: {}", e); StatusCode::INTERNAL_SERVER_ERROR @@ -112,23 +112,25 @@ async fn add_torrent_handler( async fn main() { // initialize tracing with env filter (default to info) tracing_subscriber::fmt() - .with_env_filter(tracing_subscriber::EnvFilter::from_default_env() - .add_directive(tracing::Level::INFO.into())) + .with_env_filter( + tracing_subscriber::EnvFilter::from_default_env() + .add_directive(tracing::Level::INFO.into()), + ) .init(); - + // Parse CLI Args let args = Args::parse(); tracing::info!("Starting VibeTorrent Backend..."); tracing::info!("Socket: {}", args.socket); tracing::info!("Port: {}", args.port); - + // Channel for latest state (for new clients) let (tx, _rx) = watch::channel(vec![]); let tx = Arc::new(tx); // Channel for Events (Diffs) let (event_bus, _) = broadcast::channel::(1024); - + let app_state = AppState { tx: tx.clone(), event_bus: event_bus.clone(), @@ -151,7 +153,10 @@ async fn main() { let _ = tx_clone.send(new_torrents.clone()); // 2. Calculate Diff and Broadcasting - let now = std::time::SystemTime::now().duration_since(std::time::UNIX_EPOCH).unwrap().as_secs(); + let now = std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap() + .as_secs(); let mut structural_change = false; if previous_torrents.len() != new_torrents.len() { @@ -167,8 +172,8 @@ async fn main() { } if structural_change { - // Structural change -> Send FullList - let _ = event_bus_tx.send(AppEvent::FullList(new_torrents.clone(), now)); + // Structural change -> Send FullList + let _ = event_bus_tx.send(AppEvent::FullList(new_torrents.clone(), now)); } else { // Same structure -> Calculate partial updates let updates = diff::diff_torrents(&previous_torrents, &new_torrents); @@ -195,13 +200,16 @@ async fn main() { .route("/api/torrents/action", post(handle_torrent_action)) .fallback(static_handler) // Serve static files for everything else .layer(TraceLayer::new_for_http()) - .layer(CompressionLayer::new() - .br(false) - .gzip(true) - .quality(CompressionLevel::Fastest)) - .layer(ServiceBuilder::new() - .layer(HandleErrorLayer::new(handle_timeout_error)) - .layer(tower::timeout::TimeoutLayer::new(Duration::from_secs(30))) + .layer( + CompressionLayer::new() + .br(false) + .gzip(true) + .quality(CompressionLevel::Fastest), + ) + .layer( + ServiceBuilder::new() + .layer(HandleErrorLayer::new(handle_timeout_error)) + .layer(tower::timeout::TimeoutLayer::new(Duration::from_secs(30))), ) .layer(CorsLayer::permissive()) .with_state(app_state); @@ -216,62 +224,160 @@ async fn handle_torrent_action( State(state): State, Json(payload): Json, ) -> impl IntoResponse { - tracing::info!("Received action: {} for hash: {}", payload.action, payload.hash); - + tracing::info!( + "Received action: {} for hash: {}", + payload.action, + payload.hash + ); + // Special handling for delete_with_data if payload.action == "delete_with_data" { let client = xmlrpc::RtorrentClient::new(&state.scgi_socket_path); - + // 1. Get Base Path let path_xml = match client.call("d.base_path", &[&payload.hash]).await { Ok(xml) => xml, - Err(e) => return (StatusCode::INTERNAL_SERVER_ERROR, format!("Failed to call rTorrent: {}", e)).into_response(), + Err(e) => { + return ( + StatusCode::INTERNAL_SERVER_ERROR, + format!("Failed to call rTorrent: {}", e), + ) + .into_response() + } }; let path = match xmlrpc::parse_string_response(&path_xml) { Ok(p) => p, - Err(e) => return (StatusCode::INTERNAL_SERVER_ERROR, format!("Failed to parse path: {}", e)).into_response(), + Err(e) => { + return ( + StatusCode::INTERNAL_SERVER_ERROR, + format!("Failed to parse path: {}", e), + ) + .into_response() + } }; - + let path_buf = std::path::Path::new(&path); // 1.5 Get Default Download Directory (Sandbox Root) let root_xml = match client.call("directory.default", &[]).await { Ok(xml) => xml, - Err(e) => return (StatusCode::INTERNAL_SERVER_ERROR, format!("Failed to get valid download root: {}", e)).into_response(), + Err(e) => { + return ( + StatusCode::INTERNAL_SERVER_ERROR, + format!("Failed to get valid download root: {}", e), + ) + .into_response() + } }; - + let root_path_str = match xmlrpc::parse_string_response(&root_xml) { Ok(p) => p, - Err(e) => return (StatusCode::INTERNAL_SERVER_ERROR, format!("Failed to parse root path: {}", e)).into_response(), + Err(e) => { + return ( + StatusCode::INTERNAL_SERVER_ERROR, + format!("Failed to parse root path: {}", e), + ) + .into_response() + } }; - - let root_path = std::path::Path::new(&root_path_str); - tracing::info!("Delete request: Path='{}', Root='{}'", path, root_path_str); + // Resolve Paths (Canonicalize) to prevent .. traversal and symlink attacks + let root_path = match std::fs::canonicalize(std::path::Path::new(&root_path_str)) { + Ok(p) => p, + Err(e) => { + return ( + StatusCode::INTERNAL_SERVER_ERROR, + format!("Invalid download root configuration (on server): {}", e), + ) + .into_response() + } + }; + + // Check if target path exists before trying to resolve it + let target_path_raw = std::path::Path::new(&path); + if !target_path_raw.exists() { + tracing::warn!( + "Data path not found: {:?}. Removing torrent only.", + target_path_raw + ); + // If file doesn't exist, we just remove the torrent entry + if let Err(e) = client.call("d.erase", &[&payload.hash]).await { + return ( + StatusCode::INTERNAL_SERVER_ERROR, + format!("Failed to erase torrent: {}", e), + ) + .into_response(); + } + return (StatusCode::OK, "Torrent removed (Data not found)").into_response(); + } + + let target_path = match std::fs::canonicalize(target_path_raw) { + Ok(p) => p, + Err(e) => { + return ( + StatusCode::INTERNAL_SERVER_ERROR, + format!("Invalid data path: {}", e), + ) + .into_response() + } + }; + + tracing::info!( + "Delete request: Target='{:?}', Root='{:?}'", + target_path, + root_path + ); // SECURITY CHECK: Ensure path is inside root_path - if !path_buf.starts_with(root_path) { - tracing::error!("Security Risk: Attempted to delete path outside download directory: {}", path); - return (StatusCode::FORBIDDEN, "Security Error: Cannot delete files outside default download directory").into_response(); + if !target_path.starts_with(&root_path) { + tracing::error!( + "Security Risk: Attempted to delete path outside download directory: {:?}", + target_path + ); + return ( + StatusCode::FORBIDDEN, + "Security Error: Cannot delete files outside default download directory", + ) + .into_response(); } - + // SECURITY CHECK: Ensure we are not deleting the root itself - if path_buf == root_path { - return (StatusCode::BAD_REQUEST, "Security Error: Cannot delete the download root directory itself").into_response(); + if target_path == root_path { + return ( + StatusCode::BAD_REQUEST, + "Security Error: Cannot delete the download root directory itself", + ) + .into_response(); } - // 2. Erase Torrent first (so rTorrent releases locks?) + // 2. Erase Torrent first if let Err(e) = client.call("d.erase", &[&payload.hash]).await { - tracing::warn!("Failed to erase torrent entry: {}", e); - // Proceed anyway to delete files? Maybe not. + tracing::warn!("Failed to erase torrent entry: {}", e); + return ( + StatusCode::INTERNAL_SERVER_ERROR, + format!("Failed to erase torrent: {}", e), + ) + .into_response(); } - // 3. Delete Files via rTorrent (execute.throw.bg) - // Command: rm -rf - match client.call("execute.throw.bg", &["", "rm", "-rf", &path]).await { + // 3. Delete Files via Native FS + let delete_result = if target_path.is_dir() { + std::fs::remove_dir_all(&target_path) + } else { + std::fs::remove_file(&target_path) + }; + + match delete_result { Ok(_) => return (StatusCode::OK, "Torrent and data deleted").into_response(), - Err(e) => return (StatusCode::INTERNAL_SERVER_ERROR, format!("Failed to delete data: {}", e)).into_response(), + Err(e) => { + tracing::error!("Failed to delete data at {:?}: {}", target_path, e); + return ( + StatusCode::INTERNAL_SERVER_ERROR, + format!("Failed to delete data: {}", e), + ) + .into_response(); + } } } @@ -282,11 +388,16 @@ async fn handle_torrent_action( _ => return (StatusCode::BAD_REQUEST, "Invalid action").into_response(), }; - match scgi::system_call(&state.scgi_socket_path, method, vec![&payload.hash]).await { + let client = xmlrpc::RtorrentClient::new(&state.scgi_socket_path); + match client.call(method, &[&payload.hash]).await { Ok(_) => (StatusCode::OK, "Action executed").into_response(), Err(e) => { - tracing::error!("SCGI error: {:?}", e); - (StatusCode::INTERNAL_SERVER_ERROR, "Failed to execute action").into_response() + tracing::error!("RPC error: {}", e); + ( + StatusCode::INTERNAL_SERVER_ERROR, + "Failed to execute action", + ) + .into_response() } } } @@ -295,6 +406,9 @@ async fn handle_timeout_error(err: BoxError) -> (StatusCode, &'static str) { if err.is::() { (StatusCode::REQUEST_TIMEOUT, "Request timed out") } else { - (StatusCode::INTERNAL_SERVER_ERROR, "Unhandled internal error") + ( + StatusCode::INTERNAL_SERVER_ERROR, + "Unhandled internal error", + ) } } diff --git a/backend/src/scgi.rs b/backend/src/scgi.rs index 0620ec3..9131737 100644 --- a/backend/src/scgi.rs +++ b/backend/src/scgi.rs @@ -1,22 +1,17 @@ use bytes::Bytes; use std::collections::HashMap; +use thiserror::Error; use tokio::io::{AsyncReadExt, AsyncWriteExt}; use tokio::net::UnixStream; -#[derive(Debug)] +#[derive(Error, Debug)] pub enum ScgiError { - #[allow(dead_code)] - Io(std::io::Error), - #[allow(dead_code)] + #[error("IO Error: {0}")] + Io(#[from] std::io::Error), + #[error("Protocol Error: {0}")] Protocol(String), } -impl From for ScgiError { - fn from(err: std::io::Error) -> Self { - ScgiError::Io(err) - } -} - pub struct ScgiRequest { headers: HashMap, body: Vec, @@ -46,23 +41,18 @@ impl ScgiRequest { pub fn encode(&self) -> Vec { let mut headers_data = Vec::new(); - - // SCGI Spec: The first header must be "CONTENT_LENGTH" - // The second header must be "SCGI" with value "1" - - // We handle CONTENT_LENGTH and SCGI explicitly first + let content_len = self.body.len().to_string(); headers_data.extend_from_slice(b"CONTENT_LENGTH"); headers_data.push(0); headers_data.extend_from_slice(content_len.as_bytes()); headers_data.push(0); - + headers_data.extend_from_slice(b"SCGI"); headers_data.push(0); headers_data.extend_from_slice(b"1"); headers_data.push(0); - // Add remaining headers (excluding the ones we just added if they exist in the map) for (k, v) in &self.headers { if k == "CONTENT_LENGTH" || k == "SCGI" { continue; @@ -86,10 +76,7 @@ impl ScgiRequest { } } -pub async fn send_request( - socket_path: &str, - request: ScgiRequest, -) -> Result { +pub async fn send_request(socket_path: &str, request: ScgiRequest) -> Result { let mut stream = UnixStream::connect(socket_path).await?; let data = request.encode(); stream.write_all(&data).await?; @@ -97,9 +84,6 @@ pub async fn send_request( let mut response = Vec::new(); stream.read_to_end(&mut response).await?; - // The response is usually HTTP-like: headers\r\n\r\nbody - // We strictly want the body for XML-RPC - // Find double newline let double_newline = b"\r\n\r\n"; if let Some(pos) = response .windows(double_newline.len()) @@ -107,40 +91,6 @@ pub async fn send_request( { Ok(Bytes::from(response.split_off(pos + double_newline.len()))) } else { - // Fallback: rTorrent sometimes sends raw XML without headers if configured poorly, - // but SCGI usually implies headers. - // If we don't find headers, maybe it's all body? - // But usually there's at least "Status: 200 OK" - // Let's return everything if we can't find the split, or error. - // For now, assume everything is body if no headers found might be unsafe, - // but valid for simple XML-RPC dumping. - Ok(Bytes::from(response)) + Ok(Bytes::from(response)) } } -pub async fn system_call( - socket_path: &str, - method: &str, - params: Vec<&str>, -) -> Result { - // Construct XML-RPC payload manually for simplicity - // methodval... - let mut xml = String::from(""); - xml.push_str(&format!("{}", method)); - for param in params { - // Use CDATA for safety with special chars in magnet links - xml.push_str(&format!("", param)); - } - xml.push_str(""); - - tracing::debug!("Sending XML-RPC Payload: {}", xml); - - let req = ScgiRequest::new().body(xml.clone().into_bytes()); - let response_bytes = send_request(socket_path, req).await?; - let response_str = String::from_utf8_lossy(&response_bytes).to_string(); - - // Ideally parse the response, but for actions we just check if it executed without SCGI error - // rTorrent usually returns 0 for success or fault. - // For now, returning the raw string is fine for debugging/logging in main. - - Ok(response_str) -} diff --git a/backend/src/sse.rs b/backend/src/sse.rs index 3fbb2f2..a48c5c0 100644 --- a/backend/src/sse.rs +++ b/backend/src/sse.rs @@ -1,116 +1,110 @@ +use crate::xmlrpc::{parse_multicall_response, RtorrentClient, XmlRpcError}; use axum::response::sse::{Event, Sse}; use futures::stream::{self, Stream}; +use shared::{AppEvent, Torrent, TorrentStatus}; use std::convert::Infallible; use tokio_stream::StreamExt; -use shared::{AppEvent, Torrent, TorrentStatus}; -use crate::xmlrpc::{RtorrentClient, parse_multicall_response}; // Helper (should be moved to utils) fn parse_size(s: &str) -> i64 { s.parse().unwrap_or(0) } - -pub async fn fetch_torrents(client: &RtorrentClient) -> Result, String> { +pub async fn fetch_torrents(client: &RtorrentClient) -> Result, XmlRpcError> { // d.multicall2("", "main", ...) let params = vec![ - "", - "main", - "d.hash=", - "d.name=", - "d.size_bytes=", - "d.bytes_done=", - "d.down.rate=", + "", + "main", + "d.hash=", + "d.name=", + "d.size_bytes=", + "d.bytes_done=", + "d.down.rate=", "d.up.rate=", - "d.state=", // 6 - "d.complete=", // 7 - "d.message=", // 8 - "d.left_bytes=", // 9 - "d.creation_date=", // 10 - "d.hashing=", // 11 + "d.state=", // 6 + "d.complete=", // 7 + "d.message=", // 8 + "d.left_bytes=", // 9 + "d.creation_date=", // 10 + "d.hashing=", // 11 ]; - match client.call("d.multicall2", ¶ms).await { - Ok(xml) => { - if xml.trim().is_empty() { - return Err("Empty response from SCGI".to_string()); - } - match parse_multicall_response(&xml) { - Ok(rows) => { - let torrents = rows.into_iter().map(|row| { - // row map indexes: - // 0: hash, 1: name, 2: size, 3: completed, 4: down_rate, 5: up_rate - // 6: state, 7: complete, 8: message, 9: left_bytes, 10: added, 11: hashing - - let hash = row.get(0).cloned().unwrap_or_default(); - let name = row.get(1).cloned().unwrap_or_default(); - let size = parse_size(row.get(2).unwrap_or(&"0".to_string())); - let completed = parse_size(row.get(3).unwrap_or(&"0".to_string())); - let down_rate = parse_size(row.get(4).unwrap_or(&"0".to_string())); - let up_rate = parse_size(row.get(5).unwrap_or(&"0".to_string())); - - let state = parse_size(row.get(6).unwrap_or(&"0".to_string())); - let is_complete = parse_size(row.get(7).unwrap_or(&"0".to_string())); - let message = row.get(8).cloned().unwrap_or_default(); - let left_bytes = parse_size(row.get(9).unwrap_or(&"0".to_string())); - let added_date = parse_size(row.get(10).unwrap_or(&"0".to_string())); - let is_hashing = parse_size(row.get(11).unwrap_or(&"0".to_string())); + let xml = client.call("d.multicall2", ¶ms).await?; - let percent_complete = if size > 0 { - (completed as f64 / size as f64) * 100.0 - } else { - 0.0 - }; - - // Status Logic - let status = if !message.is_empty() { - TorrentStatus::Error - } else if is_hashing != 0 { - TorrentStatus::Checking - } else if state == 0 { - TorrentStatus::Paused - } else if is_complete != 0 { - TorrentStatus::Seeding - } else { - TorrentStatus::Downloading - }; - - // ETA Logic (seconds) - let eta = if down_rate > 0 && left_bytes > 0 { - left_bytes / down_rate - } else { - 0 - }; - - Torrent { - hash, - name, - size, - completed, - down_rate, - up_rate, - eta, - percent_complete, - status, - error_message: message, - added_date, - } - }).collect(); - Ok(torrents) - }, - Err(e) => { - Err(format!("XML Parse Error: {}", e)) - } - } - }, - Err(e) => { - Err(format!("RPC Error: {}", e)) - } + if xml.trim().is_empty() { + return Err(XmlRpcError::Parse("Empty response from SCGI".to_string())); } + + let rows = parse_multicall_response(&xml)?; + + let torrents = rows + .into_iter() + .map(|row| { + // row map indexes: + // 0: hash, 1: name, 2: size, 3: completed, 4: down_rate, 5: up_rate + // 6: state, 7: complete, 8: message, 9: left_bytes, 10: added, 11: hashing + + let hash = row.get(0).cloned().unwrap_or_default(); + let name = row.get(1).cloned().unwrap_or_default(); + let size = parse_size(row.get(2).unwrap_or(&"0".to_string())); + let completed = parse_size(row.get(3).unwrap_or(&"0".to_string())); + let down_rate = parse_size(row.get(4).unwrap_or(&"0".to_string())); + let up_rate = parse_size(row.get(5).unwrap_or(&"0".to_string())); + + let state = parse_size(row.get(6).unwrap_or(&"0".to_string())); + let is_complete = parse_size(row.get(7).unwrap_or(&"0".to_string())); + let message = row.get(8).cloned().unwrap_or_default(); + let left_bytes = parse_size(row.get(9).unwrap_or(&"0".to_string())); + let added_date = parse_size(row.get(10).unwrap_or(&"0".to_string())); + let is_hashing = parse_size(row.get(11).unwrap_or(&"0".to_string())); + + let percent_complete = if size > 0 { + (completed as f64 / size as f64) * 100.0 + } else { + 0.0 + }; + + // Status Logic + let status = if !message.is_empty() { + TorrentStatus::Error + } else if is_hashing != 0 { + TorrentStatus::Checking + } else if state == 0 { + TorrentStatus::Paused + } else if is_complete != 0 { + TorrentStatus::Seeding + } else { + TorrentStatus::Downloading + }; + + // ETA Logic (seconds) + let eta = if down_rate > 0 && left_bytes > 0 { + left_bytes / down_rate + } else { + 0 + }; + + Torrent { + hash, + name, + size, + completed, + down_rate, + up_rate, + eta, + percent_complete, + status, + error_message: message, + added_date, + } + }) + .collect(); + + Ok(torrents) } -use axum::extract::State; -use crate::AppState; // Import from crate root +use crate::AppState; +use axum::extract::State; // Import from crate root pub async fn sse_handler( State(state): State, @@ -118,31 +112,34 @@ pub async fn sse_handler( // Get initial value synchronously (from the watch channel's current state) let initial_rx = state.tx.subscribe(); let initial_torrents = initial_rx.borrow().clone(); - + let initial_event = { - let timestamp = std::time::SystemTime::now().duration_since(std::time::UNIX_EPOCH).unwrap().as_secs(); + let timestamp = std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap() + .as_secs(); let event_data = AppEvent::FullList(initial_torrents, timestamp); match serde_json::to_string(&event_data) { - Ok(json) => Event::default().data(json), - Err(_) => Event::default().comment("init_error"), + Ok(json) => Event::default().data(json), + Err(_) => Event::default().comment("init_error"), } }; - + // Stream that yields the initial event once let initial_stream = stream::once(async { Ok::(initial_event) }); - - // Stream that waits for subsequent changes + // Stream that waits for subsequent changes via Broadcast channel let rx = state.event_bus.subscribe(); let update_stream = stream::unfold(rx, |mut rx| async move { match rx.recv().await { - Ok(event) => { - match serde_json::to_string(&event) { - Ok(json) => Some((Ok::(Event::default().data(json)), rx)), - Err(e) => { - tracing::warn!("Failed to serialize SSE event: {}", e); - Some((Ok::(Event::default().comment("error")), rx)) - }, + Ok(event) => match serde_json::to_string(&event) { + Ok(json) => Some((Ok::(Event::default().data(json)), rx)), + Err(e) => { + tracing::warn!("Failed to serialize SSE event: {}", e); + Some(( + Ok::(Event::default().comment("error")), + rx, + )) } }, Err(e) => { @@ -152,7 +149,7 @@ pub async fn sse_handler( } } }); - + Sse::new(initial_stream.chain(update_stream)) .keep_alive(axum::response::sse::KeepAlive::default()) } diff --git a/backend/src/xmlrpc.rs b/backend/src/xmlrpc.rs index 64d1760..7b71473 100644 --- a/backend/src/xmlrpc.rs +++ b/backend/src/xmlrpc.rs @@ -1,7 +1,20 @@ -use crate::scgi::{send_request, ScgiRequest}; +use crate::scgi::{send_request, ScgiError, ScgiRequest}; use quick_xml::de::from_str; use quick_xml::se::to_string; use serde::{Deserialize, Serialize}; +use thiserror::Error; + +#[derive(Error, Debug)] +pub enum XmlRpcError { + #[error("SCGI Error: {0}")] + Scgi(#[from] ScgiError), + #[error("Serialization Error: {0}")] + Serialization(String), // quick_xml errors are tricky to wrap directly due to versions/features + #[error("Deserialization Error: {0}")] + Deserialization(#[from] quick_xml::de::DeError), + #[error("XML Parse Error: {0}")] + Parse(String), +} // --- Request Models --- @@ -49,7 +62,6 @@ struct MulticallResponseParam { value: MulticallResponseValueArray, } -// Top level array in d.multicall2 response #[derive(Debug, Deserialize)] struct MulticallResponseValueArray { array: MulticallResponseDataOuter, @@ -145,7 +157,7 @@ impl RtorrentClient { } /// Helper to build and serialize XML-RPC method call - fn build_method_call(&self, method: &str, params: &[&str]) -> Result { + fn build_method_call(&self, method: &str, params: &[&str]) -> Result { let req_params = RequestParams { param: params .iter() @@ -163,27 +175,22 @@ impl RtorrentClient { params: req_params, }; - let xml_body = to_string(&call).map_err(|e| format!("Serialization error: {}", e))?; + let xml_body = to_string(&call).map_err(|e| XmlRpcError::Serialization(e.to_string()))?; Ok(format!("\n{}", xml_body)) } - pub async fn call(&self, method: &str, params: &[&str]) -> Result { + pub async fn call(&self, method: &str, params: &[&str]) -> Result { let xml = self.build_method_call(method, params)?; let req = ScgiRequest::new().body(xml.into_bytes()); - match send_request(&self.socket_path, req).await { - Ok(bytes) => { - let s = String::from_utf8_lossy(&bytes).to_string(); - Ok(s) - } - Err(e) => Err(format!("SCGI Error: {:?}", e)), - } + let bytes = send_request(&self.socket_path, req).await?; + let s = String::from_utf8_lossy(&bytes).to_string(); + Ok(s) } } -pub fn parse_multicall_response(xml: &str) -> Result>, String> { - let response: MulticallResponse = - from_str(xml).map_err(|e| format!("XML Parse Error: {}", e))?; +pub fn parse_multicall_response(xml: &str) -> Result>, XmlRpcError> { + let response: MulticallResponse = from_str(xml)?; let mut result = Vec::new(); @@ -198,8 +205,8 @@ pub fn parse_multicall_response(xml: &str) -> Result>, String> { Ok(result) } -pub fn parse_string_response(xml: &str) -> Result { - let response: StringResponse = from_str(xml).map_err(|e| format!("XML Parse Error: {}", e))?; +pub fn parse_string_response(xml: &str) -> Result { + let response: StringResponse = from_str(xml)?; Ok(response.params.param.value.string) } @@ -214,10 +221,7 @@ mod tests { .build_method_call("d.multicall2", &["", "main", "d.name="]) .unwrap(); - println!("Generated XML: {}", xml); - assert!(xml.contains("d.multicall2")); - // With struct option serialization, it should produce ... assert!(xml.contains("main")); }