Initial commit

This commit is contained in:
Connor Johnstone
2026-03-17 15:31:29 -04:00
commit d5641493b9
11 changed files with 1458 additions and 0 deletions

4
.gitignore vendored Normal file
View File

@@ -0,0 +1,4 @@
target/
.env
*.db
*.db-journal

25
Cargo.toml Normal file
View File

@@ -0,0 +1,25 @@
[package]
name = "shanty-dl"
version = "0.1.0"
edition = "2024"
license = "MIT"
description = "Music downloading for Shanty"
repository = "ssh://connor@git.rcjohnstone.com:2222/Shanty/dl.git"
[dependencies]
shanty-db = { path = "../shanty-db" }
sea-orm = { version = "1", features = ["sqlx-sqlite", "runtime-tokio-native-tls"] }
clap = { version = "4", features = ["derive"] }
serde = { version = "1", features = ["derive"] }
serde_json = "1"
thiserror = "2"
tracing = "0.1"
tracing-subscriber = { version = "0.3", features = ["env-filter"] }
tokio = { version = "1", features = ["full", "process"] }
anyhow = "1"
chrono = { version = "0.4", features = ["serde"] }
dirs = "6"
[dev-dependencies]
tokio = { version = "1", features = ["full", "test-util"] }
tempfile = "3"

43
readme.md Normal file
View File

@@ -0,0 +1,43 @@
# shanty-dl
Music downloading for [Shanty](ssh://connor@git.rcjohnstone.com:2222/Shanty/shanty.git).
Downloads music files using pluggable backends. The default backend uses yt-dlp with
YouTube Music search, rate limiting, and configurable output format.
## Prerequisites
- [yt-dlp](https://github.com/yt-dlp/yt-dlp) must be installed and on PATH
- [ffmpeg](https://ffmpeg.org/) for audio conversion
## Usage
```sh
# Download a single song
shanty-dl download "Pink Floyd Time"
# Download from a direct URL
shanty-dl download "https://www.youtube.com/watch?v=..."
# Add to download queue
shanty-dl queue add "Radiohead Creep"
# Process the queue
shanty-dl queue process
# List queue status
shanty-dl queue list
# Retry failed downloads
shanty-dl queue retry
```
## YouTube Authentication
For higher rate limits (2000/hr vs 500/hr), export cookies:
```sh
yt-dlp --cookies-from-browser firefox --cookies ~/.config/shanty/cookies.txt
```
Then pass via `--cookies ~/.config/shanty/cookies.txt` or set `SHANTY_COOKIES`.

109
src/backend.rs Normal file
View File

@@ -0,0 +1,109 @@
use std::path::PathBuf;
use serde::{Deserialize, Serialize};
use crate::error::DlResult;
/// A search result from a download backend.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SearchResult {
/// URL or ID that can be passed to `download`.
pub url: String,
/// Title of the track.
pub title: String,
/// Artist/uploader name, if available.
pub artist: Option<String>,
/// Duration in seconds, if available.
pub duration: Option<f64>,
/// Source identifier (e.g., "youtube_music", "youtube").
pub source: String,
}
/// What to download — either a direct URL or a search query.
#[derive(Debug, Clone)]
pub enum DownloadTarget {
/// A direct URL to download.
Url(String),
/// A search query to find and download.
Query(String),
}
/// Result of a successful download.
#[derive(Debug, Clone)]
pub struct DownloadResult {
/// Path to the downloaded file.
pub file_path: PathBuf,
/// Title from the source.
pub title: String,
/// Artist/uploader from the source.
pub artist: Option<String>,
/// Duration in seconds.
pub duration: Option<f64>,
/// The URL that was actually downloaded.
pub source_url: String,
}
/// Configuration passed to the backend for a download.
#[derive(Debug, Clone)]
pub struct BackendConfig {
pub output_dir: PathBuf,
pub format: AudioFormat,
pub cookies_path: Option<PathBuf>,
}
/// Supported output audio formats.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum AudioFormat {
Opus,
Mp3,
Flac,
Best,
}
impl AudioFormat {
pub fn as_ytdlp_arg(&self) -> &str {
match self {
AudioFormat::Opus => "opus",
AudioFormat::Mp3 => "mp3",
AudioFormat::Flac => "flac",
AudioFormat::Best => "best",
}
}
}
impl std::str::FromStr for AudioFormat {
type Err = String;
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s.to_lowercase().as_str() {
"opus" => Ok(AudioFormat::Opus),
"mp3" => Ok(AudioFormat::Mp3),
"flac" => Ok(AudioFormat::Flac),
"best" => Ok(AudioFormat::Best),
_ => Err(format!("unsupported format: {s} (expected opus, mp3, flac, or best)")),
}
}
}
impl std::fmt::Display for AudioFormat {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.as_ytdlp_arg())
}
}
/// Trait for download backends. yt-dlp is the default; others (torrents, Soulseek)
/// can be added later.
pub trait DownloadBackend: Send + Sync {
/// Check if this backend is available on the system.
fn check_available(&self) -> impl std::future::Future<Output = DlResult<()>> + Send;
/// Search for tracks matching a query.
fn search(&self, query: &str) -> impl std::future::Future<Output = DlResult<Vec<SearchResult>>> + Send;
/// Download a target to the configured output directory.
fn download(
&self,
target: &DownloadTarget,
config: &BackendConfig,
) -> impl std::future::Future<Output = DlResult<DownloadResult>> + Send;
}

45
src/error.rs Normal file
View File

@@ -0,0 +1,45 @@
use shanty_db::DbError;
#[derive(Debug, thiserror::Error)]
pub enum DlError {
#[error("database error: {0}")]
Db(#[from] DbError),
#[error("I/O error: {0}")]
Io(#[from] std::io::Error),
#[error("yt-dlp not found: {0}")]
BackendNotFound(String),
#[error("yt-dlp error: {0}")]
BackendError(String),
#[error("rate limited: {0}")]
RateLimited(String),
#[error("download failed: {0}")]
DownloadFailed(String),
#[error("search returned no results for: {0}")]
NoResults(String),
#[error("JSON parse error: {0}")]
Json(#[from] serde_json::Error),
#[error("{0}")]
Other(String),
}
impl DlError {
/// Whether this error is transient and worth retrying.
pub fn is_transient(&self) -> bool {
matches!(
self,
DlError::RateLimited(_)
| DlError::Io(_)
| DlError::BackendError(_)
)
}
}
pub type DlResult<T> = Result<T, DlError>;

15
src/lib.rs Normal file
View File

@@ -0,0 +1,15 @@
//! Music downloading for Shanty.
//!
//! Downloads music files using pluggable backends. The default backend uses
//! yt-dlp with YouTube Music search, rate limiting, and configurable output format.
pub mod backend;
pub mod error;
pub mod queue;
pub mod rate_limit;
pub mod ytdlp;
pub use backend::{AudioFormat, BackendConfig, DownloadBackend, DownloadResult, DownloadTarget, SearchResult};
pub use error::{DlError, DlResult};
pub use queue::{DlStats, download_single, run_queue};
pub use ytdlp::{SearchSource, YtDlpBackend};

292
src/main.rs Normal file
View File

@@ -0,0 +1,292 @@
use std::path::PathBuf;
use clap::{Parser, Subcommand};
use tracing_subscriber::EnvFilter;
use shanty_db::Database;
use shanty_db::entities::download_queue::DownloadStatus;
use shanty_db::queries;
use shanty_dl::{
AudioFormat, BackendConfig, DownloadBackend, DownloadTarget, SearchSource, YtDlpBackend,
download_single, run_queue,
};
#[derive(Parser)]
#[command(name = "shanty-dl", about = "Download music files for Shanty")]
struct Cli {
#[command(subcommand)]
command: Commands,
/// Database URL. Defaults to sqlite://<XDG_DATA_HOME>/shanty/shanty.db?mode=rwc
#[arg(long, global = true, env = "SHANTY_DATABASE_URL")]
database: Option<String>,
/// Increase verbosity (-v info, -vv debug, -vvv trace).
#[arg(short, long, global = true, action = clap::ArgAction::Count)]
verbose: u8,
}
#[derive(Subcommand)]
enum Commands {
/// Download a single song by query or URL.
Download {
/// Search query or direct URL.
query_or_url: String,
/// Output audio format.
#[arg(long, default_value = "opus")]
format: String,
/// Output directory for downloaded files.
#[arg(long)]
output: Option<PathBuf>,
/// Path to cookies.txt file for YouTube authentication.
#[arg(long, env = "SHANTY_COOKIES")]
cookies: Option<PathBuf>,
/// Search source (ytmusic or youtube).
#[arg(long, default_value = "ytmusic")]
search_source: String,
/// Requests per hour limit.
#[arg(long, default_value = "450")]
rate_limit: u32,
/// Preview what would be downloaded without doing it.
#[arg(long)]
dry_run: bool,
},
/// Manage the download queue.
Queue {
#[command(subcommand)]
action: QueueAction,
},
}
#[derive(Subcommand)]
enum QueueAction {
/// Process all pending items in the download queue.
Process {
/// Output audio format.
#[arg(long, default_value = "opus")]
format: String,
/// Output directory for downloaded files.
#[arg(long)]
output: Option<PathBuf>,
/// Path to cookies.txt for YouTube authentication.
#[arg(long, env = "SHANTY_COOKIES")]
cookies: Option<PathBuf>,
/// Search source (ytmusic or youtube).
#[arg(long, default_value = "ytmusic")]
search_source: String,
/// Requests per hour limit.
#[arg(long, default_value = "450")]
rate_limit: u32,
/// Preview without downloading.
#[arg(long)]
dry_run: bool,
},
/// Add an item to the download queue.
Add {
/// Search query for the song to download.
query: String,
},
/// List items in the download queue.
List {
/// Filter by status (pending, downloading, completed, failed, cancelled, all).
#[arg(long, default_value = "all")]
status: String,
},
/// Retry all failed downloads.
Retry,
}
fn default_database_url() -> String {
let data_dir = dirs::data_dir()
.unwrap_or_else(|| PathBuf::from("."))
.join("shanty");
std::fs::create_dir_all(&data_dir).ok();
let db_path = data_dir.join("shanty.db");
format!("sqlite://{}?mode=rwc", db_path.display())
}
fn default_output_dir() -> PathBuf {
let dir = dirs::data_dir()
.unwrap_or_else(|| PathBuf::from("."))
.join("shanty")
.join("downloads");
std::fs::create_dir_all(&dir).ok();
dir
}
fn make_backend(
cookies: &Option<PathBuf>,
search_source: &str,
rate_limit: u32,
) -> anyhow::Result<YtDlpBackend> {
let source: SearchSource = search_source
.parse()
.map_err(|e: String| anyhow::anyhow!(e))?;
// Bump rate limit if cookies are provided
let effective_rate = if cookies.is_some() && rate_limit == 450 {
tracing::info!("cookies provided — using authenticated rate limit (1800/hr)");
1800
} else {
rate_limit
};
Ok(YtDlpBackend::new(effective_rate, source, cookies.clone()))
}
fn make_backend_config(
format: &str,
output: &Option<PathBuf>,
cookies: &Option<PathBuf>,
) -> anyhow::Result<BackendConfig> {
let fmt: AudioFormat = format
.parse()
.map_err(|e: String| anyhow::anyhow!(e))?;
Ok(BackendConfig {
output_dir: output.clone().unwrap_or_else(default_output_dir),
format: fmt,
cookies_path: cookies.clone(),
})
}
#[tokio::main]
async fn main() -> anyhow::Result<()> {
let cli = Cli::parse();
// Set up tracing
let filter = match cli.verbose {
0 => "warn",
1 => "info,shanty_dl=info",
2 => "info,shanty_dl=debug",
_ => "debug,shanty_dl=trace",
};
tracing_subscriber::fmt()
.with_env_filter(
EnvFilter::try_from_default_env().unwrap_or_else(|_| EnvFilter::new(filter)),
)
.init();
match cli.command {
Commands::Download {
query_or_url,
format,
output,
cookies,
search_source,
rate_limit,
dry_run,
} => {
let backend = make_backend(&cookies, &search_source, rate_limit)?;
backend.check_available().await?;
let config = make_backend_config(&format, &output, &cookies)?;
// Determine if it's a URL or a search query
let target = if query_or_url.starts_with("http://")
|| query_or_url.starts_with("https://")
{
DownloadTarget::Url(query_or_url)
} else {
DownloadTarget::Query(query_or_url)
};
download_single(&backend, target, &config, dry_run).await?;
}
Commands::Queue { action } => {
let database_url = cli.database.unwrap_or_else(default_database_url);
let db = Database::new(&database_url).await?;
match action {
QueueAction::Process {
format,
output,
cookies,
search_source,
rate_limit,
dry_run,
} => {
let backend = make_backend(&cookies, &search_source, rate_limit)?;
backend.check_available().await?;
let config = make_backend_config(&format, &output, &cookies)?;
if dry_run {
println!("DRY RUN — no files will be downloaded");
}
let stats = run_queue(db.conn(), &backend, &config, dry_run).await?;
println!("\nQueue processing complete: {stats}");
}
QueueAction::Add { query } => {
let item = queries::downloads::enqueue(db.conn(), &query, None, "ytdlp").await?;
println!("Added to queue: id={}, query=\"{}\"", item.id, item.query);
}
QueueAction::List { status } => {
let filter = match status.to_lowercase().as_str() {
"all" => None,
"pending" => Some(DownloadStatus::Pending),
"downloading" => Some(DownloadStatus::Downloading),
"completed" => Some(DownloadStatus::Completed),
"failed" => Some(DownloadStatus::Failed),
"cancelled" => Some(DownloadStatus::Cancelled),
_ => anyhow::bail!("unknown status: {status}"),
};
let items = queries::downloads::list(db.conn(), filter).await?;
if items.is_empty() {
println!("Queue is empty.");
} else {
println!(
"{:<5} {:<12} {:<6} {:<40} {}",
"ID", "STATUS", "RETRY", "QUERY", "ERROR"
);
for item in &items {
println!(
"{:<5} {:<12} {:<6} {:<40} {}",
item.id,
format!("{:?}", item.status),
item.retry_count,
truncate(&item.query, 40),
item.error_message.as_deref().unwrap_or(""),
);
}
println!("\n{} items total", items.len());
}
}
QueueAction::Retry => {
let failed =
queries::downloads::list(db.conn(), Some(DownloadStatus::Failed)).await?;
if failed.is_empty() {
println!("No failed downloads to retry.");
} else {
for item in &failed {
queries::downloads::retry_failed(db.conn(), item.id).await?;
}
println!("Requeued {} failed downloads.", failed.len());
}
}
}
}
}
Ok(())
}
fn truncate(s: &str, max: usize) -> String {
if s.len() <= max {
s.to_string()
} else {
format!("{}", &s[..max - 1])
}
}

212
src/queue.rs Normal file
View File

@@ -0,0 +1,212 @@
use std::fmt;
use std::time::Duration;
use sea_orm::DatabaseConnection;
use shanty_db::entities::download_queue::DownloadStatus;
use shanty_db::entities::wanted_item::WantedStatus;
use shanty_db::queries;
use crate::backend::{BackendConfig, DownloadBackend, DownloadTarget};
use crate::error::{DlError, DlResult};
/// Statistics from a queue processing run.
#[derive(Debug, Default, Clone)]
pub struct DlStats {
pub downloads_attempted: u64,
pub downloads_completed: u64,
pub downloads_failed: u64,
pub downloads_skipped: u64,
}
impl fmt::Display for DlStats {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(
f,
"attempted: {}, completed: {}, failed: {}, skipped: {}",
self.downloads_attempted,
self.downloads_completed,
self.downloads_failed,
self.downloads_skipped,
)
}
}
const MAX_RETRIES: i32 = 3;
const RETRY_DELAYS: [Duration; 3] = [
Duration::from_secs(30),
Duration::from_secs(120),
Duration::from_secs(600),
];
/// Process all pending items in the download queue.
pub async fn run_queue(
conn: &DatabaseConnection,
backend: &impl DownloadBackend,
config: &BackendConfig,
dry_run: bool,
) -> DlResult<DlStats> {
let mut stats = DlStats::default();
loop {
let item = match queries::downloads::get_next_pending(conn).await? {
Some(item) => item,
None => break,
};
stats.downloads_attempted += 1;
tracing::info!(
id = item.id,
query = %item.query,
retry = item.retry_count,
"processing download"
);
if dry_run {
tracing::info!(id = item.id, query = %item.query, "DRY RUN: would download");
stats.downloads_skipped += 1;
// Mark as failed temporarily so we don't loop forever on the same item
queries::downloads::update_status(
conn,
item.id,
DownloadStatus::Failed,
Some("dry run"),
)
.await?;
continue;
}
// Mark as downloading
queries::downloads::update_status(conn, item.id, DownloadStatus::Downloading, None)
.await?;
// Determine download target
let target = if let Some(ref url) = item.source_url {
DownloadTarget::Url(url.clone())
} else {
DownloadTarget::Query(item.query.clone())
};
// Attempt download
match backend.download(&target, config).await {
Ok(result) => {
tracing::info!(
id = item.id,
path = %result.file_path.display(),
title = %result.title,
"download completed"
);
queries::downloads::update_status(
conn,
item.id,
DownloadStatus::Completed,
None,
)
.await?;
// Update wanted item status if linked
if let Some(wanted_id) = item.wanted_item_id {
if let Err(e) = queries::wanted::update_status(
conn,
wanted_id,
WantedStatus::Downloaded,
)
.await
{
tracing::warn!(
wanted_id = wanted_id,
error = %e,
"failed to update wanted item status"
);
}
}
stats.downloads_completed += 1;
}
Err(e) => {
let error_msg = e.to_string();
tracing::error!(
id = item.id,
query = %item.query,
error = %error_msg,
"download failed"
);
queries::downloads::update_status(
conn,
item.id,
DownloadStatus::Failed,
Some(&error_msg),
)
.await?;
// Auto-retry transient errors
if e.is_transient() && item.retry_count < MAX_RETRIES {
let delay_idx = item.retry_count.min(RETRY_DELAYS.len() as i32 - 1) as usize;
let delay = RETRY_DELAYS[delay_idx];
tracing::info!(
id = item.id,
retry = item.retry_count + 1,
delay_secs = delay.as_secs(),
"scheduling retry"
);
queries::downloads::retry_failed(conn, item.id).await?;
// If rate limited, pause before continuing
if matches!(e, DlError::RateLimited(_)) {
tracing::warn!("rate limited — pausing queue processing");
tokio::time::sleep(delay).await;
}
}
stats.downloads_failed += 1;
}
}
}
tracing::info!(%stats, "queue processing complete");
Ok(stats)
}
/// Download a single item directly (not from queue).
pub async fn download_single(
backend: &impl DownloadBackend,
target: DownloadTarget,
config: &BackendConfig,
dry_run: bool,
) -> DlResult<()> {
if dry_run {
match &target {
DownloadTarget::Url(url) => {
tracing::info!(url = %url, "DRY RUN: would download URL");
}
DownloadTarget::Query(q) => {
tracing::info!(query = %q, "DRY RUN: would search and download");
// Still run the search to show what would be downloaded
let results = backend.search(q).await?;
if let Some(best) = results.first() {
tracing::info!(
title = %best.title,
artist = ?best.artist,
url = %best.url,
"would download"
);
} else {
tracing::warn!(query = %q, "no results found");
}
}
}
return Ok(());
}
let result = backend.download(&target, config).await?;
println!(
"Downloaded: {}{}",
result.title,
result.file_path.display()
);
Ok(())
}

103
src/rate_limit.rs Normal file
View File

@@ -0,0 +1,103 @@
use std::time::Duration;
use tokio::sync::Mutex;
use tokio::time::Instant;
/// Token bucket rate limiter for controlling request rates.
pub struct RateLimiter {
max_per_hour: u32,
state: Mutex<TokenState>,
}
struct TokenState {
remaining: u32,
window_start: Instant,
}
impl RateLimiter {
/// Create a new rate limiter with the given maximum requests per hour.
pub fn new(max_per_hour: u32) -> Self {
Self {
max_per_hour,
state: Mutex::new(TokenState {
remaining: max_per_hour,
window_start: Instant::now(),
}),
}
}
/// Acquire a token, sleeping if necessary to stay within the rate limit.
pub async fn acquire(&self) {
let mut state = self.state.lock().await;
// Check if the window has rolled over
let elapsed = state.window_start.elapsed();
if elapsed >= Duration::from_secs(3600) {
// Reset window
state.remaining = self.max_per_hour;
state.window_start = Instant::now();
}
if state.remaining > 0 {
state.remaining -= 1;
// Warn when approaching the limit
let pct_remaining =
(state.remaining as f64 / self.max_per_hour as f64) * 100.0;
if pct_remaining < 10.0 && pct_remaining > 0.0 {
tracing::warn!(
remaining = state.remaining,
max = self.max_per_hour,
"approaching rate limit"
);
}
} else {
// No tokens left — wait for window to reset
let wait = Duration::from_secs(3600) - elapsed;
tracing::warn!(
wait_secs = wait.as_secs(),
"rate limit reached, waiting for window reset"
);
drop(state); // release lock while sleeping
tokio::time::sleep(wait).await;
// Re-acquire and reset
let mut state = self.state.lock().await;
state.remaining = self.max_per_hour - 1;
state.window_start = Instant::now();
}
}
/// Get the number of remaining tokens in the current window.
pub async fn remaining(&self) -> u32 {
self.state.lock().await.remaining
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_rate_limiter_basic() {
let limiter = RateLimiter::new(10);
assert_eq!(limiter.remaining().await, 10);
limiter.acquire().await;
assert_eq!(limiter.remaining().await, 9);
for _ in 0..9 {
limiter.acquire().await;
}
assert_eq!(limiter.remaining().await, 0);
}
#[tokio::test]
async fn test_rate_limiter_high_volume() {
let limiter = RateLimiter::new(100);
for _ in 0..50 {
limiter.acquire().await;
}
assert_eq!(limiter.remaining().await, 50);
}
}

371
src/ytdlp.rs Normal file
View File

@@ -0,0 +1,371 @@
use std::path::PathBuf;
use std::process::Stdio;
use serde::Deserialize;
use tokio::process::Command;
use crate::backend::{BackendConfig, DownloadBackend, DownloadResult, DownloadTarget, SearchResult};
use crate::error::{DlError, DlResult};
use crate::rate_limit::RateLimiter;
/// Search source for yt-dlp.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum SearchSource {
YouTubeMusic,
YouTube,
}
impl SearchSource {
fn prefix(&self) -> &str {
match self {
SearchSource::YouTubeMusic => "ytmusicsearch5",
SearchSource::YouTube => "ytsearch5",
}
}
fn source_name(&self) -> &str {
match self {
SearchSource::YouTubeMusic => "youtube_music",
SearchSource::YouTube => "youtube",
}
}
}
impl std::str::FromStr for SearchSource {
type Err = String;
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s.to_lowercase().as_str() {
"ytmusic" | "youtube_music" | "youtubemusic" => Ok(SearchSource::YouTubeMusic),
"youtube" | "yt" => Ok(SearchSource::YouTube),
_ => Err(format!("unknown search source: {s} (expected ytmusic or youtube)")),
}
}
}
impl std::fmt::Display for SearchSource {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
SearchSource::YouTubeMusic => write!(f, "ytmusic"),
SearchSource::YouTube => write!(f, "youtube"),
}
}
}
/// yt-dlp backend for downloading music.
pub struct YtDlpBackend {
rate_limiter: RateLimiter,
search_source: SearchSource,
fallback_search: bool,
cookies_path: Option<PathBuf>,
}
impl YtDlpBackend {
pub fn new(
rate_limit_per_hour: u32,
search_source: SearchSource,
cookies_path: Option<PathBuf>,
) -> Self {
Self {
rate_limiter: RateLimiter::new(rate_limit_per_hour),
search_source,
fallback_search: true,
cookies_path,
}
}
/// Run a yt-dlp command and return stdout.
async fn run_ytdlp(&self, args: &[&str]) -> DlResult<String> {
self.rate_limiter.acquire().await;
let mut cmd = Command::new("yt-dlp");
cmd.args(args)
.stdout(Stdio::piped())
.stderr(Stdio::piped());
// Add cookies if configured
if let Some(ref cookies) = self.cookies_path {
cmd.arg("--cookies").arg(cookies);
}
tracing::debug!(args = ?args, "running yt-dlp");
let output = cmd.output().await.map_err(|e| {
if e.kind() == std::io::ErrorKind::NotFound {
DlError::BackendNotFound("yt-dlp not found on PATH".into())
} else {
DlError::Io(e)
}
})?;
if !output.status.success() {
let stderr = String::from_utf8_lossy(&output.stderr).to_string();
tracing::error!(stderr = %stderr, "yt-dlp failed");
if stderr.contains("429") || stderr.contains("Too Many Requests") {
return Err(DlError::RateLimited(stderr));
}
return Err(DlError::BackendError(stderr));
}
Ok(String::from_utf8_lossy(&output.stdout).to_string())
}
/// Search using a specific source prefix.
async fn search_with_source(
&self,
query: &str,
source: SearchSource,
) -> DlResult<Vec<SearchResult>> {
let search_query = format!("{}:{}", source.prefix(), query);
let output = self
.run_ytdlp(&[
"--dump-json",
"--flat-playlist",
"--no-download",
&search_query,
])
.await?;
let mut results = Vec::new();
for line in output.lines() {
if line.trim().is_empty() {
continue;
}
match serde_json::from_str::<YtDlpSearchEntry>(line) {
Ok(entry) => {
results.push(SearchResult {
url: entry.url.or(entry.webpage_url).unwrap_or_default(),
title: entry.title.unwrap_or_else(|| "Unknown".into()),
artist: entry.artist.or(entry.uploader).or(entry.channel),
duration: entry.duration,
source: source.source_name().into(),
});
}
Err(e) => {
tracing::debug!(error = %e, "failed to parse search result line");
}
}
}
Ok(results)
}
}
impl DownloadBackend for YtDlpBackend {
async fn check_available(&self) -> DlResult<()> {
let output = self.run_ytdlp(&["--version"]).await?;
tracing::info!(version = output.trim(), "yt-dlp available");
Ok(())
}
async fn search(&self, query: &str) -> DlResult<Vec<SearchResult>> {
// Try primary search source
let results = self.search_with_source(query, self.search_source).await;
match results {
Ok(ref r) if !r.is_empty() => return results,
Ok(_) => {
tracing::debug!(source = %self.search_source, "no results from primary source");
}
Err(ref e) => {
tracing::debug!(source = %self.search_source, error = %e, "primary search failed");
}
}
// Fallback to the other source
if self.fallback_search && self.search_source == SearchSource::YouTubeMusic {
tracing::info!("falling back to YouTube search");
let fallback = self
.search_with_source(query, SearchSource::YouTube)
.await?;
if !fallback.is_empty() {
return Ok(fallback);
}
}
// If primary returned Ok([]) and fallback also empty, return that
// If primary returned Err, propagate it
results
}
async fn download(
&self,
target: &DownloadTarget,
config: &BackendConfig,
) -> DlResult<DownloadResult> {
let url = match target {
DownloadTarget::Url(u) => u.clone(),
DownloadTarget::Query(q) => {
// Search first, take the best result
let results = self.search(q).await?;
let best = results
.into_iter()
.next()
.ok_or_else(|| DlError::NoResults(q.clone()))?;
tracing::info!(
title = %best.title,
artist = ?best.artist,
url = %best.url,
"selected best search result"
);
best.url
}
};
// Build output template
let output_template = config
.output_dir
.join("%(title)s.%(ext)s")
.to_string_lossy()
.to_string();
let format_arg = config.format.as_ytdlp_arg();
let mut args = vec![
"--extract-audio",
"--audio-format",
format_arg,
"--audio-quality",
"0",
"--output",
&output_template,
"--print-json",
"--no-playlist",
];
// Add cookies from backend config or backend's own cookies
let cookies_str;
let cookies_path = config
.cookies_path
.as_ref()
.or(self.cookies_path.as_ref());
if let Some(c) = cookies_path {
cookies_str = c.to_string_lossy().to_string();
args.push("--cookies");
args.push(&cookies_str);
}
args.push(&url);
let output = self.run_ytdlp(&args).await?;
// Parse the JSON output to get the actual file path
let info: YtDlpDownloadInfo = serde_json::from_str(output.trim()).map_err(|e| {
DlError::BackendError(format!("failed to parse yt-dlp output: {e}"))
})?;
// yt-dlp may change the extension after conversion
let file_path = if let Some(ref requested) = info.requested_downloads {
if let Some(first) = requested.first() {
PathBuf::from(&first.filepath)
} else {
PathBuf::from(&info.filename)
}
} else {
PathBuf::from(&info.filename)
};
tracing::info!(
path = %file_path.display(),
title = %info.title,
"download complete"
);
Ok(DownloadResult {
file_path,
title: info.title,
artist: info.artist.or(info.uploader).or(info.channel),
duration: info.duration,
source_url: url,
})
}
}
// --- yt-dlp JSON output types ---
#[derive(Debug, Deserialize)]
struct YtDlpSearchEntry {
url: Option<String>,
webpage_url: Option<String>,
title: Option<String>,
artist: Option<String>,
uploader: Option<String>,
channel: Option<String>,
duration: Option<f64>,
}
#[derive(Debug, Deserialize)]
struct YtDlpDownloadInfo {
title: String,
artist: Option<String>,
uploader: Option<String>,
channel: Option<String>,
duration: Option<f64>,
filename: String,
requested_downloads: Option<Vec<YtDlpRequestedDownload>>,
}
#[derive(Debug, Deserialize)]
struct YtDlpRequestedDownload {
filepath: String,
}
#[cfg(test)]
mod tests {
use super::*;
use crate::backend::AudioFormat;
#[test]
fn test_search_source_parse() {
assert_eq!(
"ytmusic".parse::<SearchSource>().unwrap(),
SearchSource::YouTubeMusic
);
assert_eq!(
"youtube".parse::<SearchSource>().unwrap(),
SearchSource::YouTube
);
assert!("invalid".parse::<SearchSource>().is_err());
}
#[test]
fn test_audio_format_parse() {
assert_eq!("opus".parse::<AudioFormat>().unwrap(), AudioFormat::Opus);
assert_eq!("mp3".parse::<AudioFormat>().unwrap(), AudioFormat::Mp3);
assert_eq!("flac".parse::<AudioFormat>().unwrap(), AudioFormat::Flac);
assert_eq!("best".parse::<AudioFormat>().unwrap(), AudioFormat::Best);
assert!("wav".parse::<AudioFormat>().is_err());
}
#[test]
fn test_search_source_prefix() {
assert_eq!(SearchSource::YouTubeMusic.prefix(), "ytmusicsearch5");
assert_eq!(SearchSource::YouTube.prefix(), "ytsearch5");
}
#[test]
fn test_parse_search_entry() {
let json = r#"{"url": "https://youtube.com/watch?v=abc", "title": "Time", "artist": "Pink Floyd", "duration": 413.0}"#;
let entry: YtDlpSearchEntry = serde_json::from_str(json).unwrap();
assert_eq!(entry.title.as_deref(), Some("Time"));
assert_eq!(entry.artist.as_deref(), Some("Pink Floyd"));
assert_eq!(entry.duration, Some(413.0));
}
#[test]
fn test_parse_download_info() {
let json = r#"{
"title": "Time",
"artist": "Pink Floyd",
"uploader": null,
"channel": null,
"duration": 413.0,
"filename": "/tmp/Time.opus",
"requested_downloads": [{"filepath": "/tmp/Time.opus"}]
}"#;
let info: YtDlpDownloadInfo = serde_json::from_str(json).unwrap();
assert_eq!(info.title, "Time");
assert_eq!(info.filename, "/tmp/Time.opus");
}
}

239
tests/integration.rs Normal file
View File

@@ -0,0 +1,239 @@
use shanty_db::entities::download_queue::DownloadStatus;
use shanty_db::{Database, queries};
use shanty_dl::backend::{AudioFormat, BackendConfig, DownloadBackend, DownloadResult, DownloadTarget, SearchResult};
use shanty_dl::error::DlResult;
use shanty_dl::queue::{run_queue, download_single};
use tempfile::TempDir;
/// Mock backend for testing without yt-dlp.
struct MockBackend {
/// If true, downloads will fail.
should_fail: bool,
}
impl MockBackend {
fn new(should_fail: bool) -> Self {
Self { should_fail }
}
}
impl DownloadBackend for MockBackend {
async fn check_available(&self) -> DlResult<()> {
Ok(())
}
async fn search(&self, query: &str) -> DlResult<Vec<SearchResult>> {
Ok(vec![SearchResult {
url: format!("https://example.com/{}", query.replace(' ', "_")),
title: query.to_string(),
artist: Some("Test Artist".to_string()),
duration: Some(180.0),
source: "mock".to_string(),
}])
}
async fn download(
&self,
target: &DownloadTarget,
config: &BackendConfig,
) -> DlResult<DownloadResult> {
if self.should_fail {
return Err(shanty_dl::DlError::DownloadFailed("mock failure".into()));
}
let title = match target {
DownloadTarget::Url(u) => u.clone(),
DownloadTarget::Query(q) => q.clone(),
};
let file_name = format!("{}.{}", title.replace(' ', "_"), config.format);
let file_path = config.output_dir.join(&file_name);
std::fs::write(&file_path, b"fake audio data").unwrap();
Ok(DownloadResult {
file_path,
title,
artist: Some("Test Artist".into()),
duration: Some(180.0),
source_url: "https://example.com/test".into(),
})
}
}
async fn test_db() -> Database {
Database::new("sqlite::memory:")
.await
.expect("failed to create test database")
}
#[tokio::test]
async fn test_queue_process_success() {
let db = test_db().await;
let dir = TempDir::new().unwrap();
let backend = MockBackend::new(false);
let config = BackendConfig {
output_dir: dir.path().to_owned(),
format: AudioFormat::Opus,
cookies_path: None,
};
// Enqueue an item
queries::downloads::enqueue(db.conn(), "Test Song", None, "mock")
.await
.unwrap();
// Process queue
let stats = run_queue(db.conn(), &backend, &config, false).await.unwrap();
assert_eq!(stats.downloads_attempted, 1);
assert_eq!(stats.downloads_completed, 1);
assert_eq!(stats.downloads_failed, 0);
// Verify status in DB
let items = queries::downloads::list(db.conn(), Some(DownloadStatus::Completed))
.await
.unwrap();
assert_eq!(items.len(), 1);
}
#[tokio::test]
async fn test_queue_process_failure() {
let db = test_db().await;
let dir = TempDir::new().unwrap();
let backend = MockBackend::new(true);
let config = BackendConfig {
output_dir: dir.path().to_owned(),
format: AudioFormat::Opus,
cookies_path: None,
};
queries::downloads::enqueue(db.conn(), "Failing Song", None, "mock")
.await
.unwrap();
let stats = run_queue(db.conn(), &backend, &config, false).await.unwrap();
assert_eq!(stats.downloads_attempted, 1);
assert_eq!(stats.downloads_failed, 1);
// Check it's marked as failed
let items = queries::downloads::list(db.conn(), Some(DownloadStatus::Failed))
.await
.unwrap();
assert_eq!(items.len(), 1);
assert!(items[0].error_message.is_some());
}
#[tokio::test]
async fn test_queue_dry_run() {
let db = test_db().await;
let dir = TempDir::new().unwrap();
let backend = MockBackend::new(false);
let config = BackendConfig {
output_dir: dir.path().to_owned(),
format: AudioFormat::Opus,
cookies_path: None,
};
queries::downloads::enqueue(db.conn(), "Dry Run Song", None, "mock")
.await
.unwrap();
let stats = run_queue(db.conn(), &backend, &config, true).await.unwrap();
assert_eq!(stats.downloads_attempted, 1);
assert_eq!(stats.downloads_skipped, 1);
assert_eq!(stats.downloads_completed, 0);
// No files should exist
let entries: Vec<_> = std::fs::read_dir(dir.path()).unwrap().collect();
assert!(entries.is_empty());
}
#[tokio::test]
async fn test_download_single_success() {
let dir = TempDir::new().unwrap();
let backend = MockBackend::new(false);
let config = BackendConfig {
output_dir: dir.path().to_owned(),
format: AudioFormat::Mp3,
cookies_path: None,
};
download_single(
&backend,
DownloadTarget::Query("Test Song".into()),
&config,
false,
)
.await
.unwrap();
// File should exist
let entries: Vec<_> = std::fs::read_dir(dir.path())
.unwrap()
.filter_map(|e| e.ok())
.collect();
assert_eq!(entries.len(), 1);
}
#[tokio::test]
async fn test_queue_retry() {
let db = test_db().await;
// Enqueue and manually fail an item
let item = queries::downloads::enqueue(db.conn(), "Retry Song", None, "mock")
.await
.unwrap();
queries::downloads::update_status(db.conn(), item.id, DownloadStatus::Failed, Some("oops"))
.await
.unwrap();
// Retry it
queries::downloads::retry_failed(db.conn(), item.id)
.await
.unwrap();
// Should be pending again with retry_count = 1
let pending = queries::downloads::list(db.conn(), Some(DownloadStatus::Pending))
.await
.unwrap();
assert_eq!(pending.len(), 1);
assert_eq!(pending[0].retry_count, 1);
}
#[tokio::test]
async fn test_wanted_item_status_updated_on_download() {
use shanty_db::entities::wanted_item::{ItemType, WantedStatus};
let db = test_db().await;
let dir = TempDir::new().unwrap();
let backend = MockBackend::new(false);
let config = BackendConfig {
output_dir: dir.path().to_owned(),
format: AudioFormat::Opus,
cookies_path: None,
};
// Create a wanted item
let wanted = queries::wanted::add(db.conn(), ItemType::Track, None, None, None)
.await
.unwrap();
assert_eq!(wanted.status, WantedStatus::Wanted);
// Enqueue download linked to the wanted item
queries::downloads::enqueue(db.conn(), "Wanted Song", Some(wanted.id), "mock")
.await
.unwrap();
// Process queue
run_queue(db.conn(), &backend, &config, false).await.unwrap();
// Wanted item should now be Downloaded
let updated = queries::wanted::get_by_id(db.conn(), wanted.id)
.await
.unwrap();
assert_eq!(updated.status, WantedStatus::Downloaded);
}