From 6f73bb87cea262e01bb8a5ab7df42fbd64a44cbe Mon Sep 17 00:00:00 2001 From: Connor Johnstone Date: Fri, 20 Mar 2026 18:09:47 -0400 Subject: [PATCH] Added the playlist generator --- Cargo.lock | 7 +- shanty-data/src/http.rs | 7 +- shanty-data/src/lastfm.rs | 213 ++++++++++++- shanty-data/src/lib.rs | 2 +- shanty-data/src/musicbrainz.rs | 33 +- shanty-data/src/traits.rs | 19 +- shanty-data/src/types.rs | 16 + shanty-db | 2 +- shanty-playlist/Cargo.toml | 12 +- shanty-playlist/src/lib.rs | 9 + shanty-playlist/src/ordering.rs | 57 ++++ shanty-playlist/src/scoring.rs | 151 +++++++++ shanty-playlist/src/selection.rs | 95 ++++++ shanty-playlist/src/strategies.rs | 490 ++++++++++++++++++++++++++++++ shanty-playlist/src/types.rs | 93 ++++++ shanty-playlist/tests/unit.rs | 333 ++++++++++++++++++++ shanty-search | 2 +- shanty-web | 2 +- src/main.rs | 4 +- 19 files changed, 1526 insertions(+), 21 deletions(-) create mode 100644 shanty-playlist/src/ordering.rs create mode 100644 shanty-playlist/src/scoring.rs create mode 100644 shanty-playlist/src/selection.rs create mode 100644 shanty-playlist/src/strategies.rs create mode 100644 shanty-playlist/src/types.rs create mode 100644 shanty-playlist/tests/unit.rs diff --git a/Cargo.lock b/Cargo.lock index 5447144..b82e20e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3229,8 +3229,12 @@ dependencies = [ name = "shanty-playlist" version = "0.1.0" dependencies = [ - "clap", + "chrono", + "rand 0.9.2", + "sea-orm", "serde", + "serde_json", + "shanty-data", "shanty-db", "thiserror", "tracing", @@ -3341,6 +3345,7 @@ dependencies = [ "shanty-dl", "shanty-index", "shanty-org", + "shanty-playlist", "shanty-search", "shanty-tag", "shanty-watch", diff --git a/shanty-data/src/http.rs b/shanty-data/src/http.rs index 124b504..1e4de12 100644 --- a/shanty-data/src/http.rs +++ b/shanty-data/src/http.rs @@ -1,18 +1,21 @@ +use std::sync::Arc; use std::time::Duration; use tokio::sync::Mutex; use tokio::time::Instant; /// A simple rate limiter that enforces a minimum interval between requests. +/// Can be cloned (via Arc) to share across multiple clients. +#[derive(Clone)] pub struct RateLimiter { - last_request: Mutex, + last_request: Arc>, interval: Duration, } impl RateLimiter { pub fn new(interval: Duration) -> Self { Self { - last_request: Mutex::new(Instant::now() - interval), + last_request: Arc::new(Mutex::new(Instant::now() - interval)), interval, } } diff --git a/shanty-data/src/lastfm.rs b/shanty-data/src/lastfm.rs index 4488aec..f7f36f7 100644 --- a/shanty-data/src/lastfm.rs +++ b/shanty-data/src/lastfm.rs @@ -1,7 +1,9 @@ -use crate::error::DataResult; +use serde::Deserialize; + +use crate::error::{DataError, DataResult}; use crate::http::{build_client, urlencoded}; -use crate::traits::ArtistBioFetcher; -use crate::types::ArtistInfo; +use crate::traits::{ArtistBioFetcher, SimilarArtistFetcher}; +use crate::types::{ArtistInfo, PopularTrack, SimilarArtist}; const USER_AGENT: &str = "Shanty/0.1.0 (shanty-music-app)"; @@ -51,6 +53,211 @@ impl ArtistBioFetcher for LastFmBioFetcher { } } +// --- Similar artist fetcher (ported from drift) --- + +const BASE_URL: &str = "https://ws.audioscrobbler.com/2.0/"; + +/// Fetches similar artists and top tracks from Last.fm. +pub struct LastFmSimilarFetcher { + api_key: String, + client: reqwest::Client, +} + +impl LastFmSimilarFetcher { + pub fn new(api_key: String) -> DataResult { + let client = build_client(USER_AGENT, 30)?; + Ok(Self { api_key, client }) + } + + /// Normalize Unicode hyphens/quotes to ASCII, then URL-encode. + fn normalize_name(name: &str) -> String { + name.replace( + [ + '\u{2010}', '\u{2011}', '\u{2012}', '\u{2013}', '\u{2014}', '\u{2015}', + ], + "-", + ) + .replace(['\u{2018}', '\u{2019}'], "'") + } + + /// Fetch a URL and return the body, or None if Last.fm returns an API error. + async fn fetch_or_none(&self, url: &str) -> DataResult> { + let resp = match self.client.get(url).send().await { + Ok(r) if r.status().is_success() => r, + Ok(r) => { + tracing::debug!(status = %r.status(), url = url, "Last.fm non-success"); + return Ok(None); + } + Err(e) => return Err(DataError::Http(e)), + }; + let body = resp.text().await?; + // Check for Last.fm error response + if serde_json::from_str::(&body).is_ok() { + return Ok(None); + } + Ok(Some(body)) + } + + /// Fetch by artist name. + async fn fetch_by_name( + &self, + method: &str, + artist_name: &str, + extra_params: &str, + ) -> DataResult> { + let name = Self::normalize_name(artist_name); + let encoded = urlencoded(&name); + let url = format!( + "{}?method={}&artist={}&api_key={}{}&format=json", + BASE_URL, method, encoded, self.api_key, extra_params + ); + self.fetch_or_none(&url).await + } + + /// Try MBID lookup then name lookup, returning whichever yields more results. + async fn dual_lookup( + &self, + method: &str, + artist_name: &str, + mbid: Option<&str>, + extra_params: &str, + parse: fn(&str) -> DataResult>, + ) -> DataResult> { + let mbid_results = if let Some(mbid) = mbid { + let url = format!( + "{}?method={}&mbid={}&api_key={}{}&format=json", + BASE_URL, method, mbid, self.api_key, extra_params + ); + match self.fetch_or_none(&url).await? { + Some(body) => parse(&body).unwrap_or_default(), + None => Vec::new(), + } + } else { + Vec::new() + }; + + let name_results = match self + .fetch_by_name(method, artist_name, extra_params) + .await? + { + Some(body) => parse(&body).unwrap_or_default(), + None => Vec::new(), + }; + + if name_results.len() > mbid_results.len() { + Ok(name_results) + } else { + Ok(mbid_results) + } + } + + fn parse_similar_artists(body: &str) -> DataResult> { + let resp: LfmSimilarArtistsResponse = serde_json::from_str(body)?; + Ok(resp + .similarartists + .artist + .into_iter() + .map(|a| { + let mbid = a.mbid.filter(|s| !s.is_empty()); + SimilarArtist { + name: a.name, + mbid, + match_score: a.match_score.parse().unwrap_or(0.0), + } + }) + .collect()) + } + + fn parse_top_tracks(body: &str) -> DataResult> { + let resp: LfmTopTracksResponse = serde_json::from_str(body)?; + Ok(resp + .toptracks + .track + .into_iter() + .map(|t| PopularTrack { + name: t.name, + mbid: t.mbid.filter(|s| !s.is_empty()), + playcount: t.playcount.parse().unwrap_or(0), + }) + .collect()) + } +} + +impl SimilarArtistFetcher for LastFmSimilarFetcher { + async fn get_similar_artists( + &self, + artist_name: &str, + mbid: Option<&str>, + ) -> DataResult> { + self.dual_lookup( + "artist.getSimilar", + artist_name, + mbid, + "&limit=500", + Self::parse_similar_artists, + ) + .await + } + + async fn get_top_tracks( + &self, + artist_name: &str, + mbid: Option<&str>, + ) -> DataResult> { + self.dual_lookup( + "artist.getTopTracks", + artist_name, + mbid, + "&limit=1000", + Self::parse_top_tracks, + ) + .await + } +} + +// Last.fm JSON response structs + +#[derive(Deserialize)] +struct LfmApiError { + #[allow(dead_code)] + error: u32, +} + +#[derive(Deserialize)] +struct LfmSimilarArtistsResponse { + similarartists: LfmSimilarArtistsWrapper, +} + +#[derive(Deserialize)] +struct LfmSimilarArtistsWrapper { + artist: Vec, +} + +#[derive(Deserialize)] +struct LfmArtistEntry { + name: String, + mbid: Option, + #[serde(rename = "match")] + match_score: String, +} + +#[derive(Deserialize)] +struct LfmTopTracksResponse { + toptracks: LfmTopTracksWrapper, +} + +#[derive(Deserialize)] +struct LfmTopTracksWrapper { + track: Vec, +} + +#[derive(Deserialize)] +struct LfmTrackEntry { + name: String, + mbid: Option, + playcount: String, +} + /// Strip HTML tags from a string with a simple approach. fn strip_html_tags(s: &str) -> String { let mut result = String::with_capacity(s.len()); diff --git a/shanty-data/src/lib.rs b/shanty-data/src/lib.rs index 22604c7..7f4c72c 100644 --- a/shanty-data/src/lib.rs +++ b/shanty-data/src/lib.rs @@ -12,7 +12,7 @@ pub mod wikipedia; pub use coverart::CoverArtArchiveFetcher; pub use error::{DataError, DataResult}; pub use fanarttv::FanartTvFetcher; -pub use lastfm::LastFmBioFetcher; +pub use lastfm::{LastFmBioFetcher, LastFmSimilarFetcher}; pub use lrclib::LrclibFetcher; pub use musicbrainz::MusicBrainzFetcher; pub use traits::*; diff --git a/shanty-data/src/musicbrainz.rs b/shanty-data/src/musicbrainz.rs index f5b221e..e8b174b 100644 --- a/shanty-data/src/musicbrainz.rs +++ b/shanty-data/src/musicbrainz.rs @@ -21,14 +21,21 @@ pub struct MusicBrainzFetcher { impl MusicBrainzFetcher { pub fn new() -> DataResult { + Self::with_limiter(RateLimiter::new(RATE_LIMIT)) + } + + /// Create a fetcher that shares a rate limiter with other MB clients. + pub fn with_limiter(limiter: RateLimiter) -> DataResult { let client = reqwest::Client::builder() .user_agent(USER_AGENT) .timeout(Duration::from_secs(30)) .build()?; - Ok(Self { - client, - limiter: RateLimiter::new(RATE_LIMIT), - }) + Ok(Self { client, limiter }) + } + + /// Get a clone of the rate limiter for sharing with other MB clients. + pub fn limiter(&self) -> RateLimiter { + self.limiter.clone() } async fn get_json(&self, url: &str) -> DataResult { @@ -84,6 +91,24 @@ impl MusicBrainzFetcher { urls, }) } + + /// Resolve a release-group MBID to a release MBID (first release in the group). + pub async fn resolve_release_from_group(&self, release_group_mbid: &str) -> DataResult { + let url = format!("{BASE_URL}/release?release-group={release_group_mbid}&fmt=json&limit=1"); + let resp: serde_json::Value = self.get_json(&url).await?; + + resp.get("releases") + .and_then(|r| r.as_array()) + .and_then(|arr| arr.first()) + .and_then(|r| r.get("id")) + .and_then(|id| id.as_str()) + .map(String::from) + .ok_or_else(|| { + DataError::Other(format!( + "no releases for release-group {release_group_mbid}" + )) + }) + } } impl MetadataFetcher for MusicBrainzFetcher { diff --git a/shanty-data/src/traits.rs b/shanty-data/src/traits.rs index cb5bc8e..eb9c433 100644 --- a/shanty-data/src/traits.rs +++ b/shanty-data/src/traits.rs @@ -2,8 +2,8 @@ use std::future::Future; use crate::error::DataResult; use crate::types::{ - ArtistInfo, ArtistSearchResult, DiscographyEntry, LyricsResult, RecordingDetails, - RecordingMatch, ReleaseGroupEntry, ReleaseMatch, ReleaseTrack, + ArtistInfo, ArtistSearchResult, DiscographyEntry, LyricsResult, PopularTrack, RecordingDetails, + RecordingMatch, ReleaseGroupEntry, ReleaseMatch, ReleaseTrack, SimilarArtist, }; /// Trait for metadata lookup backends. MusicBrainz is the default implementation; @@ -86,6 +86,21 @@ pub trait LyricsFetcher: Send + Sync { ) -> impl Future> + Send; } +/// Fetches similar artists and top tracks from an external source (e.g. Last.fm). +pub trait SimilarArtistFetcher: Send + Sync { + fn get_similar_artists( + &self, + artist_name: &str, + mbid: Option<&str>, + ) -> impl Future>> + Send; + + fn get_top_tracks( + &self, + artist_name: &str, + mbid: Option<&str>, + ) -> impl Future>> + Send; +} + /// Fetches cover art URLs for releases. pub trait CoverArtFetcher: Send + Sync { fn get_cover_art_url(&self, release_id: &str) -> Option; diff --git a/shanty-data/src/types.rs b/shanty-data/src/types.rs index c8e5389..c99267d 100644 --- a/shanty-data/src/types.rs +++ b/shanty-data/src/types.rs @@ -111,6 +111,22 @@ pub struct ReleaseTrack { pub duration_ms: Option, } +/// A similar artist returned by Last.fm or another provider. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct SimilarArtist { + pub name: String, + pub mbid: Option, + pub match_score: f64, +} + +/// A popular/top track for an artist from Last.fm. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct PopularTrack { + pub name: String, + pub mbid: Option, + pub playcount: u64, +} + /// Result from a lyrics lookup. #[derive(Debug, Clone, Serialize, Deserialize)] pub struct LyricsResult { diff --git a/shanty-db b/shanty-db index 8a1435d..f03f8f0 160000 --- a/shanty-db +++ b/shanty-db @@ -1 +1 @@ -Subproject commit 8a1435d9e94135617bd55c8bbb74d21a1622cc88 +Subproject commit f03f8f0362d84a45ea1f6cd75d5d39f88f8897fb diff --git a/shanty-playlist/Cargo.toml b/shanty-playlist/Cargo.toml index fc77739..b21b6ad 100644 --- a/shanty-playlist/Cargo.toml +++ b/shanty-playlist/Cargo.toml @@ -7,7 +7,11 @@ description = "Playlist generation for Shanty" [dependencies] shanty-db = { path = "../shanty-db" } -clap = { workspace = true } -serde = { workspace = true } -thiserror = { workspace = true } -tracing = { workspace = true } +shanty-data = { path = "../shanty-data" } +sea-orm = { version = "1", features = ["sqlx-sqlite", "runtime-tokio-native-tls"] } +serde = { version = "1", features = ["derive"] } +serde_json = "1" +rand = "0.9" +tracing = "0.1" +thiserror = "2" +chrono = "0.4" diff --git a/shanty-playlist/src/lib.rs b/shanty-playlist/src/lib.rs index 1116a2f..9475ea5 100644 --- a/shanty-playlist/src/lib.rs +++ b/shanty-playlist/src/lib.rs @@ -2,3 +2,12 @@ //! //! Generates playlists based on the indexed music library using strategies like //! similar artists, genre matching, smart rules, and weighted random selection. + +pub mod ordering; +pub mod scoring; +pub mod selection; +pub mod strategies; +pub mod types; + +pub use strategies::{PlaylistError, genre_based, random, similar_artists, smart, to_m3u}; +pub use types::{Candidate, PlaylistRequest, PlaylistResult, PlaylistTrack, SmartRules}; diff --git a/shanty-playlist/src/ordering.rs b/shanty-playlist/src/ordering.rs new file mode 100644 index 0000000..9709545 --- /dev/null +++ b/shanty-playlist/src/ordering.rs @@ -0,0 +1,57 @@ +use std::collections::BTreeMap; + +use rand::prelude::*; + +use crate::types::Candidate; + +/// Reorder tracks so that artists are evenly spread out. +/// Greedily picks from the artist with the most remaining tracks, +/// avoiding back-to-back repeats when possible. +/// Ported faithfully from drift's interleave_artists(). +pub fn interleave_artists(tracks: Vec) -> Vec { + let mut rng = rand::rng(); + + let mut by_artist: BTreeMap> = BTreeMap::new(); + for track in tracks { + by_artist + .entry(track.artist.clone()) + .or_default() + .push(track); + } + for group in by_artist.values_mut() { + group.shuffle(&mut rng); + } + + let mut result = Vec::new(); + let mut last_artist: Option = None; + + while !by_artist.is_empty() { + let mut artists: Vec = by_artist.keys().cloned().collect(); + artists.sort_by(|a, b| by_artist[b].len().cmp(&by_artist[a].len())); + + let pick = artists + .iter() + .find(|a| last_artist.as_ref() != Some(a)) + .or(artists.first()) + .cloned() + .unwrap(); + + let group = by_artist.get_mut(&pick).unwrap(); + let track = group.pop().unwrap(); + if group.is_empty() { + by_artist.remove(&pick); + } + + last_artist = Some(pick); + result.push(track); + } + + result +} + +/// Full random shuffle. +pub fn shuffle(mut tracks: Vec) -> Vec { + let mut rng = rand::rng(); + tracks.shuffle(&mut rng); + tracks +} diff --git a/shanty-playlist/src/scoring.rs b/shanty-playlist/src/scoring.rs new file mode 100644 index 0000000..2910398 --- /dev/null +++ b/shanty-playlist/src/scoring.rs @@ -0,0 +1,151 @@ +use std::collections::HashMap; + +use shanty_data::PopularTrack; +use shanty_db::entities::track::Model as Track; + +use crate::types::ScoredTrack; + +/// Popularity exponent curve (0-10 scale). +/// 0 = no preference, 10 = heavy popular bias. +const POPULARITY_EXPONENTS: [f64; 11] = [ + 0.0, 0.06, 0.17, 0.33, 0.67, 1.30, 1.50, 1.70, 1.94, 2.22, 2.50, +]; + +/// Score all tracks for the given artists, returning scored tracks for ranking. +/// +/// `artists` is a list of (mbid_or_name, display_name, similarity_score) tuples. +/// `tracks_by_artist` maps artist identifier -> their local tracks. +/// `top_tracks_by_artist` maps artist identifier -> their Last.fm top tracks. +pub fn score_tracks( + artists: &[(String, String, f64)], + tracks_by_artist: &HashMap>, + top_tracks_by_artist: &HashMap>, + popularity_bias: u8, +) -> Vec { + let bias = popularity_bias.min(10) as usize; + let mut scored = Vec::new(); + + for (artist_key, name, match_score) in artists { + let local_tracks = match tracks_by_artist.get(artist_key) { + Some(t) if !t.is_empty() => t, + _ => continue, + }; + + let top_tracks = top_tracks_by_artist + .get(artist_key) + .cloned() + .unwrap_or_default(); + + // Build playcount lookup by lowercase name + let playcount_by_name: HashMap = top_tracks + .iter() + .map(|t| (t.name.to_lowercase(), t.playcount)) + .collect(); + + let max_playcount = playcount_by_name + .values() + .copied() + .max() + .unwrap_or(1) + .max(1); + + for track in local_tracks { + let title_lower = track.title.as_ref().map(|t| t.to_lowercase()); + + let playcount = title_lower + .as_ref() + .and_then(|t| playcount_by_name.get(t).copied()) + .or_else(|| { + track + .musicbrainz_id + .as_ref() + .and_then(|id| playcount_by_name.get(id).copied()) + }); + + // If we have popularity data, require a match; otherwise assign uniform score + let (popularity, similarity, score) = if !playcount_by_name.is_empty() { + let Some(playcount) = playcount else { + continue; + }; + + let popularity = if playcount > 0 { + (playcount as f64 / max_playcount as f64).powf(POPULARITY_EXPONENTS[bias]) + } else { + 0.0 + }; + + let similarity = (match_score.exp()) / std::f64::consts::E; + let score = similarity * popularity; + (popularity, similarity, score) + } else { + // No top tracks data — use uniform scoring based on similarity only + let similarity = (match_score.exp()) / std::f64::consts::E; + (1.0, similarity, similarity) + }; + + scored.push(ScoredTrack { + track_id: track.id, + file_path: track.file_path.clone(), + title: track.title.clone(), + artist: name.clone(), + artist_mbid: track + .artist_id + .map(|_| artist_key.clone()) + .or_else(|| Some(artist_key.clone())), + album: track.album.clone(), + duration: track.duration, + score, + popularity, + similarity, + }); + } + } + + // Step 1: Cap tracks per artist based on popularity bias + let mut by_artist: HashMap> = HashMap::new(); + for t in scored { + let key = t.artist_mbid.clone().unwrap_or_else(|| t.artist.clone()); + by_artist.entry(key).or_default().push(t); + } + + let cap = if popularity_bias == 0 { + None + } else { + let b = popularity_bias as f64; + let c = if b <= 5.0 { + 90.0 - 12.8 * b + } else { + 26.0 - 3.2 * (b - 5.0) + }; + Some((c.round() as usize).max(1)) + }; + + for group in by_artist.values_mut() { + group.sort_by(|a, b| { + b.score + .partial_cmp(&a.score) + .unwrap_or(std::cmp::Ordering::Equal) + }); + if let Some(cap) = cap { + group.truncate(cap); + } + } + + // Step 2: Normalize so each artist's total weight = their similarity + let similarity_map: HashMap<&str, f64> = artists + .iter() + .map(|(key, _, sim)| (key.as_str(), *sim)) + .collect(); + + for (key, group) in &mut by_artist { + let total: f64 = group.iter().map(|t| t.score).sum(); + if total > 0.0 { + let sim = similarity_map.get(key.as_str()).copied().unwrap_or(1.0); + for t in group.iter_mut() { + t.score *= sim / total; + } + } + } + + by_artist.into_values().flatten().collect() +} diff --git a/shanty-playlist/src/selection.rs b/shanty-playlist/src/selection.rs new file mode 100644 index 0000000..e00eece --- /dev/null +++ b/shanty-playlist/src/selection.rs @@ -0,0 +1,95 @@ +use std::collections::{HashMap, HashSet}; + +use rand::distr::weighted::WeightedIndex; +use rand::prelude::*; + +use crate::types::Candidate; + +/// Weighted random sampling with per-artist caps and seed enforcement. +/// Ported faithfully from drift's generate_playlist(). +pub fn generate_playlist( + candidates: &[Candidate], + n: usize, + seed_names: &HashSet, +) -> Vec { + if candidates.is_empty() { + return Vec::new(); + } + + let mut rng = rand::rng(); + let mut pool: Vec<&Candidate> = candidates.iter().collect(); + let mut result: Vec = Vec::new(); + let mut artist_counts: HashMap = HashMap::new(); + + let seed_min = (n / 10).max(1); + + let distinct_artists: usize = { + let mut seen = HashSet::new(); + for c in &pool { + seen.insert(&c.artist); + } + seen.len() + }; + + let divisor = match distinct_artists { + 1 => 1, + 2 => 2, + 3 => 3, + 4 => 3, + 5 => 4, + _ => 5, + }; + let artist_cap = n.div_ceil(divisor).max(1); + + while result.len() < n && !pool.is_empty() { + let seed_count: usize = seed_names + .iter() + .map(|name| *artist_counts.get(name).unwrap_or(&0)) + .sum(); + let remaining = n - result.len(); + let seed_deficit = seed_min.saturating_sub(seed_count); + let force_seed = seed_deficit > 0 && remaining <= seed_deficit; + + let eligible: Vec = pool + .iter() + .enumerate() + .filter(|(_, c)| { + if force_seed { + seed_names.contains(&c.artist) + } else { + *artist_counts.get(&c.artist).unwrap_or(&0) < artist_cap + } + }) + .map(|(i, _)| i) + .collect(); + + let fallback_indices: Vec = (0..pool.len()).collect(); + let indices: &[usize] = if eligible.is_empty() { + &fallback_indices + } else { + &eligible + }; + + let weights: Vec = indices.iter().map(|&i| pool[i].score.max(0.001)).collect(); + let dist = match WeightedIndex::new(&weights) { + Ok(d) => d, + Err(_) => break, + }; + + let picked = indices[dist.sample(&mut rng)]; + let track = pool.remove(picked); + *artist_counts.entry(track.artist.clone()).or_insert(0) += 1; + result.push(Candidate { + score: track.score, + artist: track.artist.clone(), + artist_mbid: track.artist_mbid.clone(), + track_id: track.track_id, + file_path: track.file_path.clone(), + title: track.title.clone(), + album: track.album.clone(), + duration: track.duration, + }); + } + + result +} diff --git a/shanty-playlist/src/strategies.rs b/shanty-playlist/src/strategies.rs new file mode 100644 index 0000000..466fb07 --- /dev/null +++ b/shanty-playlist/src/strategies.rs @@ -0,0 +1,490 @@ +use std::collections::{HashMap, HashSet}; + +use sea_orm::DatabaseConnection; +use shanty_data::{PopularTrack, SimilarArtist, SimilarArtistFetcher}; +use shanty_db::queries; + +use crate::ordering; +use crate::scoring; +use crate::selection; +use crate::types::*; + +/// Cache TTL: 7 days in seconds. +const CACHE_TTL: i64 = 7 * 24 * 3600; + +/// Generate a playlist based on similar artists (the primary strategy). +/// +/// Flow: +/// 1. For each seed artist: resolve MBID from DB +/// 2. Fetch similar artists (check cache first, else call Last.fm, cache result) +/// 3. Merge multi-seed: accumulate scores, normalize by seed count +/// 4. Filter: only keep artists that have tracks in the local library +/// 5. Score all tracks +/// 6. Select via weighted sampling +/// 7. Order (interleave or shuffle) +pub async fn similar_artists( + conn: &DatabaseConnection, + fetcher: &impl SimilarArtistFetcher, + seed_artists: Vec, + count: usize, + popularity_bias: u8, + ordering: &str, +) -> Result { + if seed_artists.is_empty() { + return Err(PlaylistError::InvalidInput( + "at least one seed artist is required".into(), + )); + } + + let num_seeds = seed_artists.len() as f64; + + // Merge similar artists from all seeds: key -> (name, total_score) + let mut merged: HashMap = HashMap::new(); + // Track resolved seed names for enforcement (use DB names, not raw input) + let mut resolved_seed_names: HashSet = HashSet::new(); + + for seed in &seed_artists { + // Resolve the seed artist: try name lookup in DB + let (artist_name, artist_mbid) = resolve_artist(conn, seed).await?; + resolved_seed_names.insert(artist_name.clone()); + + // Insert the seed itself with score 1.0 + let key = artist_mbid + .clone() + .unwrap_or_else(|| artist_name.to_lowercase()); + let entry = merged + .entry(key) + .or_insert_with(|| (artist_name.clone(), 0.0)); + entry.1 += 1.0; + + // Fetch similar artists (cached or fresh) + let similar = fetch_cached_similar(conn, fetcher, &artist_name, artist_mbid.as_deref()) + .await + .unwrap_or_default(); + + for sa in similar { + let sa_key = sa.mbid.clone().unwrap_or_else(|| sa.name.to_lowercase()); + let entry = merged + .entry(sa_key) + .or_insert_with(|| (sa.name.clone(), 0.0)); + entry.1 += sa.match_score; + } + } + + // Normalize scores by seed count + let artists: Vec<(String, String, f64)> = merged + .into_iter() + .map(|(key, (name, total))| (key, name, total / num_seeds)) + .collect(); + + // Build track and top-track maps for scoring + let mut tracks_by_artist: HashMap> = + HashMap::new(); + let mut top_tracks_by_artist: HashMap> = HashMap::new(); + + for (key, name, _) in &artists { + // Get local tracks for this artist + let local = get_artist_tracks(conn, key, name).await; + if local.is_empty() { + continue; + } + tracks_by_artist.insert(key.clone(), local); + + // Get top tracks (cached or fresh) + let top = fetch_cached_top_tracks(conn, fetcher, name, Some(key.as_str())) + .await + .unwrap_or_default(); + if !top.is_empty() { + top_tracks_by_artist.insert(key.clone(), top); + } + } + + // Score + let scored = scoring::score_tracks( + &artists, + &tracks_by_artist, + &top_tracks_by_artist, + popularity_bias, + ); + + // Convert to candidates + let candidates: Vec = scored + .into_iter() + .map(|t| Candidate { + score: t.score, + artist: t.artist, + artist_mbid: t.artist_mbid, + track_id: t.track_id, + file_path: t.file_path, + title: t.title, + album: t.album, + duration: t.duration, + }) + .collect(); + + // Select (use resolved DB names for seed enforcement, not raw input) + let selected = selection::generate_playlist(&candidates, count, &resolved_seed_names); + + // Order + let ordered = apply_ordering(selected, ordering); + + Ok(PlaylistResult { + tracks: candidates_to_tracks(ordered), + strategy: "similar".to_string(), + resolved_seeds: resolved_seed_names.into_iter().collect(), + }) +} + +/// Generate a genre-based playlist. +pub async fn genre_based( + conn: &DatabaseConnection, + genres: Vec, + count: usize, + ordering: &str, +) -> Result { + if genres.is_empty() { + return Err(PlaylistError::InvalidInput( + "at least one genre is required".into(), + )); + } + + let mut all_tracks = Vec::new(); + for genre in &genres { + let tracks = queries::tracks::get_by_genre(conn, genre) + .await + .map_err(|e| PlaylistError::Db(e.to_string()))?; + all_tracks.extend(tracks); + } + + // Deduplicate by track ID + let mut seen = HashSet::new(); + all_tracks.retain(|t| seen.insert(t.id)); + + // Convert to candidates with uniform scoring + let candidates: Vec = all_tracks + .into_iter() + .map(|t| Candidate { + score: 1.0, + artist: t.artist.clone().unwrap_or_default(), + artist_mbid: t.musicbrainz_id.clone(), + track_id: t.id, + file_path: t.file_path.clone(), + title: t.title.clone(), + album: t.album.clone(), + duration: t.duration, + }) + .collect(); + + let seed_names = HashSet::new(); + let selected = selection::generate_playlist(&candidates, count, &seed_names); + let ordered = apply_ordering(selected, ordering); + + Ok(PlaylistResult { + tracks: candidates_to_tracks(ordered), + strategy: "genre".to_string(), + resolved_seeds: vec![], + }) +} + +/// Generate a random playlist. +pub async fn random( + conn: &DatabaseConnection, + count: usize, + no_repeat_artist: bool, +) -> Result { + let tracks = queries::tracks::get_random(conn, count as u64 * 2) + .await + .map_err(|e| PlaylistError::Db(e.to_string()))?; + + let mut result = Vec::new(); + let mut seen_artists: HashSet = HashSet::new(); + + for t in tracks { + if result.len() >= count { + break; + } + let artist = t.artist.clone().unwrap_or_default(); + if no_repeat_artist && !artist.is_empty() && !seen_artists.insert(artist.clone()) { + continue; + } + result.push(PlaylistTrack { + track_id: t.id, + file_path: t.file_path.clone(), + title: t.title.clone(), + artist: t.artist.clone(), + album: t.album.clone(), + score: 0.0, + duration: t.duration, + }); + } + + Ok(PlaylistResult { + tracks: result, + strategy: "random".to_string(), + resolved_seeds: vec![], + }) +} + +/// Generate a smart playlist based on rules. +pub async fn smart( + conn: &DatabaseConnection, + rules: SmartRules, + count: usize, +) -> Result { + let mut all_tracks: Vec = Vec::new(); + + // Genre filter + if !rules.genres.is_empty() { + for genre in &rules.genres { + let tracks = queries::tracks::get_by_genre(conn, genre) + .await + .map_err(|e| PlaylistError::Db(e.to_string()))?; + all_tracks.extend(tracks); + } + } + + // Artist filter + if !rules.artists.is_empty() { + for artist_name in &rules.artists { + let tracks = queries::tracks::get_by_artist_name(conn, artist_name) + .await + .map_err(|e| PlaylistError::Db(e.to_string()))?; + all_tracks.extend(tracks); + } + } + + // Recently added filter + if let Some(days) = rules.added_within_days { + let tracks = queries::tracks::get_recent(conn, days, 10000) + .await + .map_err(|e| PlaylistError::Db(e.to_string()))?; + if all_tracks.is_empty() { + all_tracks = tracks; + } else { + let recent_ids: HashSet = tracks.iter().map(|t| t.id).collect(); + all_tracks.retain(|t| recent_ids.contains(&t.id)); + } + } + + // Year range filter + if let Some((min_year, max_year)) = rules.year_range { + all_tracks.retain(|t| { + t.year + .map(|y| y >= min_year && y <= max_year) + .unwrap_or(false) + }); + } + + // If no filters were specified, get random tracks + if rules.genres.is_empty() + && rules.artists.is_empty() + && rules.added_within_days.is_none() + && rules.year_range.is_none() + { + all_tracks = queries::tracks::get_random(conn, count as u64 * 2) + .await + .map_err(|e| PlaylistError::Db(e.to_string()))?; + } + + // Deduplicate + let mut seen = HashSet::new(); + all_tracks.retain(|t| seen.insert(t.id)); + + // Convert to candidates + let candidates: Vec = all_tracks + .into_iter() + .map(|t| Candidate { + score: 1.0, + artist: t.artist.clone().unwrap_or_default(), + artist_mbid: t.musicbrainz_id.clone(), + track_id: t.id, + file_path: t.file_path.clone(), + title: t.title.clone(), + album: t.album.clone(), + duration: t.duration, + }) + .collect(); + + let seed_names = HashSet::new(); + let selected = selection::generate_playlist(&candidates, count, &seed_names); + let ordered = ordering::interleave_artists(selected); + + Ok(PlaylistResult { + tracks: candidates_to_tracks(ordered), + strategy: "smart".to_string(), + resolved_seeds: vec![], + }) +} + +/// Generate an M3U playlist string from tracks. +pub fn to_m3u(tracks: &[PlaylistTrack]) -> String { + let mut out = String::from("#EXTM3U\n"); + for t in tracks { + let duration = t.duration.unwrap_or(0.0) as i64; + let artist = t.artist.as_deref().unwrap_or("Unknown"); + let title = t.title.as_deref().unwrap_or("Unknown"); + out.push_str(&format!("#EXTINF:{duration},{artist} - {title}\n")); + out.push_str(&t.file_path); + out.push('\n'); + } + out +} + +/// Apply the chosen ordering mode to selected candidates. +fn apply_ordering(mut candidates: Vec, mode: &str) -> Vec { + match mode { + "random" => ordering::shuffle(candidates), + "score" => { + candidates.sort_by(|a, b| { + b.score + .partial_cmp(&a.score) + .unwrap_or(std::cmp::Ordering::Equal) + }); + candidates + } + _ => ordering::interleave_artists(candidates), // "interleave" is default + } +} + +// --- Helper functions --- + +/// Resolve an artist name or MBID to (name, optional_mbid). +async fn resolve_artist( + conn: &DatabaseConnection, + query: &str, +) -> Result<(String, Option), PlaylistError> { + // Try as MBID first (if it looks like a UUID) + if query.len() == 36 && query.contains('-') { + let tracks = queries::tracks::get_by_artist_mbid(conn, query) + .await + .map_err(|e| PlaylistError::Db(e.to_string()))?; + if let Some(t) = tracks.first() { + let name = t.artist.clone().unwrap_or_else(|| query.to_string()); + return Ok((name, Some(query.to_string()))); + } + } + + // Try by name in DB + if let Ok(Some(artist)) = queries::artists::find_by_name(conn, query).await { + return Ok((artist.name, artist.musicbrainz_id)); + } + + // Fall back to using the query as the name + Ok((query.to_string(), None)) +} + +/// Get local tracks for an artist by key (MBID) or name. +async fn get_artist_tracks( + conn: &DatabaseConnection, + key: &str, + name: &str, +) -> Vec { + // Try by MBID first + if let Ok(tracks) = queries::tracks::get_by_artist_mbid(conn, key).await + && !tracks.is_empty() + { + return tracks; + } + + // Try by name + queries::tracks::get_by_artist_name(conn, name) + .await + .unwrap_or_default() +} + +/// Fetch similar artists from cache or Last.fm. +async fn fetch_cached_similar( + conn: &DatabaseConnection, + fetcher: &impl SimilarArtistFetcher, + artist_name: &str, + mbid: Option<&str>, +) -> Result, PlaylistError> { + let cache_key = format!( + "lastfm_similar:{}", + mbid.unwrap_or(&artist_name.to_lowercase()) + ); + + // Check cache + if let Ok(Some(json)) = queries::cache::get(conn, &cache_key).await + && let Ok(cached) = serde_json::from_str::>(&json) + { + tracing::debug!(artist = artist_name, "using cached similar artists"); + return Ok(cached); + } + + // Fetch from provider + let similar = fetcher + .get_similar_artists(artist_name, mbid) + .await + .map_err(|e| PlaylistError::FetchError(e.to_string()))?; + + // Cache the result + if let Ok(json) = serde_json::to_string(&similar) { + let _ = queries::cache::set(conn, &cache_key, "lastfm", &json, CACHE_TTL).await; + } + + Ok(similar) +} + +/// Fetch top tracks from cache or Last.fm. +async fn fetch_cached_top_tracks( + conn: &DatabaseConnection, + fetcher: &impl SimilarArtistFetcher, + artist_name: &str, + mbid: Option<&str>, +) -> Result, PlaylistError> { + let cache_key = format!( + "lastfm_toptracks:{}", + mbid.unwrap_or(&artist_name.to_lowercase()) + ); + + // Check cache + if let Ok(Some(json)) = queries::cache::get(conn, &cache_key).await + && let Ok(cached) = serde_json::from_str::>(&json) + { + tracing::debug!(artist = artist_name, "using cached top tracks"); + return Ok(cached); + } + + // Fetch from provider + let tracks = fetcher + .get_top_tracks(artist_name, mbid) + .await + .map_err(|e| PlaylistError::FetchError(e.to_string()))?; + + // Cache the result + if let Ok(json) = serde_json::to_string(&tracks) { + let _ = queries::cache::set(conn, &cache_key, "lastfm", &json, CACHE_TTL).await; + } + + Ok(tracks) +} + +/// Convert candidates to playlist tracks. +fn candidates_to_tracks(candidates: Vec) -> Vec { + candidates + .into_iter() + .map(|c| PlaylistTrack { + track_id: c.track_id, + file_path: c.file_path, + title: c.title, + artist: Some(c.artist), + album: c.album, + score: c.score, + duration: c.duration, + }) + .collect() +} + +/// Playlist generation error. +#[derive(Debug, thiserror::Error)] +pub enum PlaylistError { + #[error("invalid input: {0}")] + InvalidInput(String), + + #[error("database error: {0}")] + Db(String), + + #[error("fetch error: {0}")] + FetchError(String), +} diff --git a/shanty-playlist/src/types.rs b/shanty-playlist/src/types.rs new file mode 100644 index 0000000..0555ced --- /dev/null +++ b/shanty-playlist/src/types.rs @@ -0,0 +1,93 @@ +use serde::{Deserialize, Serialize}; + +/// Request to generate a playlist. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct PlaylistRequest { + pub strategy: String, + #[serde(default)] + pub seed_artists: Vec, + #[serde(default)] + pub genres: Vec, + #[serde(default = "default_count")] + pub count: usize, + #[serde(default = "default_popularity_bias")] + pub popularity_bias: u8, + /// Ordering mode: "score" (by score), "interleave" (spread artists), "random" (full shuffle). + #[serde(default = "default_ordering")] + pub ordering: String, + #[serde(default)] + pub rules: Option, +} + +fn default_count() -> usize { + 50 +} + +fn default_popularity_bias() -> u8 { + 5 +} + +fn default_ordering() -> String { + "interleave".to_string() +} + +/// Result of generating a playlist. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct PlaylistResult { + pub tracks: Vec, + pub strategy: String, + /// Resolved seed artist names (for display — may differ from input query). + #[serde(default)] + pub resolved_seeds: Vec, +} + +/// A track in a generated playlist. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct PlaylistTrack { + pub track_id: i32, + pub file_path: String, + pub title: Option, + pub artist: Option, + pub album: Option, + pub score: f64, + pub duration: Option, +} + +/// A weighted candidate for playlist selection (internal). +#[derive(Debug, Clone)] +pub struct Candidate { + pub score: f64, + pub artist: String, + pub artist_mbid: Option, + pub track_id: i32, + pub file_path: String, + pub title: Option, + pub album: Option, + pub duration: Option, +} + +/// A scored track before candidate conversion (internal). +#[derive(Debug, Clone)] +pub struct ScoredTrack { + pub track_id: i32, + pub file_path: String, + pub title: Option, + pub artist: String, + pub artist_mbid: Option, + pub album: Option, + pub duration: Option, + pub score: f64, + pub popularity: f64, + pub similarity: f64, +} + +/// Rules for the "smart" playlist strategy. +#[derive(Debug, Clone, Serialize, Deserialize, Default)] +pub struct SmartRules { + #[serde(default)] + pub genres: Vec, + pub added_within_days: Option, + pub year_range: Option<(i32, i32)>, + #[serde(default)] + pub artists: Vec, +} diff --git a/shanty-playlist/tests/unit.rs b/shanty-playlist/tests/unit.rs new file mode 100644 index 0000000..64e46bf --- /dev/null +++ b/shanty-playlist/tests/unit.rs @@ -0,0 +1,333 @@ +use std::collections::{HashMap, HashSet}; + +use chrono::Utc; +use shanty_data::PopularTrack; +use shanty_db::entities::track::Model as Track; +use shanty_playlist::ordering::{interleave_artists, shuffle}; +use shanty_playlist::scoring::score_tracks; +use shanty_playlist::selection::generate_playlist; +use shanty_playlist::types::Candidate; + +fn make_track(id: i32, title: &str, artist_id: Option, mbid: Option<&str>) -> Track { + let now = Utc::now().naive_utc(); + Track { + id, + file_path: format!("/music/{title}.opus"), + title: Some(title.to_string()), + artist: Some("Test Artist".to_string()), + album: Some("Test Album".to_string()), + album_artist: None, + track_number: Some(id), + disc_number: None, + duration: Some(200.0), + codec: None, + bitrate: None, + genre: None, + year: None, + musicbrainz_id: mbid.map(String::from), + file_size: 1000, + fingerprint: None, + artist_id, + album_id: None, + added_at: now, + updated_at: now, + file_mtime: None, + } +} + +fn make_candidate(id: i32, artist: &str, score: f64) -> Candidate { + Candidate { + score, + artist: artist.to_string(), + artist_mbid: Some(format!("mbid-{artist}")), + track_id: id, + file_path: format!("/music/{id}.opus"), + title: Some(format!("Track {id}")), + album: None, + duration: Some(200.0), + } +} + +// --- Scoring tests --- + +#[test] +fn test_score_tracks_basic() { + let artists = vec![("artist-1".to_string(), "Artist One".to_string(), 1.0)]; + + let tracks: Vec = (1..=5) + .map(|i| make_track(i, &format!("Song {i}"), Some(1), None)) + .collect(); + + let top_tracks = vec![ + PopularTrack { + name: "Song 1".to_string(), + mbid: None, + playcount: 1000, + }, + PopularTrack { + name: "Song 2".to_string(), + mbid: None, + playcount: 500, + }, + PopularTrack { + name: "Song 3".to_string(), + mbid: None, + playcount: 100, + }, + ]; + + let mut tracks_map = HashMap::new(); + tracks_map.insert("artist-1".to_string(), tracks); + + let mut top_map = HashMap::new(); + top_map.insert("artist-1".to_string(), top_tracks); + + let scored = score_tracks(&artists, &tracks_map, &top_map, 5); + + // Should have 3 tracks (only ones matching top tracks) + assert_eq!(scored.len(), 3); + + // Higher playcount should yield higher score + let scores: Vec = scored.iter().map(|t| t.score).collect(); + let max_score = scores.iter().cloned().fold(f64::NEG_INFINITY, f64::max); + assert!(max_score > 0.0); +} + +#[test] +fn test_score_tracks_no_top_tracks_uses_uniform() { + let artists = vec![("artist-1".to_string(), "Artist One".to_string(), 0.8)]; + + let tracks: Vec = (1..=3) + .map(|i| make_track(i, &format!("Song {i}"), Some(1), None)) + .collect(); + + let mut tracks_map = HashMap::new(); + tracks_map.insert("artist-1".to_string(), tracks); + + let top_map = HashMap::new(); // No top tracks + + let scored = score_tracks(&artists, &tracks_map, &top_map, 5); + + // All 3 tracks should be included with uniform scoring + assert_eq!(scored.len(), 3); + // All should have the same score (similarity only) + let first_score = scored[0].score; + for t in &scored { + assert!((t.score - first_score).abs() < 1e-10); + } +} + +#[test] +fn test_score_tracks_per_artist_cap() { + let artists = vec![("artist-1".to_string(), "Artist One".to_string(), 1.0)]; + + // 50 tracks + let tracks: Vec = (1..=50) + .map(|i| make_track(i, &format!("Song {i}"), Some(1), None)) + .collect(); + + let top_tracks: Vec = (1..=50) + .map(|i| PopularTrack { + name: format!("Song {i}"), + mbid: None, + playcount: (50 - i + 1) as u64 * 100, + }) + .collect(); + + let mut tracks_map = HashMap::new(); + tracks_map.insert("artist-1".to_string(), tracks); + let mut top_map = HashMap::new(); + top_map.insert("artist-1".to_string(), top_tracks); + + // bias 10 → cap = 10 + let scored = score_tracks(&artists, &tracks_map, &top_map, 10); + assert!(scored.len() <= 10); + + // bias 0 → no cap + let scored_no_cap = score_tracks(&artists, &tracks_map, &top_map, 0); + assert_eq!(scored_no_cap.len(), 50); +} + +#[test] +fn test_similarity_transform() { + // Higher match_score should produce higher similarity + let artists = vec![ + ("high".to_string(), "High".to_string(), 0.9), + ("low".to_string(), "Low".to_string(), 0.1), + ]; + + let track_high = make_track(1, "Song", Some(1), None); + let track_low = make_track(2, "Song", Some(2), None); + + let mut tracks_map = HashMap::new(); + tracks_map.insert("high".to_string(), vec![track_high]); + tracks_map.insert("low".to_string(), vec![track_low]); + + let scored = score_tracks(&artists, &tracks_map, &HashMap::new(), 5); + assert_eq!(scored.len(), 2); + + let high_score = scored + .iter() + .find(|t| t.artist == "High") + .unwrap() + .similarity; + let low_score = scored + .iter() + .find(|t| t.artist == "Low") + .unwrap() + .similarity; + assert!(high_score > low_score); +} + +// --- Selection tests --- + +#[test] +fn test_generate_playlist_basic() { + let candidates: Vec = (1..=20) + .map(|i| make_candidate(i, &format!("Artist{}", i % 4), 1.0)) + .collect(); + + let seeds = HashSet::new(); + let result = generate_playlist(&candidates, 10, &seeds); + + assert_eq!(result.len(), 10); +} + +#[test] +fn test_generate_playlist_respects_count() { + let candidates: Vec = (1..=5).map(|i| make_candidate(i, "Artist", 1.0)).collect(); + + let seeds = HashSet::new(); + let result = generate_playlist(&candidates, 3, &seeds); + assert_eq!(result.len(), 3); +} + +#[test] +fn test_generate_playlist_not_more_than_available() { + let candidates: Vec = (1..=3).map(|i| make_candidate(i, "Artist", 1.0)).collect(); + + let seeds = HashSet::new(); + let result = generate_playlist(&candidates, 100, &seeds); + assert_eq!(result.len(), 3); +} + +#[test] +fn test_generate_playlist_empty_candidates() { + let candidates: Vec = vec![]; + let seeds = HashSet::new(); + let result = generate_playlist(&candidates, 10, &seeds); + assert!(result.is_empty()); +} + +#[test] +fn test_generate_playlist_per_artist_cap() { + // 20 tracks from one artist, 5 from another + let mut candidates: Vec = (1..=20) + .map(|i| make_candidate(i, "Prolific", 1.0)) + .collect(); + candidates.extend((21..=25).map(|i| make_candidate(i, "Minor", 1.0))); + + let seeds = HashSet::new(); + let result = generate_playlist(&candidates, 15, &seeds); + + let prolific_count = result.iter().filter(|c| c.artist == "Prolific").count(); + let minor_count = result.iter().filter(|c| c.artist == "Minor").count(); + + // With 2 artists and n=15, cap = ceil(15/2) = 8. + // Minor only has 5 tracks, so Prolific fills the rest via fallback. + // Key check: both artists are represented. + assert!( + minor_count >= 1, + "Minor should get at least 1 track, got {minor_count}" + ); + assert!( + prolific_count >= 1, + "Prolific should get at least 1 track, got {prolific_count}" + ); + assert_eq!(result.len(), 15); +} + +#[test] +fn test_generate_playlist_seed_enforcement() { + // Many tracks from "Other" with high scores, few from "Seed" with low scores + let mut candidates: Vec = + (1..=50).map(|i| make_candidate(i, "Other", 10.0)).collect(); + candidates.push(make_candidate(51, "Seed", 0.01)); + candidates.push(make_candidate(52, "Seed", 0.01)); + + let mut seeds = HashSet::new(); + seeds.insert("Seed".to_string()); + + let result = generate_playlist(&candidates, 10, &seeds); + let seed_count = result.iter().filter(|c| c.artist == "Seed").count(); + + // seed_min = (10/10).max(1) = 1, so at least 1 seed track + assert!( + seed_count >= 1, + "Expected at least 1 seed track, got {seed_count}" + ); +} + +// --- Ordering tests --- + +#[test] +fn test_interleave_no_back_to_back() { + let tracks: Vec = vec![ + make_candidate(1, "A", 1.0), + make_candidate(2, "A", 1.0), + make_candidate(3, "A", 1.0), + make_candidate(4, "B", 1.0), + make_candidate(5, "B", 1.0), + make_candidate(6, "B", 1.0), + ]; + + let result = interleave_artists(tracks); + assert_eq!(result.len(), 6); + + // Check no back-to-back same artist + for window in result.windows(2) { + assert_ne!( + window[0].artist, window[1].artist, + "Back-to-back: {} at positions", + window[0].artist + ); + } +} + +#[test] +fn test_interleave_single_artist() { + let tracks: Vec = (1..=5).map(|i| make_candidate(i, "Solo", 1.0)).collect(); + + let result = interleave_artists(tracks); + assert_eq!(result.len(), 5); + // All same artist, so back-to-back is unavoidable — just check count +} + +#[test] +fn test_shuffle_preserves_count() { + let tracks: Vec = (1..=10).map(|i| make_candidate(i, "Artist", 1.0)).collect(); + + let result = shuffle(tracks); + assert_eq!(result.len(), 10); +} + +#[test] +fn test_interleave_many_artists() { + let mut tracks = Vec::new(); + for artist_idx in 0..5 { + for track_idx in 0..4 { + let id = artist_idx * 4 + track_idx + 1; + tracks.push(make_candidate(id, &format!("Artist{artist_idx}"), 1.0)); + } + } + + let result = interleave_artists(tracks); + assert_eq!(result.len(), 20); + + // Count back-to-back violations (should be 0 with 5 artists) + let violations = result + .windows(2) + .filter(|w| w[0].artist == w[1].artist) + .count(); + assert_eq!(violations, 0, "Expected no back-to-back with 5 artists"); +} diff --git a/shanty-search b/shanty-search index cbd0243..b39dd6c 160000 --- a/shanty-search +++ b/shanty-search @@ -1 +1 @@ -Subproject commit cbd02435160f0eea652f2870c3a54ce3c000a6d6 +Subproject commit b39dd6cc8efd71bd1ae8182809553d8b6c44bd37 diff --git a/shanty-web b/shanty-web index 9d6c0e3..ea6a641 160000 --- a/shanty-web +++ b/shanty-web @@ -1 +1 @@ -Subproject commit 9d6c0e31c1f14b73dc70b9016fecd2b37e2aa30f +Subproject commit ea6a6410f3d22c5ff0fe30729f796caa5efe28ae diff --git a/src/main.rs b/src/main.rs index 334a6f8..20c8b85 100644 --- a/src/main.rs +++ b/src/main.rs @@ -58,7 +58,7 @@ async fn main() -> anyhow::Result<()> { let db = Database::new(&config.database_url).await?; let mb_client = MusicBrainzFetcher::new()?; - let search = MusicBrainzSearch::new()?; + let search = MusicBrainzSearch::with_limiter(mb_client.limiter())?; let wiki_fetcher = WikipediaFetcher::new()?; let bind = format!("{}:{}", config.web.bind, config.web.port); @@ -77,6 +77,8 @@ async fn main() -> anyhow::Result<()> { scheduler: tokio::sync::Mutex::new(shanty_web::state::SchedulerInfo { next_pipeline: None, next_monitor: None, + skip_pipeline: false, + skip_monitor: false, }), });