update to the playlists. testing
CI / check (push) Successful in 1m15s
CI / docker (push) Successful in 2m11s

This commit is contained in:
Connor Johnstone
2026-04-01 22:12:58 -04:00
parent f77cea47b1
commit b2f030b52d
8 changed files with 224 additions and 32 deletions
+7
View File
@@ -59,6 +59,13 @@ impl HybridMusicBrainzFetcher {
self.remote.get_artist_by_mbid(mbid).await self.remote.get_artist_by_mbid(mbid).await
} }
/// Get artist info from local DB only (no remote API fallback).
/// Returns `None` if the local DB is unavailable or doesn't have this artist.
pub fn get_artist_info_local(&self, mbid: &str) -> Option<ArtistInfo> {
self.local_if_available()
.and_then(|l| l.get_artist_info_sync(mbid).ok())
}
/// Get detailed artist info by MBID. Tries local first, then remote. /// Get detailed artist info by MBID. Tries local first, then remote.
pub async fn get_artist_info(&self, mbid: &str) -> DataResult<ArtistInfo> { pub async fn get_artist_info(&self, mbid: &str) -> DataResult<ArtistInfo> {
if let Some(local) = self.local_if_available() if let Some(local) = self.local_if_available()
+6 -2
View File
@@ -9,5 +9,9 @@ pub mod selection;
pub mod strategies; pub mod strategies;
pub mod types; pub mod types;
pub use strategies::{PlaylistError, genre_based, random, similar_artists, smart, to_m3u}; pub use strategies::{
pub use types::{Candidate, PlaylistRequest, PlaylistResult, PlaylistTrack, SmartRules}; CountryLookup, PlaylistError, genre_based, random, similar_artists, smart, to_m3u,
};
pub use types::{
Candidate, PlaylistRequest, PlaylistResult, PlaylistTrack, SimilarConfig, SmartRules,
};
+50 -2
View File
@@ -21,6 +21,8 @@ pub fn score_tracks(
tracks_by_artist: &HashMap<String, Vec<Track>>, tracks_by_artist: &HashMap<String, Vec<Track>>,
top_tracks_by_artist: &HashMap<String, Vec<PopularTrack>>, top_tracks_by_artist: &HashMap<String, Vec<PopularTrack>>,
popularity_bias: u8, popularity_bias: u8,
_global_popularity: u8,
max_tracks_per_artist: Option<u8>,
) -> Vec<ScoredTrack> { ) -> Vec<ScoredTrack> {
let bias = popularity_bias.min(10) as usize; let bias = popularity_bias.min(10) as usize;
let mut scored = Vec::new(); let mut scored = Vec::new();
@@ -108,7 +110,9 @@ pub fn score_tracks(
by_artist.entry(key).or_default().push(t); by_artist.entry(key).or_default().push(t);
} }
let cap = if popularity_bias == 0 { let cap = if let Some(explicit) = max_tracks_per_artist {
Some((explicit as usize).max(1))
} else if popularity_bias == 0 {
None None
} else { } else {
let b = popularity_bias as f64; let b = popularity_bias as f64;
@@ -147,5 +151,49 @@ pub fn score_tracks(
} }
} }
by_artist.into_values().flatten().collect() let mut result: Vec<ScoredTrack> = by_artist.into_values().flatten().collect();
// Step 3: Apply global popularity weighting
if _global_popularity > 0 {
let gp = _global_popularity.min(10) as usize;
let gp_exponent = POPULARITY_EXPONENTS[gp];
let gp_strength = _global_popularity as f64 / 10.0;
// Find max playcount across ALL artists
let global_max: u64 = top_tracks_by_artist
.values()
.flat_map(|tracks| tracks.iter().map(|t| t.playcount))
.max()
.unwrap_or(1)
.max(1);
// Build a global playcount lookup (lowercase name -> max playcount)
let mut global_playcounts: HashMap<String, u64> = HashMap::new();
for tracks in top_tracks_by_artist.values() {
for t in tracks {
let key = t.name.to_lowercase();
global_playcounts
.entry(key)
.and_modify(|c| *c = (*c).max(t.playcount))
.or_insert(t.playcount);
}
}
for t in &mut result {
let playcount = t
.title
.as_ref()
.and_then(|title| global_playcounts.get(&title.to_lowercase()).copied())
.unwrap_or(0);
if playcount > 0 {
let global_pop = (playcount as f64 / global_max as f64).powf(gp_exponent);
// lerp(1.0, global_pop, gp_strength)
let factor = 1.0 + gp_strength * (global_pop - 1.0);
t.score *= factor;
}
}
}
result
} }
+17 -1
View File
@@ -11,6 +11,8 @@ pub fn generate_playlist(
candidates: &[Candidate], candidates: &[Candidate],
n: usize, n: usize,
seed_names: &HashSet<String>, seed_names: &HashSet<String>,
max_artists: Option<u8>,
skip_seed_enforcement: bool,
) -> Vec<Candidate> { ) -> Vec<Candidate> {
if candidates.is_empty() { if candidates.is_empty() {
return Vec::new(); return Vec::new();
@@ -20,8 +22,14 @@ pub fn generate_playlist(
let mut pool: Vec<&Candidate> = candidates.iter().collect(); let mut pool: Vec<&Candidate> = candidates.iter().collect();
let mut result: Vec<Candidate> = Vec::new(); let mut result: Vec<Candidate> = Vec::new();
let mut artist_counts: HashMap<String, usize> = HashMap::new(); let mut artist_counts: HashMap<String, usize> = HashMap::new();
let mut distinct_artists_set: HashSet<String> = HashSet::new();
let max_distinct = max_artists.map(|m| (m as usize).max(1));
let seed_min = (n / 10).max(1); let seed_min = if skip_seed_enforcement {
0
} else {
(n / 10).max(1)
};
let distinct_artists: usize = { let distinct_artists: usize = {
let mut seen = HashSet::new(); let mut seen = HashSet::new();
@@ -54,6 +62,13 @@ pub fn generate_playlist(
.iter() .iter()
.enumerate() .enumerate()
.filter(|(_, c)| { .filter(|(_, c)| {
// Max distinct artists: reject new artists once we hit the cap
if let Some(max) = max_distinct
&& distinct_artists_set.len() >= max
&& !distinct_artists_set.contains(&c.artist)
{
return false;
}
if force_seed { if force_seed {
seed_names.contains(&c.artist) seed_names.contains(&c.artist)
} else { } else {
@@ -79,6 +94,7 @@ pub fn generate_playlist(
let picked = indices[dist.sample(&mut rng)]; let picked = indices[dist.sample(&mut rng)];
let track = pool.remove(picked); let track = pool.remove(picked);
*artist_counts.entry(track.artist.clone()).or_insert(0) += 1; *artist_counts.entry(track.artist.clone()).or_insert(0) += 1;
distinct_artists_set.insert(track.artist.clone());
result.push(Candidate { result.push(Candidate {
score: track.score, score: track.score,
artist: track.artist.clone(), artist: track.artist.clone(),
+80 -15
View File
@@ -1,4 +1,6 @@
use std::collections::{HashMap, HashSet}; use std::collections::{HashMap, HashSet};
use std::future::Future;
use std::pin::Pin;
use sea_orm::DatabaseConnection; use sea_orm::DatabaseConnection;
use shanty_data::{PopularTrack, SimilarArtist, SimilarArtistFetcher}; use shanty_data::{PopularTrack, SimilarArtist, SimilarArtistFetcher};
@@ -12,6 +14,16 @@ use crate::types::*;
/// Cache TTL: 7 days in seconds. /// Cache TTL: 7 days in seconds.
const CACHE_TTL: i64 = 7 * 24 * 3600; const CACHE_TTL: i64 = 7 * 24 * 3600;
/// Trait for looking up an artist's country by MBID.
/// Implementations should return quickly (local DB or cache), never blocking
/// on rate-limited remote APIs during playlist generation.
pub trait CountryLookup: Send + Sync {
fn get_country<'a>(
&'a self,
mbid: &'a str,
) -> Pin<Box<dyn Future<Output = Option<String>> + Send + 'a>>;
}
/// Generate a playlist based on similar artists (the primary strategy). /// Generate a playlist based on similar artists (the primary strategy).
/// ///
/// Flow: /// Flow:
@@ -26,9 +38,8 @@ pub async fn similar_artists(
conn: &DatabaseConnection, conn: &DatabaseConnection,
fetcher: &impl SimilarArtistFetcher, fetcher: &impl SimilarArtistFetcher,
seed_artists: Vec<String>, seed_artists: Vec<String>,
count: usize, config: &SimilarConfig,
popularity_bias: u8, _country_fetcher: Option<&dyn CountryLookup>,
ordering: &str,
) -> Result<PlaylistResult, PlaylistError> { ) -> Result<PlaylistResult, PlaylistError> {
if seed_artists.is_empty() { if seed_artists.is_empty() {
return Err(PlaylistError::InvalidInput( return Err(PlaylistError::InvalidInput(
@@ -37,25 +48,32 @@ pub async fn similar_artists(
} }
let num_seeds = seed_artists.len() as f64; let num_seeds = seed_artists.len() as f64;
let seed_similarity = config.seed_weight as f64 * 0.2;
// Merge similar artists from all seeds: key -> (name, total_score) // Merge similar artists from all seeds: key -> (name, total_score)
let mut merged: HashMap<String, (String, f64)> = HashMap::new(); let mut merged: HashMap<String, (String, f64)> = HashMap::new();
// Track resolved seed names for enforcement (use DB names, not raw input) // Track resolved seed names for enforcement (use DB names, not raw input)
let mut resolved_seed_names: HashSet<String> = HashSet::new(); let mut resolved_seed_names: HashSet<String> = HashSet::new();
// Track which keys are seeds (for country filter)
let mut seed_keys: HashSet<String> = HashSet::new();
for seed in &seed_artists { for seed in &seed_artists {
// Resolve the seed artist: try name lookup in DB // Resolve the seed artist: try name lookup in DB
let (artist_name, artist_mbid) = resolve_artist(conn, seed).await?; let (artist_name, artist_mbid) = resolve_artist(conn, seed).await?;
resolved_seed_names.insert(artist_name.clone()); resolved_seed_names.insert(artist_name.clone());
// Insert the seed itself with score 1.0
let key = artist_mbid let key = artist_mbid
.clone() .clone()
.unwrap_or_else(|| artist_name.to_lowercase()); .unwrap_or_else(|| artist_name.to_lowercase());
let entry = merged seed_keys.insert(key.clone());
.entry(key)
.or_insert_with(|| (artist_name.clone(), 0.0)); // Insert the seed itself with configured weight
entry.1 += 1.0; if seed_similarity > 0.0 {
let entry = merged
.entry(key.clone())
.or_insert_with(|| (artist_name.clone(), 0.0));
entry.1 += seed_similarity;
}
// Fetch similar artists (cached or fresh) // Fetch similar artists (cached or fresh)
let similar = fetch_cached_similar(conn, fetcher, &artist_name, artist_mbid.as_deref()) let similar = fetch_cached_similar(conn, fetcher, &artist_name, artist_mbid.as_deref())
@@ -71,11 +89,41 @@ pub async fn similar_artists(
} }
} }
// Normalize scores by seed count // Normalize scores by seed count, sort by similarity descending
let artists: Vec<(String, String, f64)> = merged let mut artists: Vec<(String, String, f64)> = merged
.into_iter() .into_iter()
.map(|(key, (name, total))| (key, name, total / num_seeds)) .map(|(key, (name, total))| (key, name, total / num_seeds))
.collect(); .collect();
artists.sort_by(|a, b| b.2.partial_cmp(&a.2).unwrap_or(std::cmp::Ordering::Equal));
// Apply discovery range: truncate to pool size
let pool_size = discovery_pool_size(config.discovery_range);
artists.truncate(pool_size);
// Country filter: only keep artists from the same countries as seeds
if config.country_filter
&& let Some(cf) = _country_fetcher
{
let mut seed_countries: HashSet<String> = HashSet::new();
for key in &seed_keys {
if let Some(country) = cf.get_country(key).await {
seed_countries.insert(country);
}
}
if !seed_countries.is_empty() {
let mut filtered = Vec::new();
for entry in artists {
let country = cf.get_country(&entry.0).await;
match country {
Some(c) if seed_countries.contains(&c) => filtered.push(entry),
None => filtered.push(entry), // unknown = pass through
_ => {} // known but different = exclude
}
}
artists = filtered;
}
}
// Build track and top-track maps for scoring // Build track and top-track maps for scoring
let mut tracks_by_artist: HashMap<String, Vec<shanty_db::entities::track::Model>> = let mut tracks_by_artist: HashMap<String, Vec<shanty_db::entities::track::Model>> =
@@ -104,7 +152,9 @@ pub async fn similar_artists(
&artists, &artists,
&tracks_by_artist, &tracks_by_artist,
&top_tracks_by_artist, &top_tracks_by_artist,
popularity_bias, config.popularity_bias,
config.global_popularity,
config.max_tracks_per_artist,
); );
// Convert to candidates // Convert to candidates
@@ -123,10 +173,17 @@ pub async fn similar_artists(
.collect(); .collect();
// Select (use resolved DB names for seed enforcement, not raw input) // Select (use resolved DB names for seed enforcement, not raw input)
let selected = selection::generate_playlist(&candidates, count, &resolved_seed_names); let skip_seed_enforcement = config.seed_weight == 0;
let selected = selection::generate_playlist(
&candidates,
config.count,
&resolved_seed_names,
config.max_artists,
skip_seed_enforcement,
);
// Order // Order
let ordered = apply_ordering(selected, ordering); let ordered = apply_ordering(selected, &config.ordering);
Ok(PlaylistResult { Ok(PlaylistResult {
tracks: candidates_to_tracks(ordered), tracks: candidates_to_tracks(ordered),
@@ -135,6 +192,14 @@ pub async fn similar_artists(
}) })
} }
/// Map discovery_range (0-10) to artist pool size.
/// 0 -> 15, 5 -> ~100, 10 -> 500 (exponential curve).
fn discovery_pool_size(range: u8) -> usize {
let r = range.min(10) as f64;
let size = 15.0 * (500.0_f64 / 15.0).powf(r / 10.0);
size.round() as usize
}
/// Generate a genre-based playlist. /// Generate a genre-based playlist.
pub async fn genre_based( pub async fn genre_based(
conn: &DatabaseConnection, conn: &DatabaseConnection,
@@ -176,7 +241,7 @@ pub async fn genre_based(
.collect(); .collect();
let seed_names = HashSet::new(); let seed_names = HashSet::new();
let selected = selection::generate_playlist(&candidates, count, &seed_names); let selected = selection::generate_playlist(&candidates, count, &seed_names, None, true);
let ordered = apply_ordering(selected, ordering); let ordered = apply_ordering(selected, ordering);
Ok(PlaylistResult { Ok(PlaylistResult {
@@ -306,7 +371,7 @@ pub async fn smart(
.collect(); .collect();
let seed_names = HashSet::new(); let seed_names = HashSet::new();
let selected = selection::generate_playlist(&candidates, count, &seed_names); let selected = selection::generate_playlist(&candidates, count, &seed_names, None, true);
let ordered = ordering::interleave_artists(selected); let ordered = ordering::interleave_artists(selected);
Ok(PlaylistResult { Ok(PlaylistResult {
+52
View File
@@ -17,6 +17,58 @@ pub struct PlaylistRequest {
pub ordering: String, pub ordering: String,
#[serde(default)] #[serde(default)]
pub rules: Option<SmartRules>, pub rules: Option<SmartRules>,
/// Discovery range: how many similar artists to consider (0-10).
/// 0 = focused (~15), 10 = wide open (~500). Default: 5.
#[serde(default)]
pub discovery_range: Option<u8>,
/// Global popularity weighting (0-10). 0 = off, 10 = strong bias toward
/// globally popular tracks across all artists. Default: 0.
#[serde(default)]
pub global_popularity: Option<u8>,
/// Filter to same countries as seed artists. Default: false.
#[serde(default)]
pub country_filter: Option<bool>,
/// Seed artist weight (0-10). 0 = exclude seeds, 5 = normal (similarity 1.0),
/// 10 = double weight (similarity 2.0). Default: 5.
#[serde(default)]
pub seed_weight: Option<u8>,
/// Explicit per-artist track cap. None or 0 = auto (derived from popularity_bias).
#[serde(default)]
pub max_tracks_per_artist: Option<u8>,
/// Maximum distinct artists in the result. None or 0 = unlimited.
#[serde(default)]
pub max_artists: Option<u8>,
}
/// Resolved configuration for the similar-artists strategy.
#[derive(Debug, Clone)]
pub struct SimilarConfig {
pub count: usize,
pub popularity_bias: u8,
pub ordering: String,
pub discovery_range: u8,
pub global_popularity: u8,
pub country_filter: bool,
pub seed_weight: u8,
pub max_tracks_per_artist: Option<u8>,
pub max_artists: Option<u8>,
}
impl SimilarConfig {
pub fn from_request(req: &PlaylistRequest) -> Self {
Self {
count: req.count,
popularity_bias: req.popularity_bias,
ordering: req.ordering.clone(),
discovery_range: req.discovery_range.unwrap_or(5),
global_popularity: req.global_popularity.unwrap_or(0),
country_filter: req.country_filter.unwrap_or(false),
seed_weight: req.seed_weight.unwrap_or(5),
max_tracks_per_artist: req.max_tracks_per_artist.filter(|&v| v > 0),
max_artists: req.max_artists.filter(|&v| v > 0),
}
}
} }
fn default_count() -> usize { fn default_count() -> usize {
+11 -11
View File
@@ -83,7 +83,7 @@ fn test_score_tracks_basic() {
let mut top_map = HashMap::new(); let mut top_map = HashMap::new();
top_map.insert("artist-1".to_string(), top_tracks); top_map.insert("artist-1".to_string(), top_tracks);
let scored = score_tracks(&artists, &tracks_map, &top_map, 5); let scored = score_tracks(&artists, &tracks_map, &top_map, 5, 0, None);
// Should have 3 tracks (only ones matching top tracks) // Should have 3 tracks (only ones matching top tracks)
assert_eq!(scored.len(), 3); assert_eq!(scored.len(), 3);
@@ -107,7 +107,7 @@ fn test_score_tracks_no_top_tracks_uses_uniform() {
let top_map = HashMap::new(); // No top tracks let top_map = HashMap::new(); // No top tracks
let scored = score_tracks(&artists, &tracks_map, &top_map, 5); let scored = score_tracks(&artists, &tracks_map, &top_map, 5, 0, None);
// All 3 tracks should be included with uniform scoring // All 3 tracks should be included with uniform scoring
assert_eq!(scored.len(), 3); assert_eq!(scored.len(), 3);
@@ -141,11 +141,11 @@ fn test_score_tracks_per_artist_cap() {
top_map.insert("artist-1".to_string(), top_tracks); top_map.insert("artist-1".to_string(), top_tracks);
// bias 10 → cap = 10 // bias 10 → cap = 10
let scored = score_tracks(&artists, &tracks_map, &top_map, 10); let scored = score_tracks(&artists, &tracks_map, &top_map, 10, 0, None);
assert!(scored.len() <= 10); assert!(scored.len() <= 10);
// bias 0 → no cap // bias 0 → no cap
let scored_no_cap = score_tracks(&artists, &tracks_map, &top_map, 0); let scored_no_cap = score_tracks(&artists, &tracks_map, &top_map, 0, 0, None);
assert_eq!(scored_no_cap.len(), 50); assert_eq!(scored_no_cap.len(), 50);
} }
@@ -164,7 +164,7 @@ fn test_similarity_transform() {
tracks_map.insert("high".to_string(), vec![track_high]); tracks_map.insert("high".to_string(), vec![track_high]);
tracks_map.insert("low".to_string(), vec![track_low]); tracks_map.insert("low".to_string(), vec![track_low]);
let scored = score_tracks(&artists, &tracks_map, &HashMap::new(), 5); let scored = score_tracks(&artists, &tracks_map, &HashMap::new(), 5, 0, None);
assert_eq!(scored.len(), 2); assert_eq!(scored.len(), 2);
let high_score = scored let high_score = scored
@@ -189,7 +189,7 @@ fn test_generate_playlist_basic() {
.collect(); .collect();
let seeds = HashSet::new(); let seeds = HashSet::new();
let result = generate_playlist(&candidates, 10, &seeds); let result = generate_playlist(&candidates, 10, &seeds, None, false);
assert_eq!(result.len(), 10); assert_eq!(result.len(), 10);
} }
@@ -199,7 +199,7 @@ fn test_generate_playlist_respects_count() {
let candidates: Vec<Candidate> = (1..=5).map(|i| make_candidate(i, "Artist", 1.0)).collect(); let candidates: Vec<Candidate> = (1..=5).map(|i| make_candidate(i, "Artist", 1.0)).collect();
let seeds = HashSet::new(); let seeds = HashSet::new();
let result = generate_playlist(&candidates, 3, &seeds); let result = generate_playlist(&candidates, 3, &seeds, None, false);
assert_eq!(result.len(), 3); assert_eq!(result.len(), 3);
} }
@@ -208,7 +208,7 @@ fn test_generate_playlist_not_more_than_available() {
let candidates: Vec<Candidate> = (1..=3).map(|i| make_candidate(i, "Artist", 1.0)).collect(); let candidates: Vec<Candidate> = (1..=3).map(|i| make_candidate(i, "Artist", 1.0)).collect();
let seeds = HashSet::new(); let seeds = HashSet::new();
let result = generate_playlist(&candidates, 100, &seeds); let result = generate_playlist(&candidates, 100, &seeds, None, false);
assert_eq!(result.len(), 3); assert_eq!(result.len(), 3);
} }
@@ -216,7 +216,7 @@ fn test_generate_playlist_not_more_than_available() {
fn test_generate_playlist_empty_candidates() { fn test_generate_playlist_empty_candidates() {
let candidates: Vec<Candidate> = vec![]; let candidates: Vec<Candidate> = vec![];
let seeds = HashSet::new(); let seeds = HashSet::new();
let result = generate_playlist(&candidates, 10, &seeds); let result = generate_playlist(&candidates, 10, &seeds, None, false);
assert!(result.is_empty()); assert!(result.is_empty());
} }
@@ -229,7 +229,7 @@ fn test_generate_playlist_per_artist_cap() {
candidates.extend((21..=25).map(|i| make_candidate(i, "Minor", 1.0))); candidates.extend((21..=25).map(|i| make_candidate(i, "Minor", 1.0)));
let seeds = HashSet::new(); let seeds = HashSet::new();
let result = generate_playlist(&candidates, 15, &seeds); let result = generate_playlist(&candidates, 15, &seeds, None, false);
let prolific_count = result.iter().filter(|c| c.artist == "Prolific").count(); let prolific_count = result.iter().filter(|c| c.artist == "Prolific").count();
let minor_count = result.iter().filter(|c| c.artist == "Minor").count(); let minor_count = result.iter().filter(|c| c.artist == "Minor").count();
@@ -259,7 +259,7 @@ fn test_generate_playlist_seed_enforcement() {
let mut seeds = HashSet::new(); let mut seeds = HashSet::new();
seeds.insert("Seed".to_string()); seeds.insert("Seed".to_string());
let result = generate_playlist(&candidates, 10, &seeds); let result = generate_playlist(&candidates, 10, &seeds, None, false);
let seed_count = result.iter().filter(|c| c.artist == "Seed").count(); let seed_count = result.iter().filter(|c| c.artist == "Seed").count();
// seed_min = (10/10).max(1) = 1, so at least 1 seed track // seed_min = (10/10).max(1) = 1, so at least 1 seed track