commit d5641493b9c64498e6115f1b327070dd0944804a Author: Connor Johnstone Date: Tue Mar 17 15:31:29 2026 -0400 Initial commit diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..360fdc9 --- /dev/null +++ b/.gitignore @@ -0,0 +1,4 @@ +target/ +.env +*.db +*.db-journal diff --git a/Cargo.toml b/Cargo.toml new file mode 100644 index 0000000..baab03c --- /dev/null +++ b/Cargo.toml @@ -0,0 +1,25 @@ +[package] +name = "shanty-dl" +version = "0.1.0" +edition = "2024" +license = "MIT" +description = "Music downloading for Shanty" +repository = "ssh://connor@git.rcjohnstone.com:2222/Shanty/dl.git" + +[dependencies] +shanty-db = { path = "../shanty-db" } +sea-orm = { version = "1", features = ["sqlx-sqlite", "runtime-tokio-native-tls"] } +clap = { version = "4", features = ["derive"] } +serde = { version = "1", features = ["derive"] } +serde_json = "1" +thiserror = "2" +tracing = "0.1" +tracing-subscriber = { version = "0.3", features = ["env-filter"] } +tokio = { version = "1", features = ["full", "process"] } +anyhow = "1" +chrono = { version = "0.4", features = ["serde"] } +dirs = "6" + +[dev-dependencies] +tokio = { version = "1", features = ["full", "test-util"] } +tempfile = "3" diff --git a/readme.md b/readme.md new file mode 100644 index 0000000..c007f74 --- /dev/null +++ b/readme.md @@ -0,0 +1,43 @@ +# shanty-dl + +Music downloading for [Shanty](ssh://connor@git.rcjohnstone.com:2222/Shanty/shanty.git). + +Downloads music files using pluggable backends. The default backend uses yt-dlp with +YouTube Music search, rate limiting, and configurable output format. + +## Prerequisites + +- [yt-dlp](https://github.com/yt-dlp/yt-dlp) must be installed and on PATH +- [ffmpeg](https://ffmpeg.org/) for audio conversion + +## Usage + +```sh +# Download a single song +shanty-dl download "Pink Floyd Time" + +# Download from a direct URL +shanty-dl download "https://www.youtube.com/watch?v=..." + +# Add to download queue +shanty-dl queue add "Radiohead Creep" + +# Process the queue +shanty-dl queue process + +# List queue status +shanty-dl queue list + +# Retry failed downloads +shanty-dl queue retry +``` + +## YouTube Authentication + +For higher rate limits (2000/hr vs 500/hr), export cookies: + +```sh +yt-dlp --cookies-from-browser firefox --cookies ~/.config/shanty/cookies.txt +``` + +Then pass via `--cookies ~/.config/shanty/cookies.txt` or set `SHANTY_COOKIES`. diff --git a/src/backend.rs b/src/backend.rs new file mode 100644 index 0000000..e391e8a --- /dev/null +++ b/src/backend.rs @@ -0,0 +1,109 @@ +use std::path::PathBuf; + +use serde::{Deserialize, Serialize}; + +use crate::error::DlResult; + +/// A search result from a download backend. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct SearchResult { + /// URL or ID that can be passed to `download`. + pub url: String, + /// Title of the track. + pub title: String, + /// Artist/uploader name, if available. + pub artist: Option, + /// Duration in seconds, if available. + pub duration: Option, + /// Source identifier (e.g., "youtube_music", "youtube"). + pub source: String, +} + +/// What to download — either a direct URL or a search query. +#[derive(Debug, Clone)] +pub enum DownloadTarget { + /// A direct URL to download. + Url(String), + /// A search query to find and download. + Query(String), +} + +/// Result of a successful download. +#[derive(Debug, Clone)] +pub struct DownloadResult { + /// Path to the downloaded file. + pub file_path: PathBuf, + /// Title from the source. + pub title: String, + /// Artist/uploader from the source. + pub artist: Option, + /// Duration in seconds. + pub duration: Option, + /// The URL that was actually downloaded. + pub source_url: String, +} + +/// Configuration passed to the backend for a download. +#[derive(Debug, Clone)] +pub struct BackendConfig { + pub output_dir: PathBuf, + pub format: AudioFormat, + pub cookies_path: Option, +} + +/// Supported output audio formats. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum AudioFormat { + Opus, + Mp3, + Flac, + Best, +} + +impl AudioFormat { + pub fn as_ytdlp_arg(&self) -> &str { + match self { + AudioFormat::Opus => "opus", + AudioFormat::Mp3 => "mp3", + AudioFormat::Flac => "flac", + AudioFormat::Best => "best", + } + } +} + +impl std::str::FromStr for AudioFormat { + type Err = String; + + fn from_str(s: &str) -> Result { + match s.to_lowercase().as_str() { + "opus" => Ok(AudioFormat::Opus), + "mp3" => Ok(AudioFormat::Mp3), + "flac" => Ok(AudioFormat::Flac), + "best" => Ok(AudioFormat::Best), + _ => Err(format!("unsupported format: {s} (expected opus, mp3, flac, or best)")), + } + } +} + +impl std::fmt::Display for AudioFormat { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", self.as_ytdlp_arg()) + } +} + +/// Trait for download backends. yt-dlp is the default; others (torrents, Soulseek) +/// can be added later. +pub trait DownloadBackend: Send + Sync { + /// Check if this backend is available on the system. + fn check_available(&self) -> impl std::future::Future> + Send; + + /// Search for tracks matching a query. + fn search(&self, query: &str) -> impl std::future::Future>> + Send; + + /// Download a target to the configured output directory. + fn download( + &self, + target: &DownloadTarget, + config: &BackendConfig, + ) -> impl std::future::Future> + Send; +} diff --git a/src/error.rs b/src/error.rs new file mode 100644 index 0000000..2dab744 --- /dev/null +++ b/src/error.rs @@ -0,0 +1,45 @@ +use shanty_db::DbError; + +#[derive(Debug, thiserror::Error)] +pub enum DlError { + #[error("database error: {0}")] + Db(#[from] DbError), + + #[error("I/O error: {0}")] + Io(#[from] std::io::Error), + + #[error("yt-dlp not found: {0}")] + BackendNotFound(String), + + #[error("yt-dlp error: {0}")] + BackendError(String), + + #[error("rate limited: {0}")] + RateLimited(String), + + #[error("download failed: {0}")] + DownloadFailed(String), + + #[error("search returned no results for: {0}")] + NoResults(String), + + #[error("JSON parse error: {0}")] + Json(#[from] serde_json::Error), + + #[error("{0}")] + Other(String), +} + +impl DlError { + /// Whether this error is transient and worth retrying. + pub fn is_transient(&self) -> bool { + matches!( + self, + DlError::RateLimited(_) + | DlError::Io(_) + | DlError::BackendError(_) + ) + } +} + +pub type DlResult = Result; diff --git a/src/lib.rs b/src/lib.rs new file mode 100644 index 0000000..50ac56f --- /dev/null +++ b/src/lib.rs @@ -0,0 +1,15 @@ +//! Music downloading for Shanty. +//! +//! Downloads music files using pluggable backends. The default backend uses +//! yt-dlp with YouTube Music search, rate limiting, and configurable output format. + +pub mod backend; +pub mod error; +pub mod queue; +pub mod rate_limit; +pub mod ytdlp; + +pub use backend::{AudioFormat, BackendConfig, DownloadBackend, DownloadResult, DownloadTarget, SearchResult}; +pub use error::{DlError, DlResult}; +pub use queue::{DlStats, download_single, run_queue}; +pub use ytdlp::{SearchSource, YtDlpBackend}; diff --git a/src/main.rs b/src/main.rs new file mode 100644 index 0000000..6a8b917 --- /dev/null +++ b/src/main.rs @@ -0,0 +1,292 @@ +use std::path::PathBuf; + +use clap::{Parser, Subcommand}; +use tracing_subscriber::EnvFilter; + +use shanty_db::Database; +use shanty_db::entities::download_queue::DownloadStatus; +use shanty_db::queries; +use shanty_dl::{ + AudioFormat, BackendConfig, DownloadBackend, DownloadTarget, SearchSource, YtDlpBackend, + download_single, run_queue, +}; + +#[derive(Parser)] +#[command(name = "shanty-dl", about = "Download music files for Shanty")] +struct Cli { + #[command(subcommand)] + command: Commands, + + /// Database URL. Defaults to sqlite:///shanty/shanty.db?mode=rwc + #[arg(long, global = true, env = "SHANTY_DATABASE_URL")] + database: Option, + + /// Increase verbosity (-v info, -vv debug, -vvv trace). + #[arg(short, long, global = true, action = clap::ArgAction::Count)] + verbose: u8, +} + +#[derive(Subcommand)] +enum Commands { + /// Download a single song by query or URL. + Download { + /// Search query or direct URL. + query_or_url: String, + + /// Output audio format. + #[arg(long, default_value = "opus")] + format: String, + + /// Output directory for downloaded files. + #[arg(long)] + output: Option, + + /// Path to cookies.txt file for YouTube authentication. + #[arg(long, env = "SHANTY_COOKIES")] + cookies: Option, + + /// Search source (ytmusic or youtube). + #[arg(long, default_value = "ytmusic")] + search_source: String, + + /// Requests per hour limit. + #[arg(long, default_value = "450")] + rate_limit: u32, + + /// Preview what would be downloaded without doing it. + #[arg(long)] + dry_run: bool, + }, + /// Manage the download queue. + Queue { + #[command(subcommand)] + action: QueueAction, + }, +} + +#[derive(Subcommand)] +enum QueueAction { + /// Process all pending items in the download queue. + Process { + /// Output audio format. + #[arg(long, default_value = "opus")] + format: String, + + /// Output directory for downloaded files. + #[arg(long)] + output: Option, + + /// Path to cookies.txt for YouTube authentication. + #[arg(long, env = "SHANTY_COOKIES")] + cookies: Option, + + /// Search source (ytmusic or youtube). + #[arg(long, default_value = "ytmusic")] + search_source: String, + + /// Requests per hour limit. + #[arg(long, default_value = "450")] + rate_limit: u32, + + /// Preview without downloading. + #[arg(long)] + dry_run: bool, + }, + /// Add an item to the download queue. + Add { + /// Search query for the song to download. + query: String, + }, + /// List items in the download queue. + List { + /// Filter by status (pending, downloading, completed, failed, cancelled, all). + #[arg(long, default_value = "all")] + status: String, + }, + /// Retry all failed downloads. + Retry, +} + +fn default_database_url() -> String { + let data_dir = dirs::data_dir() + .unwrap_or_else(|| PathBuf::from(".")) + .join("shanty"); + std::fs::create_dir_all(&data_dir).ok(); + let db_path = data_dir.join("shanty.db"); + format!("sqlite://{}?mode=rwc", db_path.display()) +} + +fn default_output_dir() -> PathBuf { + let dir = dirs::data_dir() + .unwrap_or_else(|| PathBuf::from(".")) + .join("shanty") + .join("downloads"); + std::fs::create_dir_all(&dir).ok(); + dir +} + +fn make_backend( + cookies: &Option, + search_source: &str, + rate_limit: u32, +) -> anyhow::Result { + let source: SearchSource = search_source + .parse() + .map_err(|e: String| anyhow::anyhow!(e))?; + + // Bump rate limit if cookies are provided + let effective_rate = if cookies.is_some() && rate_limit == 450 { + tracing::info!("cookies provided — using authenticated rate limit (1800/hr)"); + 1800 + } else { + rate_limit + }; + + Ok(YtDlpBackend::new(effective_rate, source, cookies.clone())) +} + +fn make_backend_config( + format: &str, + output: &Option, + cookies: &Option, +) -> anyhow::Result { + let fmt: AudioFormat = format + .parse() + .map_err(|e: String| anyhow::anyhow!(e))?; + Ok(BackendConfig { + output_dir: output.clone().unwrap_or_else(default_output_dir), + format: fmt, + cookies_path: cookies.clone(), + }) +} + +#[tokio::main] +async fn main() -> anyhow::Result<()> { + let cli = Cli::parse(); + + // Set up tracing + let filter = match cli.verbose { + 0 => "warn", + 1 => "info,shanty_dl=info", + 2 => "info,shanty_dl=debug", + _ => "debug,shanty_dl=trace", + }; + tracing_subscriber::fmt() + .with_env_filter( + EnvFilter::try_from_default_env().unwrap_or_else(|_| EnvFilter::new(filter)), + ) + .init(); + + match cli.command { + Commands::Download { + query_or_url, + format, + output, + cookies, + search_source, + rate_limit, + dry_run, + } => { + let backend = make_backend(&cookies, &search_source, rate_limit)?; + backend.check_available().await?; + + let config = make_backend_config(&format, &output, &cookies)?; + + // Determine if it's a URL or a search query + let target = if query_or_url.starts_with("http://") + || query_or_url.starts_with("https://") + { + DownloadTarget::Url(query_or_url) + } else { + DownloadTarget::Query(query_or_url) + }; + + download_single(&backend, target, &config, dry_run).await?; + } + Commands::Queue { action } => { + let database_url = cli.database.unwrap_or_else(default_database_url); + let db = Database::new(&database_url).await?; + + match action { + QueueAction::Process { + format, + output, + cookies, + search_source, + rate_limit, + dry_run, + } => { + let backend = make_backend(&cookies, &search_source, rate_limit)?; + backend.check_available().await?; + + let config = make_backend_config(&format, &output, &cookies)?; + + if dry_run { + println!("DRY RUN — no files will be downloaded"); + } + + let stats = run_queue(db.conn(), &backend, &config, dry_run).await?; + println!("\nQueue processing complete: {stats}"); + } + QueueAction::Add { query } => { + let item = queries::downloads::enqueue(db.conn(), &query, None, "ytdlp").await?; + println!("Added to queue: id={}, query=\"{}\"", item.id, item.query); + } + QueueAction::List { status } => { + let filter = match status.to_lowercase().as_str() { + "all" => None, + "pending" => Some(DownloadStatus::Pending), + "downloading" => Some(DownloadStatus::Downloading), + "completed" => Some(DownloadStatus::Completed), + "failed" => Some(DownloadStatus::Failed), + "cancelled" => Some(DownloadStatus::Cancelled), + _ => anyhow::bail!("unknown status: {status}"), + }; + let items = queries::downloads::list(db.conn(), filter).await?; + + if items.is_empty() { + println!("Queue is empty."); + } else { + println!( + "{:<5} {:<12} {:<6} {:<40} {}", + "ID", "STATUS", "RETRY", "QUERY", "ERROR" + ); + for item in &items { + println!( + "{:<5} {:<12} {:<6} {:<40} {}", + item.id, + format!("{:?}", item.status), + item.retry_count, + truncate(&item.query, 40), + item.error_message.as_deref().unwrap_or(""), + ); + } + println!("\n{} items total", items.len()); + } + } + QueueAction::Retry => { + let failed = + queries::downloads::list(db.conn(), Some(DownloadStatus::Failed)).await?; + if failed.is_empty() { + println!("No failed downloads to retry."); + } else { + for item in &failed { + queries::downloads::retry_failed(db.conn(), item.id).await?; + } + println!("Requeued {} failed downloads.", failed.len()); + } + } + } + } + } + + Ok(()) +} + +fn truncate(s: &str, max: usize) -> String { + if s.len() <= max { + s.to_string() + } else { + format!("{}…", &s[..max - 1]) + } +} diff --git a/src/queue.rs b/src/queue.rs new file mode 100644 index 0000000..65d4145 --- /dev/null +++ b/src/queue.rs @@ -0,0 +1,212 @@ +use std::fmt; +use std::time::Duration; + +use sea_orm::DatabaseConnection; + +use shanty_db::entities::download_queue::DownloadStatus; +use shanty_db::entities::wanted_item::WantedStatus; +use shanty_db::queries; + +use crate::backend::{BackendConfig, DownloadBackend, DownloadTarget}; +use crate::error::{DlError, DlResult}; + +/// Statistics from a queue processing run. +#[derive(Debug, Default, Clone)] +pub struct DlStats { + pub downloads_attempted: u64, + pub downloads_completed: u64, + pub downloads_failed: u64, + pub downloads_skipped: u64, +} + +impl fmt::Display for DlStats { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!( + f, + "attempted: {}, completed: {}, failed: {}, skipped: {}", + self.downloads_attempted, + self.downloads_completed, + self.downloads_failed, + self.downloads_skipped, + ) + } +} + +const MAX_RETRIES: i32 = 3; +const RETRY_DELAYS: [Duration; 3] = [ + Duration::from_secs(30), + Duration::from_secs(120), + Duration::from_secs(600), +]; + +/// Process all pending items in the download queue. +pub async fn run_queue( + conn: &DatabaseConnection, + backend: &impl DownloadBackend, + config: &BackendConfig, + dry_run: bool, +) -> DlResult { + let mut stats = DlStats::default(); + + loop { + let item = match queries::downloads::get_next_pending(conn).await? { + Some(item) => item, + None => break, + }; + + stats.downloads_attempted += 1; + + tracing::info!( + id = item.id, + query = %item.query, + retry = item.retry_count, + "processing download" + ); + + if dry_run { + tracing::info!(id = item.id, query = %item.query, "DRY RUN: would download"); + stats.downloads_skipped += 1; + // Mark as failed temporarily so we don't loop forever on the same item + queries::downloads::update_status( + conn, + item.id, + DownloadStatus::Failed, + Some("dry run"), + ) + .await?; + continue; + } + + // Mark as downloading + queries::downloads::update_status(conn, item.id, DownloadStatus::Downloading, None) + .await?; + + // Determine download target + let target = if let Some(ref url) = item.source_url { + DownloadTarget::Url(url.clone()) + } else { + DownloadTarget::Query(item.query.clone()) + }; + + // Attempt download + match backend.download(&target, config).await { + Ok(result) => { + tracing::info!( + id = item.id, + path = %result.file_path.display(), + title = %result.title, + "download completed" + ); + + queries::downloads::update_status( + conn, + item.id, + DownloadStatus::Completed, + None, + ) + .await?; + + // Update wanted item status if linked + if let Some(wanted_id) = item.wanted_item_id { + if let Err(e) = queries::wanted::update_status( + conn, + wanted_id, + WantedStatus::Downloaded, + ) + .await + { + tracing::warn!( + wanted_id = wanted_id, + error = %e, + "failed to update wanted item status" + ); + } + } + + stats.downloads_completed += 1; + } + Err(e) => { + let error_msg = e.to_string(); + tracing::error!( + id = item.id, + query = %item.query, + error = %error_msg, + "download failed" + ); + + queries::downloads::update_status( + conn, + item.id, + DownloadStatus::Failed, + Some(&error_msg), + ) + .await?; + + // Auto-retry transient errors + if e.is_transient() && item.retry_count < MAX_RETRIES { + let delay_idx = item.retry_count.min(RETRY_DELAYS.len() as i32 - 1) as usize; + let delay = RETRY_DELAYS[delay_idx]; + tracing::info!( + id = item.id, + retry = item.retry_count + 1, + delay_secs = delay.as_secs(), + "scheduling retry" + ); + queries::downloads::retry_failed(conn, item.id).await?; + + // If rate limited, pause before continuing + if matches!(e, DlError::RateLimited(_)) { + tracing::warn!("rate limited — pausing queue processing"); + tokio::time::sleep(delay).await; + } + } + + stats.downloads_failed += 1; + } + } + } + + tracing::info!(%stats, "queue processing complete"); + Ok(stats) +} + +/// Download a single item directly (not from queue). +pub async fn download_single( + backend: &impl DownloadBackend, + target: DownloadTarget, + config: &BackendConfig, + dry_run: bool, +) -> DlResult<()> { + if dry_run { + match &target { + DownloadTarget::Url(url) => { + tracing::info!(url = %url, "DRY RUN: would download URL"); + } + DownloadTarget::Query(q) => { + tracing::info!(query = %q, "DRY RUN: would search and download"); + + // Still run the search to show what would be downloaded + let results = backend.search(q).await?; + if let Some(best) = results.first() { + tracing::info!( + title = %best.title, + artist = ?best.artist, + url = %best.url, + "would download" + ); + } else { + tracing::warn!(query = %q, "no results found"); + } + } + } + return Ok(()); + } + + let result = backend.download(&target, config).await?; + println!( + "Downloaded: {} → {}", + result.title, + result.file_path.display() + ); + Ok(()) +} diff --git a/src/rate_limit.rs b/src/rate_limit.rs new file mode 100644 index 0000000..a57487a --- /dev/null +++ b/src/rate_limit.rs @@ -0,0 +1,103 @@ +use std::time::Duration; + +use tokio::sync::Mutex; +use tokio::time::Instant; + +/// Token bucket rate limiter for controlling request rates. +pub struct RateLimiter { + max_per_hour: u32, + state: Mutex, +} + +struct TokenState { + remaining: u32, + window_start: Instant, +} + +impl RateLimiter { + /// Create a new rate limiter with the given maximum requests per hour. + pub fn new(max_per_hour: u32) -> Self { + Self { + max_per_hour, + state: Mutex::new(TokenState { + remaining: max_per_hour, + window_start: Instant::now(), + }), + } + } + + /// Acquire a token, sleeping if necessary to stay within the rate limit. + pub async fn acquire(&self) { + let mut state = self.state.lock().await; + + // Check if the window has rolled over + let elapsed = state.window_start.elapsed(); + if elapsed >= Duration::from_secs(3600) { + // Reset window + state.remaining = self.max_per_hour; + state.window_start = Instant::now(); + } + + if state.remaining > 0 { + state.remaining -= 1; + + // Warn when approaching the limit + let pct_remaining = + (state.remaining as f64 / self.max_per_hour as f64) * 100.0; + if pct_remaining < 10.0 && pct_remaining > 0.0 { + tracing::warn!( + remaining = state.remaining, + max = self.max_per_hour, + "approaching rate limit" + ); + } + } else { + // No tokens left — wait for window to reset + let wait = Duration::from_secs(3600) - elapsed; + tracing::warn!( + wait_secs = wait.as_secs(), + "rate limit reached, waiting for window reset" + ); + drop(state); // release lock while sleeping + tokio::time::sleep(wait).await; + + // Re-acquire and reset + let mut state = self.state.lock().await; + state.remaining = self.max_per_hour - 1; + state.window_start = Instant::now(); + } + } + + /// Get the number of remaining tokens in the current window. + pub async fn remaining(&self) -> u32 { + self.state.lock().await.remaining + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[tokio::test] + async fn test_rate_limiter_basic() { + let limiter = RateLimiter::new(10); + assert_eq!(limiter.remaining().await, 10); + + limiter.acquire().await; + assert_eq!(limiter.remaining().await, 9); + + for _ in 0..9 { + limiter.acquire().await; + } + assert_eq!(limiter.remaining().await, 0); + } + + #[tokio::test] + async fn test_rate_limiter_high_volume() { + let limiter = RateLimiter::new(100); + for _ in 0..50 { + limiter.acquire().await; + } + assert_eq!(limiter.remaining().await, 50); + } +} diff --git a/src/ytdlp.rs b/src/ytdlp.rs new file mode 100644 index 0000000..bcbd839 --- /dev/null +++ b/src/ytdlp.rs @@ -0,0 +1,371 @@ +use std::path::PathBuf; +use std::process::Stdio; + +use serde::Deserialize; +use tokio::process::Command; + +use crate::backend::{BackendConfig, DownloadBackend, DownloadResult, DownloadTarget, SearchResult}; +use crate::error::{DlError, DlResult}; +use crate::rate_limit::RateLimiter; + +/// Search source for yt-dlp. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum SearchSource { + YouTubeMusic, + YouTube, +} + +impl SearchSource { + fn prefix(&self) -> &str { + match self { + SearchSource::YouTubeMusic => "ytmusicsearch5", + SearchSource::YouTube => "ytsearch5", + } + } + + fn source_name(&self) -> &str { + match self { + SearchSource::YouTubeMusic => "youtube_music", + SearchSource::YouTube => "youtube", + } + } +} + +impl std::str::FromStr for SearchSource { + type Err = String; + + fn from_str(s: &str) -> Result { + match s.to_lowercase().as_str() { + "ytmusic" | "youtube_music" | "youtubemusic" => Ok(SearchSource::YouTubeMusic), + "youtube" | "yt" => Ok(SearchSource::YouTube), + _ => Err(format!("unknown search source: {s} (expected ytmusic or youtube)")), + } + } +} + +impl std::fmt::Display for SearchSource { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + SearchSource::YouTubeMusic => write!(f, "ytmusic"), + SearchSource::YouTube => write!(f, "youtube"), + } + } +} + +/// yt-dlp backend for downloading music. +pub struct YtDlpBackend { + rate_limiter: RateLimiter, + search_source: SearchSource, + fallback_search: bool, + cookies_path: Option, +} + +impl YtDlpBackend { + pub fn new( + rate_limit_per_hour: u32, + search_source: SearchSource, + cookies_path: Option, + ) -> Self { + Self { + rate_limiter: RateLimiter::new(rate_limit_per_hour), + search_source, + fallback_search: true, + cookies_path, + } + } + + /// Run a yt-dlp command and return stdout. + async fn run_ytdlp(&self, args: &[&str]) -> DlResult { + self.rate_limiter.acquire().await; + + let mut cmd = Command::new("yt-dlp"); + cmd.args(args) + .stdout(Stdio::piped()) + .stderr(Stdio::piped()); + + // Add cookies if configured + if let Some(ref cookies) = self.cookies_path { + cmd.arg("--cookies").arg(cookies); + } + + tracing::debug!(args = ?args, "running yt-dlp"); + + let output = cmd.output().await.map_err(|e| { + if e.kind() == std::io::ErrorKind::NotFound { + DlError::BackendNotFound("yt-dlp not found on PATH".into()) + } else { + DlError::Io(e) + } + })?; + + if !output.status.success() { + let stderr = String::from_utf8_lossy(&output.stderr).to_string(); + tracing::error!(stderr = %stderr, "yt-dlp failed"); + + if stderr.contains("429") || stderr.contains("Too Many Requests") { + return Err(DlError::RateLimited(stderr)); + } + return Err(DlError::BackendError(stderr)); + } + + Ok(String::from_utf8_lossy(&output.stdout).to_string()) + } + + /// Search using a specific source prefix. + async fn search_with_source( + &self, + query: &str, + source: SearchSource, + ) -> DlResult> { + let search_query = format!("{}:{}", source.prefix(), query); + let output = self + .run_ytdlp(&[ + "--dump-json", + "--flat-playlist", + "--no-download", + &search_query, + ]) + .await?; + + let mut results = Vec::new(); + for line in output.lines() { + if line.trim().is_empty() { + continue; + } + match serde_json::from_str::(line) { + Ok(entry) => { + results.push(SearchResult { + url: entry.url.or(entry.webpage_url).unwrap_or_default(), + title: entry.title.unwrap_or_else(|| "Unknown".into()), + artist: entry.artist.or(entry.uploader).or(entry.channel), + duration: entry.duration, + source: source.source_name().into(), + }); + } + Err(e) => { + tracing::debug!(error = %e, "failed to parse search result line"); + } + } + } + + Ok(results) + } +} + +impl DownloadBackend for YtDlpBackend { + async fn check_available(&self) -> DlResult<()> { + let output = self.run_ytdlp(&["--version"]).await?; + tracing::info!(version = output.trim(), "yt-dlp available"); + Ok(()) + } + + async fn search(&self, query: &str) -> DlResult> { + // Try primary search source + let results = self.search_with_source(query, self.search_source).await; + + match results { + Ok(ref r) if !r.is_empty() => return results, + Ok(_) => { + tracing::debug!(source = %self.search_source, "no results from primary source"); + } + Err(ref e) => { + tracing::debug!(source = %self.search_source, error = %e, "primary search failed"); + } + } + + // Fallback to the other source + if self.fallback_search && self.search_source == SearchSource::YouTubeMusic { + tracing::info!("falling back to YouTube search"); + let fallback = self + .search_with_source(query, SearchSource::YouTube) + .await?; + if !fallback.is_empty() { + return Ok(fallback); + } + } + + // If primary returned Ok([]) and fallback also empty, return that + // If primary returned Err, propagate it + results + } + + async fn download( + &self, + target: &DownloadTarget, + config: &BackendConfig, + ) -> DlResult { + let url = match target { + DownloadTarget::Url(u) => u.clone(), + DownloadTarget::Query(q) => { + // Search first, take the best result + let results = self.search(q).await?; + let best = results + .into_iter() + .next() + .ok_or_else(|| DlError::NoResults(q.clone()))?; + tracing::info!( + title = %best.title, + artist = ?best.artist, + url = %best.url, + "selected best search result" + ); + best.url + } + }; + + // Build output template + let output_template = config + .output_dir + .join("%(title)s.%(ext)s") + .to_string_lossy() + .to_string(); + + let format_arg = config.format.as_ytdlp_arg(); + + let mut args = vec![ + "--extract-audio", + "--audio-format", + format_arg, + "--audio-quality", + "0", + "--output", + &output_template, + "--print-json", + "--no-playlist", + ]; + + // Add cookies from backend config or backend's own cookies + let cookies_str; + let cookies_path = config + .cookies_path + .as_ref() + .or(self.cookies_path.as_ref()); + if let Some(c) = cookies_path { + cookies_str = c.to_string_lossy().to_string(); + args.push("--cookies"); + args.push(&cookies_str); + } + + args.push(&url); + + let output = self.run_ytdlp(&args).await?; + + // Parse the JSON output to get the actual file path + let info: YtDlpDownloadInfo = serde_json::from_str(output.trim()).map_err(|e| { + DlError::BackendError(format!("failed to parse yt-dlp output: {e}")) + })?; + + // yt-dlp may change the extension after conversion + let file_path = if let Some(ref requested) = info.requested_downloads { + if let Some(first) = requested.first() { + PathBuf::from(&first.filepath) + } else { + PathBuf::from(&info.filename) + } + } else { + PathBuf::from(&info.filename) + }; + + tracing::info!( + path = %file_path.display(), + title = %info.title, + "download complete" + ); + + Ok(DownloadResult { + file_path, + title: info.title, + artist: info.artist.or(info.uploader).or(info.channel), + duration: info.duration, + source_url: url, + }) + } +} + +// --- yt-dlp JSON output types --- + +#[derive(Debug, Deserialize)] +struct YtDlpSearchEntry { + url: Option, + webpage_url: Option, + title: Option, + artist: Option, + uploader: Option, + channel: Option, + duration: Option, +} + +#[derive(Debug, Deserialize)] +struct YtDlpDownloadInfo { + title: String, + artist: Option, + uploader: Option, + channel: Option, + duration: Option, + filename: String, + requested_downloads: Option>, +} + +#[derive(Debug, Deserialize)] +struct YtDlpRequestedDownload { + filepath: String, +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::backend::AudioFormat; + + #[test] + fn test_search_source_parse() { + assert_eq!( + "ytmusic".parse::().unwrap(), + SearchSource::YouTubeMusic + ); + assert_eq!( + "youtube".parse::().unwrap(), + SearchSource::YouTube + ); + assert!("invalid".parse::().is_err()); + } + + #[test] + fn test_audio_format_parse() { + assert_eq!("opus".parse::().unwrap(), AudioFormat::Opus); + assert_eq!("mp3".parse::().unwrap(), AudioFormat::Mp3); + assert_eq!("flac".parse::().unwrap(), AudioFormat::Flac); + assert_eq!("best".parse::().unwrap(), AudioFormat::Best); + assert!("wav".parse::().is_err()); + } + + #[test] + fn test_search_source_prefix() { + assert_eq!(SearchSource::YouTubeMusic.prefix(), "ytmusicsearch5"); + assert_eq!(SearchSource::YouTube.prefix(), "ytsearch5"); + } + + #[test] + fn test_parse_search_entry() { + let json = r#"{"url": "https://youtube.com/watch?v=abc", "title": "Time", "artist": "Pink Floyd", "duration": 413.0}"#; + let entry: YtDlpSearchEntry = serde_json::from_str(json).unwrap(); + assert_eq!(entry.title.as_deref(), Some("Time")); + assert_eq!(entry.artist.as_deref(), Some("Pink Floyd")); + assert_eq!(entry.duration, Some(413.0)); + } + + #[test] + fn test_parse_download_info() { + let json = r#"{ + "title": "Time", + "artist": "Pink Floyd", + "uploader": null, + "channel": null, + "duration": 413.0, + "filename": "/tmp/Time.opus", + "requested_downloads": [{"filepath": "/tmp/Time.opus"}] + }"#; + let info: YtDlpDownloadInfo = serde_json::from_str(json).unwrap(); + assert_eq!(info.title, "Time"); + assert_eq!(info.filename, "/tmp/Time.opus"); + } +} diff --git a/tests/integration.rs b/tests/integration.rs new file mode 100644 index 0000000..a1beca0 --- /dev/null +++ b/tests/integration.rs @@ -0,0 +1,239 @@ +use shanty_db::entities::download_queue::DownloadStatus; +use shanty_db::{Database, queries}; +use shanty_dl::backend::{AudioFormat, BackendConfig, DownloadBackend, DownloadResult, DownloadTarget, SearchResult}; +use shanty_dl::error::DlResult; +use shanty_dl::queue::{run_queue, download_single}; +use tempfile::TempDir; + +/// Mock backend for testing without yt-dlp. +struct MockBackend { + /// If true, downloads will fail. + should_fail: bool, +} + +impl MockBackend { + fn new(should_fail: bool) -> Self { + Self { should_fail } + } +} + +impl DownloadBackend for MockBackend { + async fn check_available(&self) -> DlResult<()> { + Ok(()) + } + + async fn search(&self, query: &str) -> DlResult> { + Ok(vec![SearchResult { + url: format!("https://example.com/{}", query.replace(' ', "_")), + title: query.to_string(), + artist: Some("Test Artist".to_string()), + duration: Some(180.0), + source: "mock".to_string(), + }]) + } + + async fn download( + &self, + target: &DownloadTarget, + config: &BackendConfig, + ) -> DlResult { + if self.should_fail { + return Err(shanty_dl::DlError::DownloadFailed("mock failure".into())); + } + + let title = match target { + DownloadTarget::Url(u) => u.clone(), + DownloadTarget::Query(q) => q.clone(), + }; + + let file_name = format!("{}.{}", title.replace(' ', "_"), config.format); + let file_path = config.output_dir.join(&file_name); + std::fs::write(&file_path, b"fake audio data").unwrap(); + + Ok(DownloadResult { + file_path, + title, + artist: Some("Test Artist".into()), + duration: Some(180.0), + source_url: "https://example.com/test".into(), + }) + } +} + +async fn test_db() -> Database { + Database::new("sqlite::memory:") + .await + .expect("failed to create test database") +} + +#[tokio::test] +async fn test_queue_process_success() { + let db = test_db().await; + let dir = TempDir::new().unwrap(); + let backend = MockBackend::new(false); + + let config = BackendConfig { + output_dir: dir.path().to_owned(), + format: AudioFormat::Opus, + cookies_path: None, + }; + + // Enqueue an item + queries::downloads::enqueue(db.conn(), "Test Song", None, "mock") + .await + .unwrap(); + + // Process queue + let stats = run_queue(db.conn(), &backend, &config, false).await.unwrap(); + assert_eq!(stats.downloads_attempted, 1); + assert_eq!(stats.downloads_completed, 1); + assert_eq!(stats.downloads_failed, 0); + + // Verify status in DB + let items = queries::downloads::list(db.conn(), Some(DownloadStatus::Completed)) + .await + .unwrap(); + assert_eq!(items.len(), 1); +} + +#[tokio::test] +async fn test_queue_process_failure() { + let db = test_db().await; + let dir = TempDir::new().unwrap(); + let backend = MockBackend::new(true); + + let config = BackendConfig { + output_dir: dir.path().to_owned(), + format: AudioFormat::Opus, + cookies_path: None, + }; + + queries::downloads::enqueue(db.conn(), "Failing Song", None, "mock") + .await + .unwrap(); + + let stats = run_queue(db.conn(), &backend, &config, false).await.unwrap(); + assert_eq!(stats.downloads_attempted, 1); + assert_eq!(stats.downloads_failed, 1); + + // Check it's marked as failed + let items = queries::downloads::list(db.conn(), Some(DownloadStatus::Failed)) + .await + .unwrap(); + assert_eq!(items.len(), 1); + assert!(items[0].error_message.is_some()); +} + +#[tokio::test] +async fn test_queue_dry_run() { + let db = test_db().await; + let dir = TempDir::new().unwrap(); + let backend = MockBackend::new(false); + + let config = BackendConfig { + output_dir: dir.path().to_owned(), + format: AudioFormat::Opus, + cookies_path: None, + }; + + queries::downloads::enqueue(db.conn(), "Dry Run Song", None, "mock") + .await + .unwrap(); + + let stats = run_queue(db.conn(), &backend, &config, true).await.unwrap(); + assert_eq!(stats.downloads_attempted, 1); + assert_eq!(stats.downloads_skipped, 1); + assert_eq!(stats.downloads_completed, 0); + + // No files should exist + let entries: Vec<_> = std::fs::read_dir(dir.path()).unwrap().collect(); + assert!(entries.is_empty()); +} + +#[tokio::test] +async fn test_download_single_success() { + let dir = TempDir::new().unwrap(); + let backend = MockBackend::new(false); + + let config = BackendConfig { + output_dir: dir.path().to_owned(), + format: AudioFormat::Mp3, + cookies_path: None, + }; + + download_single( + &backend, + DownloadTarget::Query("Test Song".into()), + &config, + false, + ) + .await + .unwrap(); + + // File should exist + let entries: Vec<_> = std::fs::read_dir(dir.path()) + .unwrap() + .filter_map(|e| e.ok()) + .collect(); + assert_eq!(entries.len(), 1); +} + +#[tokio::test] +async fn test_queue_retry() { + let db = test_db().await; + + // Enqueue and manually fail an item + let item = queries::downloads::enqueue(db.conn(), "Retry Song", None, "mock") + .await + .unwrap(); + queries::downloads::update_status(db.conn(), item.id, DownloadStatus::Failed, Some("oops")) + .await + .unwrap(); + + // Retry it + queries::downloads::retry_failed(db.conn(), item.id) + .await + .unwrap(); + + // Should be pending again with retry_count = 1 + let pending = queries::downloads::list(db.conn(), Some(DownloadStatus::Pending)) + .await + .unwrap(); + assert_eq!(pending.len(), 1); + assert_eq!(pending[0].retry_count, 1); +} + +#[tokio::test] +async fn test_wanted_item_status_updated_on_download() { + use shanty_db::entities::wanted_item::{ItemType, WantedStatus}; + + let db = test_db().await; + let dir = TempDir::new().unwrap(); + let backend = MockBackend::new(false); + + let config = BackendConfig { + output_dir: dir.path().to_owned(), + format: AudioFormat::Opus, + cookies_path: None, + }; + + // Create a wanted item + let wanted = queries::wanted::add(db.conn(), ItemType::Track, None, None, None) + .await + .unwrap(); + assert_eq!(wanted.status, WantedStatus::Wanted); + + // Enqueue download linked to the wanted item + queries::downloads::enqueue(db.conn(), "Wanted Song", Some(wanted.id), "mock") + .await + .unwrap(); + + // Process queue + run_queue(db.conn(), &backend, &config, false).await.unwrap(); + + // Wanted item should now be Downloaded + let updated = queries::wanted::get_by_id(db.conn(), wanted.id) + .await + .unwrap(); + assert_eq!(updated.status, WantedStatus::Downloaded); +}