Implement authentication system with SQLite: Add login/setup pages, auth middleware, and database integration
Some checks failed
Build MIPS Binary / build (push) Failing after 3m42s

This commit is contained in:
spinline
2026-02-07 14:43:25 +03:00
parent 92720c15b3
commit d53d661ad1
11 changed files with 1408 additions and 87 deletions

106
backend/src/db.rs Normal file
View File

@@ -0,0 +1,106 @@
use sqlx::{sqlite::SqlitePoolOptions, Pool, Sqlite, Row};
use std::time::Duration;
use anyhow::Result;
#[derive(Clone)]
pub struct Db {
pool: Pool<Sqlite>,
}
impl Db {
pub async fn new(db_url: &str) -> Result<Self> {
let pool = SqlitePoolOptions::new()
.max_connections(5)
.acquire_timeout(Duration::from_secs(3))
.connect(db_url)
.await?;
let db = Self { pool };
db.init().await?;
Ok(db)
}
async fn init(&self) -> Result<()> {
// Create users table
sqlx::query(
"CREATE TABLE IF NOT EXISTS users (
id INTEGER PRIMARY KEY,
username TEXT NOT NULL UNIQUE,
password_hash TEXT NOT NULL,
created_at DATETIME DEFAULT CURRENT_TIMESTAMP
)",
)
.execute(&self.pool)
.await?;
// Create sessions table
sqlx::query(
"CREATE TABLE IF NOT EXISTS sessions (
token TEXT PRIMARY KEY,
user_id INTEGER NOT NULL,
expires_at DATETIME NOT NULL,
FOREIGN KEY(user_id) REFERENCES users(id)
)",
)
.execute(&self.pool)
.await?;
Ok(())
}
// --- User Operations ---
pub async fn create_user(&self, username: &str, password_hash: &str) -> Result<()> {
sqlx::query("INSERT INTO users (username, password_hash) VALUES (?, ?)")
.bind(username)
.bind(password_hash)
.execute(&self.pool)
.await?;
Ok(())
}
pub async fn get_user_by_username(&self, username: &str) -> Result<Option<(i64, String)>> {
let row = sqlx::query("SELECT id, password_hash FROM users WHERE username = ?")
.bind(username)
.fetch_optional(&self.pool)
.await?;
Ok(row.map(|r| (r.get(0), r.get(1))))
}
pub async fn has_users(&self) -> Result<bool> {
let row: (i64,) = sqlx::query_as("SELECT COUNT(*) FROM users")
.fetch_one(&self.pool)
.await?;
Ok(row.0 > 0)
}
// --- Session Operations ---
pub async fn create_session(&self, user_id: i64, token: &str, expires_at: i64) -> Result<()> {
sqlx::query("INSERT INTO sessions (token, user_id, expires_at) VALUES (?, ?, datetime(?, 'unixepoch'))")
.bind(token)
.bind(user_id)
.bind(expires_at)
.execute(&self.pool)
.await?;
Ok(())
}
pub async fn get_session_user(&self, token: &str) -> Result<Option<i64>> {
let row = sqlx::query("SELECT user_id FROM sessions WHERE token = ? AND expires_at > datetime('now')")
.bind(token)
.fetch_optional(&self.pool)
.await?;
Ok(row.map(|r| r.get(0)))
}
pub async fn delete_session(&self, token: &str) -> Result<()> {
sqlx::query("DELETE FROM sessions WHERE token = ?")
.bind(token)
.execute(&self.pool)
.await?;
Ok(())
}
}

View File

@@ -18,6 +18,9 @@ use shared::{
};
use utoipa::ToSchema;
pub mod auth;
pub mod setup;
#[derive(RustEmbed)]
#[folder = "../frontend/dist"]
pub struct Asset;
@@ -709,8 +712,8 @@ pub async fn subscribe_push_handler(
Json(subscription): Json<push::PushSubscription>,
) -> impl IntoResponse {
tracing::info!("Received push subscription: {:?}", subscription);
state.push_store.add_subscription(subscription).await;
(StatusCode::OK, "Subscription saved").into_response()
}

View File

@@ -0,0 +1,66 @@
use crate::{db::Db, AppState};
use axum::{
extract::{State, Json},
http::StatusCode,
response::IntoResponse,
};
use serde::{Deserialize, Serialize};
use utoipa::ToSchema;
#[derive(Deserialize, ToSchema)]
pub struct SetupRequest {
username: String,
password: String,
}
#[derive(Serialize)]
pub struct SetupStatusResponse {
completed: bool,
}
pub async fn get_setup_status_handler(State(state): State<AppState>) -> impl IntoResponse {
let completed = match state.db.has_users().await {
Ok(has) => has,
Err(e) => {
tracing::error!("DB error checking users: {}", e);
false
}
};
Json(SetupStatusResponse { completed }).into_response()
}
pub async fn setup_handler(
State(state): State<AppState>,
Json(payload): Json<SetupRequest>,
) -> impl IntoResponse {
// 1. Check if setup is already completed (i.e., users exist)
match state.db.has_users().await {
Ok(true) => return (StatusCode::FORBIDDEN, "Setup already completed").into_response(),
Err(e) => {
tracing::error!("DB error checking users: {}", e);
return (StatusCode::INTERNAL_SERVER_ERROR, "Database error").into_response();
}
Ok(false) => {} // Proceed
}
// 2. Validate input
if payload.username.len() < 3 || payload.password.len() < 6 {
return (StatusCode::BAD_REQUEST, "Username must be at least 3 chars, password at least 6").into_response();
}
// 3. Create User
let password_hash = match bcrypt::hash(&payload.password, bcrypt::DEFAULT_COST) {
Ok(h) => h,
Err(e) => {
tracing::error!("Failed to hash password: {}", e);
return (StatusCode::INTERNAL_SERVER_ERROR, "Failed to process password").into_response();
}
};
if let Err(e) = state.db.create_user(&payload.username, &password_hash).await {
tracing::error!("Failed to create user: {}", e);
return (StatusCode::INTERNAL_SERVER_ERROR, "Failed to create user").into_response();
}
(StatusCode::OK, "Setup completed successfully").into_response()
}

View File

@@ -1,3 +1,4 @@
mod db;
mod diff;
mod handlers;
#[cfg(feature = "push-notifications")]
@@ -10,7 +11,12 @@ use axum::error_handling::HandleErrorLayer;
use axum::{
routing::{get, post},
Router,
middleware::{self, Next},
extract::Request,
response::Response,
http::StatusCode,
};
use axum_extra::extract::cookie::CookieJar;
use clap::Parser;
use dotenvy::dotenv;
use shared::{AppEvent, Torrent};
@@ -32,10 +38,40 @@ pub struct AppState {
pub tx: Arc<watch::Sender<Vec<Torrent>>>,
pub event_bus: broadcast::Sender<AppEvent>,
pub scgi_socket_path: String,
pub db: db::Db,
#[cfg(feature = "push-notifications")]
pub push_store: push::PushSubscriptionStore,
}
async fn auth_middleware(
state: axum::extract::State<AppState>,
jar: CookieJar,
request: Request,
next: Next,
) -> Result<Response, StatusCode> {
// Skip auth for public paths
let path = request.uri().path();
if path.starts_with("/api/auth/login")
|| path.starts_with("/api/auth/check") // Used by frontend to decide where to go
|| path.starts_with("/api/setup")
|| path.starts_with("/swagger-ui")
|| path.starts_with("/api-docs")
|| !path.starts_with("/api/") // Allow static files (frontend)
{
return Ok(next.run(request).await);
}
// Check token
if let Some(token) = jar.get("auth_token") {
match state.db.get_session_user(token.value()).await {
Ok(Some(_)) => return Ok(next.run(request).await),
_ => {} // Invalid
}
}
Err(StatusCode::UNAUTHORIZED)
}
#[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)]
struct Args {
@@ -51,6 +87,10 @@ struct Args {
/// Port to listen on
#[arg(short, long, env = "PORT", default_value_t = 3000)]
port: u16,
/// Database URL
#[arg(long, env = "DATABASE_URL", default_value = "sqlite:vibetorrent.db")]
db_url: String,
}
#[cfg(feature = "push-notifications")]
@@ -68,7 +108,10 @@ struct Args {
handlers::get_global_limit_handler,
handlers::set_global_limit_handler,
handlers::get_push_public_key_handler,
handlers::subscribe_push_handler
handlers::subscribe_push_handler,
handlers::auth::login_handler,
handlers::setup::setup_handler,
handlers::setup::get_setup_status_handler
),
components(
schemas(
@@ -83,7 +126,9 @@ struct Args {
shared::SetLabelRequest,
shared::GlobalLimitRequest,
push::PushSubscription,
push::PushKeys
push::PushKeys,
handlers::auth::LoginRequest,
handlers::setup::SetupRequest
)
),
tags(
@@ -105,7 +150,10 @@ struct ApiDoc;
handlers::set_file_priority_handler,
handlers::set_label_handler,
handlers::get_global_limit_handler,
handlers::set_global_limit_handler
handlers::set_global_limit_handler,
handlers::auth::login_handler,
handlers::setup::setup_handler,
handlers::setup::get_setup_status_handler
),
components(
schemas(
@@ -118,7 +166,9 @@ struct ApiDoc;
shared::TorrentTracker,
shared::SetFilePriorityRequest,
shared::SetLabelRequest,
shared::GlobalLimitRequest
shared::GlobalLimitRequest,
handlers::auth::LoginRequest,
handlers::setup::SetupRequest
)
),
tags(
@@ -146,6 +196,29 @@ async fn main() {
tracing::info!("Socket: {}", args.socket);
tracing::info!("Port: {}", args.port);
// Initialize Database
tracing::info!("Connecting to database: {}", args.db_url);
// Ensure the db file exists if it's sqlite
if args.db_url.starts_with("sqlite:") {
let path = args.db_url.trim_start_matches("sqlite:");
if !std::path::Path::new(path).exists() {
tracing::info!("Database file not found, creating: {}", path);
match std::fs::File::create(path) {
Ok(_) => tracing::info!("Created empty database file"),
Err(e) => tracing::error!("Failed to create database file: {}", e),
}
}
}
let db = match db::Db::new(&args.db_url).await {
Ok(db) => db,
Err(e) => {
tracing::error!("Failed to connect to database: {}", e);
std::process::exit(1);
}
};
tracing::info!("Database connected successfully.");
// Startup Health Check
let socket_path = std::path::Path::new(&args.socket);
if !socket_path.exists() {
@@ -181,6 +254,7 @@ async fn main() {
tx: tx.clone(),
event_bus: event_bus.clone(),
scgi_socket_path: args.socket.clone(),
db: db.clone(),
#[cfg(feature = "push-notifications")]
push_store: push::PushSubscriptionStore::new(),
};
@@ -308,6 +382,13 @@ async fn main() {
let app = Router::new()
.merge(SwaggerUi::new("/swagger-ui").url("/api-docs/openapi.json", ApiDoc::openapi()))
// Setup & Auth Routes
.route("/api/setup/status", get(handlers::setup::get_setup_status_handler))
.route("/api/setup", post(handlers::setup::setup_handler))
.route("/api/auth/login", post(handlers::auth::login_handler))
.route("/api/auth/logout", post(handlers::auth::logout_handler))
.route("/api/auth/check", get(handlers::auth::check_auth_handler))
// App Routes
.route("/api/events", get(sse::sse_handler))
.route("/api/torrents/add", post(handlers::add_torrent_handler))
.route(
@@ -337,13 +418,14 @@ async fn main() {
get(handlers::get_global_limit_handler).post(handlers::set_global_limit_handler),
)
.fallback(handlers::static_handler); // Serve static files for everything else
#[cfg(feature = "push-notifications")]
let app = app
.route("/api/push/public-key", get(handlers::get_push_public_key_handler))
.route("/api/push/subscribe", post(handlers::subscribe_push_handler));
let app = app
.layer(middleware::from_fn_with_state(app_state.clone(), auth_middleware))
.layer(TraceLayer::new_for_http())
.layer(
CompressionLayer::new()