diff --git a/shanty-data/src/mb_hybrid.rs b/shanty-data/src/mb_hybrid.rs index ff25e2c..46d7f11 100644 --- a/shanty-data/src/mb_hybrid.rs +++ b/shanty-data/src/mb_hybrid.rs @@ -59,6 +59,13 @@ impl HybridMusicBrainzFetcher { self.remote.get_artist_by_mbid(mbid).await } + /// Get artist info from local DB only (no remote API fallback). + /// Returns `None` if the local DB is unavailable or doesn't have this artist. + pub fn get_artist_info_local(&self, mbid: &str) -> Option { + self.local_if_available() + .and_then(|l| l.get_artist_info_sync(mbid).ok()) + } + /// Get detailed artist info by MBID. Tries local first, then remote. pub async fn get_artist_info(&self, mbid: &str) -> DataResult { if let Some(local) = self.local_if_available() diff --git a/shanty-playlist/src/lib.rs b/shanty-playlist/src/lib.rs index 9475ea5..e445fc9 100644 --- a/shanty-playlist/src/lib.rs +++ b/shanty-playlist/src/lib.rs @@ -9,5 +9,9 @@ pub mod selection; pub mod strategies; pub mod types; -pub use strategies::{PlaylistError, genre_based, random, similar_artists, smart, to_m3u}; -pub use types::{Candidate, PlaylistRequest, PlaylistResult, PlaylistTrack, SmartRules}; +pub use strategies::{ + CountryLookup, PlaylistError, genre_based, random, similar_artists, smart, to_m3u, +}; +pub use types::{ + Candidate, PlaylistRequest, PlaylistResult, PlaylistTrack, SimilarConfig, SmartRules, +}; diff --git a/shanty-playlist/src/scoring.rs b/shanty-playlist/src/scoring.rs index 2910398..efb3ba7 100644 --- a/shanty-playlist/src/scoring.rs +++ b/shanty-playlist/src/scoring.rs @@ -21,6 +21,8 @@ pub fn score_tracks( tracks_by_artist: &HashMap>, top_tracks_by_artist: &HashMap>, popularity_bias: u8, + _global_popularity: u8, + max_tracks_per_artist: Option, ) -> Vec { let bias = popularity_bias.min(10) as usize; let mut scored = Vec::new(); @@ -108,7 +110,9 @@ pub fn score_tracks( by_artist.entry(key).or_default().push(t); } - let cap = if popularity_bias == 0 { + let cap = if let Some(explicit) = max_tracks_per_artist { + Some((explicit as usize).max(1)) + } else if popularity_bias == 0 { None } else { let b = popularity_bias as f64; @@ -147,5 +151,49 @@ pub fn score_tracks( } } - by_artist.into_values().flatten().collect() + let mut result: Vec = by_artist.into_values().flatten().collect(); + + // Step 3: Apply global popularity weighting + if _global_popularity > 0 { + let gp = _global_popularity.min(10) as usize; + let gp_exponent = POPULARITY_EXPONENTS[gp]; + let gp_strength = _global_popularity as f64 / 10.0; + + // Find max playcount across ALL artists + let global_max: u64 = top_tracks_by_artist + .values() + .flat_map(|tracks| tracks.iter().map(|t| t.playcount)) + .max() + .unwrap_or(1) + .max(1); + + // Build a global playcount lookup (lowercase name -> max playcount) + let mut global_playcounts: HashMap = HashMap::new(); + for tracks in top_tracks_by_artist.values() { + for t in tracks { + let key = t.name.to_lowercase(); + global_playcounts + .entry(key) + .and_modify(|c| *c = (*c).max(t.playcount)) + .or_insert(t.playcount); + } + } + + for t in &mut result { + let playcount = t + .title + .as_ref() + .and_then(|title| global_playcounts.get(&title.to_lowercase()).copied()) + .unwrap_or(0); + + if playcount > 0 { + let global_pop = (playcount as f64 / global_max as f64).powf(gp_exponent); + // lerp(1.0, global_pop, gp_strength) + let factor = 1.0 + gp_strength * (global_pop - 1.0); + t.score *= factor; + } + } + } + + result } diff --git a/shanty-playlist/src/selection.rs b/shanty-playlist/src/selection.rs index e00eece..3c7a891 100644 --- a/shanty-playlist/src/selection.rs +++ b/shanty-playlist/src/selection.rs @@ -11,6 +11,8 @@ pub fn generate_playlist( candidates: &[Candidate], n: usize, seed_names: &HashSet, + max_artists: Option, + skip_seed_enforcement: bool, ) -> Vec { if candidates.is_empty() { return Vec::new(); @@ -20,8 +22,14 @@ pub fn generate_playlist( let mut pool: Vec<&Candidate> = candidates.iter().collect(); let mut result: Vec = Vec::new(); let mut artist_counts: HashMap = HashMap::new(); + let mut distinct_artists_set: HashSet = HashSet::new(); + let max_distinct = max_artists.map(|m| (m as usize).max(1)); - let seed_min = (n / 10).max(1); + let seed_min = if skip_seed_enforcement { + 0 + } else { + (n / 10).max(1) + }; let distinct_artists: usize = { let mut seen = HashSet::new(); @@ -54,6 +62,13 @@ pub fn generate_playlist( .iter() .enumerate() .filter(|(_, c)| { + // Max distinct artists: reject new artists once we hit the cap + if let Some(max) = max_distinct + && distinct_artists_set.len() >= max + && !distinct_artists_set.contains(&c.artist) + { + return false; + } if force_seed { seed_names.contains(&c.artist) } else { @@ -79,6 +94,7 @@ pub fn generate_playlist( let picked = indices[dist.sample(&mut rng)]; let track = pool.remove(picked); *artist_counts.entry(track.artist.clone()).or_insert(0) += 1; + distinct_artists_set.insert(track.artist.clone()); result.push(Candidate { score: track.score, artist: track.artist.clone(), diff --git a/shanty-playlist/src/strategies.rs b/shanty-playlist/src/strategies.rs index 466fb07..c2af28f 100644 --- a/shanty-playlist/src/strategies.rs +++ b/shanty-playlist/src/strategies.rs @@ -1,4 +1,6 @@ use std::collections::{HashMap, HashSet}; +use std::future::Future; +use std::pin::Pin; use sea_orm::DatabaseConnection; use shanty_data::{PopularTrack, SimilarArtist, SimilarArtistFetcher}; @@ -12,6 +14,16 @@ use crate::types::*; /// Cache TTL: 7 days in seconds. const CACHE_TTL: i64 = 7 * 24 * 3600; +/// Trait for looking up an artist's country by MBID. +/// Implementations should return quickly (local DB or cache), never blocking +/// on rate-limited remote APIs during playlist generation. +pub trait CountryLookup: Send + Sync { + fn get_country<'a>( + &'a self, + mbid: &'a str, + ) -> Pin> + Send + 'a>>; +} + /// Generate a playlist based on similar artists (the primary strategy). /// /// Flow: @@ -26,9 +38,8 @@ pub async fn similar_artists( conn: &DatabaseConnection, fetcher: &impl SimilarArtistFetcher, seed_artists: Vec, - count: usize, - popularity_bias: u8, - ordering: &str, + config: &SimilarConfig, + _country_fetcher: Option<&dyn CountryLookup>, ) -> Result { if seed_artists.is_empty() { return Err(PlaylistError::InvalidInput( @@ -37,25 +48,32 @@ pub async fn similar_artists( } let num_seeds = seed_artists.len() as f64; + let seed_similarity = config.seed_weight as f64 * 0.2; // Merge similar artists from all seeds: key -> (name, total_score) let mut merged: HashMap = HashMap::new(); // Track resolved seed names for enforcement (use DB names, not raw input) let mut resolved_seed_names: HashSet = HashSet::new(); + // Track which keys are seeds (for country filter) + let mut seed_keys: HashSet = HashSet::new(); for seed in &seed_artists { // Resolve the seed artist: try name lookup in DB let (artist_name, artist_mbid) = resolve_artist(conn, seed).await?; resolved_seed_names.insert(artist_name.clone()); - // Insert the seed itself with score 1.0 let key = artist_mbid .clone() .unwrap_or_else(|| artist_name.to_lowercase()); - let entry = merged - .entry(key) - .or_insert_with(|| (artist_name.clone(), 0.0)); - entry.1 += 1.0; + seed_keys.insert(key.clone()); + + // Insert the seed itself with configured weight + if seed_similarity > 0.0 { + let entry = merged + .entry(key.clone()) + .or_insert_with(|| (artist_name.clone(), 0.0)); + entry.1 += seed_similarity; + } // Fetch similar artists (cached or fresh) let similar = fetch_cached_similar(conn, fetcher, &artist_name, artist_mbid.as_deref()) @@ -71,11 +89,41 @@ pub async fn similar_artists( } } - // Normalize scores by seed count - let artists: Vec<(String, String, f64)> = merged + // Normalize scores by seed count, sort by similarity descending + let mut artists: Vec<(String, String, f64)> = merged .into_iter() .map(|(key, (name, total))| (key, name, total / num_seeds)) .collect(); + artists.sort_by(|a, b| b.2.partial_cmp(&a.2).unwrap_or(std::cmp::Ordering::Equal)); + + // Apply discovery range: truncate to pool size + let pool_size = discovery_pool_size(config.discovery_range); + artists.truncate(pool_size); + + // Country filter: only keep artists from the same countries as seeds + if config.country_filter + && let Some(cf) = _country_fetcher + { + let mut seed_countries: HashSet = HashSet::new(); + for key in &seed_keys { + if let Some(country) = cf.get_country(key).await { + seed_countries.insert(country); + } + } + + if !seed_countries.is_empty() { + let mut filtered = Vec::new(); + for entry in artists { + let country = cf.get_country(&entry.0).await; + match country { + Some(c) if seed_countries.contains(&c) => filtered.push(entry), + None => filtered.push(entry), // unknown = pass through + _ => {} // known but different = exclude + } + } + artists = filtered; + } + } // Build track and top-track maps for scoring let mut tracks_by_artist: HashMap> = @@ -104,7 +152,9 @@ pub async fn similar_artists( &artists, &tracks_by_artist, &top_tracks_by_artist, - popularity_bias, + config.popularity_bias, + config.global_popularity, + config.max_tracks_per_artist, ); // Convert to candidates @@ -123,10 +173,17 @@ pub async fn similar_artists( .collect(); // Select (use resolved DB names for seed enforcement, not raw input) - let selected = selection::generate_playlist(&candidates, count, &resolved_seed_names); + let skip_seed_enforcement = config.seed_weight == 0; + let selected = selection::generate_playlist( + &candidates, + config.count, + &resolved_seed_names, + config.max_artists, + skip_seed_enforcement, + ); // Order - let ordered = apply_ordering(selected, ordering); + let ordered = apply_ordering(selected, &config.ordering); Ok(PlaylistResult { tracks: candidates_to_tracks(ordered), @@ -135,6 +192,14 @@ pub async fn similar_artists( }) } +/// Map discovery_range (0-10) to artist pool size. +/// 0 -> 15, 5 -> ~100, 10 -> 500 (exponential curve). +fn discovery_pool_size(range: u8) -> usize { + let r = range.min(10) as f64; + let size = 15.0 * (500.0_f64 / 15.0).powf(r / 10.0); + size.round() as usize +} + /// Generate a genre-based playlist. pub async fn genre_based( conn: &DatabaseConnection, @@ -176,7 +241,7 @@ pub async fn genre_based( .collect(); let seed_names = HashSet::new(); - let selected = selection::generate_playlist(&candidates, count, &seed_names); + let selected = selection::generate_playlist(&candidates, count, &seed_names, None, true); let ordered = apply_ordering(selected, ordering); Ok(PlaylistResult { @@ -306,7 +371,7 @@ pub async fn smart( .collect(); let seed_names = HashSet::new(); - let selected = selection::generate_playlist(&candidates, count, &seed_names); + let selected = selection::generate_playlist(&candidates, count, &seed_names, None, true); let ordered = ordering::interleave_artists(selected); Ok(PlaylistResult { diff --git a/shanty-playlist/src/types.rs b/shanty-playlist/src/types.rs index 0555ced..057141e 100644 --- a/shanty-playlist/src/types.rs +++ b/shanty-playlist/src/types.rs @@ -17,6 +17,58 @@ pub struct PlaylistRequest { pub ordering: String, #[serde(default)] pub rules: Option, + + /// Discovery range: how many similar artists to consider (0-10). + /// 0 = focused (~15), 10 = wide open (~500). Default: 5. + #[serde(default)] + pub discovery_range: Option, + /// Global popularity weighting (0-10). 0 = off, 10 = strong bias toward + /// globally popular tracks across all artists. Default: 0. + #[serde(default)] + pub global_popularity: Option, + /// Filter to same countries as seed artists. Default: false. + #[serde(default)] + pub country_filter: Option, + /// Seed artist weight (0-10). 0 = exclude seeds, 5 = normal (similarity 1.0), + /// 10 = double weight (similarity 2.0). Default: 5. + #[serde(default)] + pub seed_weight: Option, + /// Explicit per-artist track cap. None or 0 = auto (derived from popularity_bias). + #[serde(default)] + pub max_tracks_per_artist: Option, + /// Maximum distinct artists in the result. None or 0 = unlimited. + #[serde(default)] + pub max_artists: Option, +} + +/// Resolved configuration for the similar-artists strategy. +#[derive(Debug, Clone)] +pub struct SimilarConfig { + pub count: usize, + pub popularity_bias: u8, + pub ordering: String, + pub discovery_range: u8, + pub global_popularity: u8, + pub country_filter: bool, + pub seed_weight: u8, + pub max_tracks_per_artist: Option, + pub max_artists: Option, +} + +impl SimilarConfig { + pub fn from_request(req: &PlaylistRequest) -> Self { + Self { + count: req.count, + popularity_bias: req.popularity_bias, + ordering: req.ordering.clone(), + discovery_range: req.discovery_range.unwrap_or(5), + global_popularity: req.global_popularity.unwrap_or(0), + country_filter: req.country_filter.unwrap_or(false), + seed_weight: req.seed_weight.unwrap_or(5), + max_tracks_per_artist: req.max_tracks_per_artist.filter(|&v| v > 0), + max_artists: req.max_artists.filter(|&v| v > 0), + } + } } fn default_count() -> usize { diff --git a/shanty-playlist/tests/unit.rs b/shanty-playlist/tests/unit.rs index 822c6e4..2205d0b 100644 --- a/shanty-playlist/tests/unit.rs +++ b/shanty-playlist/tests/unit.rs @@ -83,7 +83,7 @@ fn test_score_tracks_basic() { let mut top_map = HashMap::new(); top_map.insert("artist-1".to_string(), top_tracks); - let scored = score_tracks(&artists, &tracks_map, &top_map, 5); + let scored = score_tracks(&artists, &tracks_map, &top_map, 5, 0, None); // Should have 3 tracks (only ones matching top tracks) assert_eq!(scored.len(), 3); @@ -107,7 +107,7 @@ fn test_score_tracks_no_top_tracks_uses_uniform() { let top_map = HashMap::new(); // No top tracks - let scored = score_tracks(&artists, &tracks_map, &top_map, 5); + let scored = score_tracks(&artists, &tracks_map, &top_map, 5, 0, None); // All 3 tracks should be included with uniform scoring assert_eq!(scored.len(), 3); @@ -141,11 +141,11 @@ fn test_score_tracks_per_artist_cap() { top_map.insert("artist-1".to_string(), top_tracks); // bias 10 → cap = 10 - let scored = score_tracks(&artists, &tracks_map, &top_map, 10); + let scored = score_tracks(&artists, &tracks_map, &top_map, 10, 0, None); assert!(scored.len() <= 10); // bias 0 → no cap - let scored_no_cap = score_tracks(&artists, &tracks_map, &top_map, 0); + let scored_no_cap = score_tracks(&artists, &tracks_map, &top_map, 0, 0, None); assert_eq!(scored_no_cap.len(), 50); } @@ -164,7 +164,7 @@ fn test_similarity_transform() { tracks_map.insert("high".to_string(), vec![track_high]); tracks_map.insert("low".to_string(), vec![track_low]); - let scored = score_tracks(&artists, &tracks_map, &HashMap::new(), 5); + let scored = score_tracks(&artists, &tracks_map, &HashMap::new(), 5, 0, None); assert_eq!(scored.len(), 2); let high_score = scored @@ -189,7 +189,7 @@ fn test_generate_playlist_basic() { .collect(); let seeds = HashSet::new(); - let result = generate_playlist(&candidates, 10, &seeds); + let result = generate_playlist(&candidates, 10, &seeds, None, false); assert_eq!(result.len(), 10); } @@ -199,7 +199,7 @@ fn test_generate_playlist_respects_count() { let candidates: Vec = (1..=5).map(|i| make_candidate(i, "Artist", 1.0)).collect(); let seeds = HashSet::new(); - let result = generate_playlist(&candidates, 3, &seeds); + let result = generate_playlist(&candidates, 3, &seeds, None, false); assert_eq!(result.len(), 3); } @@ -208,7 +208,7 @@ fn test_generate_playlist_not_more_than_available() { let candidates: Vec = (1..=3).map(|i| make_candidate(i, "Artist", 1.0)).collect(); let seeds = HashSet::new(); - let result = generate_playlist(&candidates, 100, &seeds); + let result = generate_playlist(&candidates, 100, &seeds, None, false); assert_eq!(result.len(), 3); } @@ -216,7 +216,7 @@ fn test_generate_playlist_not_more_than_available() { fn test_generate_playlist_empty_candidates() { let candidates: Vec = vec![]; let seeds = HashSet::new(); - let result = generate_playlist(&candidates, 10, &seeds); + let result = generate_playlist(&candidates, 10, &seeds, None, false); assert!(result.is_empty()); } @@ -229,7 +229,7 @@ fn test_generate_playlist_per_artist_cap() { candidates.extend((21..=25).map(|i| make_candidate(i, "Minor", 1.0))); let seeds = HashSet::new(); - let result = generate_playlist(&candidates, 15, &seeds); + let result = generate_playlist(&candidates, 15, &seeds, None, false); let prolific_count = result.iter().filter(|c| c.artist == "Prolific").count(); let minor_count = result.iter().filter(|c| c.artist == "Minor").count(); @@ -259,7 +259,7 @@ fn test_generate_playlist_seed_enforcement() { let mut seeds = HashSet::new(); seeds.insert("Seed".to_string()); - let result = generate_playlist(&candidates, 10, &seeds); + let result = generate_playlist(&candidates, 10, &seeds, None, false); let seed_count = result.iter().filter(|c| c.artist == "Seed").count(); // seed_min = (10/10).max(1) = 1, so at least 1 seed track diff --git a/shanty-web b/shanty-web index bd6656f..4c42cf0 160000 --- a/shanty-web +++ b/shanty-web @@ -1 +1 @@ -Subproject commit bd6656ff316a5ad02b3e3687042bcdc1decd9fb4 +Subproject commit 4c42cf0131284bdb2aa351645e02bd2f6a069df1