update to the playlists. testing
This commit is contained in:
@@ -1,4 +1,6 @@
|
||||
use std::collections::{HashMap, HashSet};
|
||||
use std::future::Future;
|
||||
use std::pin::Pin;
|
||||
|
||||
use sea_orm::DatabaseConnection;
|
||||
use shanty_data::{PopularTrack, SimilarArtist, SimilarArtistFetcher};
|
||||
@@ -12,6 +14,16 @@ use crate::types::*;
|
||||
/// Cache TTL: 7 days in seconds.
|
||||
const CACHE_TTL: i64 = 7 * 24 * 3600;
|
||||
|
||||
/// Trait for looking up an artist's country by MBID.
|
||||
/// Implementations should return quickly (local DB or cache), never blocking
|
||||
/// on rate-limited remote APIs during playlist generation.
|
||||
pub trait CountryLookup: Send + Sync {
|
||||
fn get_country<'a>(
|
||||
&'a self,
|
||||
mbid: &'a str,
|
||||
) -> Pin<Box<dyn Future<Output = Option<String>> + Send + 'a>>;
|
||||
}
|
||||
|
||||
/// Generate a playlist based on similar artists (the primary strategy).
|
||||
///
|
||||
/// Flow:
|
||||
@@ -26,9 +38,8 @@ pub async fn similar_artists(
|
||||
conn: &DatabaseConnection,
|
||||
fetcher: &impl SimilarArtistFetcher,
|
||||
seed_artists: Vec<String>,
|
||||
count: usize,
|
||||
popularity_bias: u8,
|
||||
ordering: &str,
|
||||
config: &SimilarConfig,
|
||||
_country_fetcher: Option<&dyn CountryLookup>,
|
||||
) -> Result<PlaylistResult, PlaylistError> {
|
||||
if seed_artists.is_empty() {
|
||||
return Err(PlaylistError::InvalidInput(
|
||||
@@ -37,25 +48,32 @@ pub async fn similar_artists(
|
||||
}
|
||||
|
||||
let num_seeds = seed_artists.len() as f64;
|
||||
let seed_similarity = config.seed_weight as f64 * 0.2;
|
||||
|
||||
// Merge similar artists from all seeds: key -> (name, total_score)
|
||||
let mut merged: HashMap<String, (String, f64)> = HashMap::new();
|
||||
// Track resolved seed names for enforcement (use DB names, not raw input)
|
||||
let mut resolved_seed_names: HashSet<String> = HashSet::new();
|
||||
// Track which keys are seeds (for country filter)
|
||||
let mut seed_keys: HashSet<String> = HashSet::new();
|
||||
|
||||
for seed in &seed_artists {
|
||||
// Resolve the seed artist: try name lookup in DB
|
||||
let (artist_name, artist_mbid) = resolve_artist(conn, seed).await?;
|
||||
resolved_seed_names.insert(artist_name.clone());
|
||||
|
||||
// Insert the seed itself with score 1.0
|
||||
let key = artist_mbid
|
||||
.clone()
|
||||
.unwrap_or_else(|| artist_name.to_lowercase());
|
||||
let entry = merged
|
||||
.entry(key)
|
||||
.or_insert_with(|| (artist_name.clone(), 0.0));
|
||||
entry.1 += 1.0;
|
||||
seed_keys.insert(key.clone());
|
||||
|
||||
// Insert the seed itself with configured weight
|
||||
if seed_similarity > 0.0 {
|
||||
let entry = merged
|
||||
.entry(key.clone())
|
||||
.or_insert_with(|| (artist_name.clone(), 0.0));
|
||||
entry.1 += seed_similarity;
|
||||
}
|
||||
|
||||
// Fetch similar artists (cached or fresh)
|
||||
let similar = fetch_cached_similar(conn, fetcher, &artist_name, artist_mbid.as_deref())
|
||||
@@ -71,11 +89,41 @@ pub async fn similar_artists(
|
||||
}
|
||||
}
|
||||
|
||||
// Normalize scores by seed count
|
||||
let artists: Vec<(String, String, f64)> = merged
|
||||
// Normalize scores by seed count, sort by similarity descending
|
||||
let mut artists: Vec<(String, String, f64)> = merged
|
||||
.into_iter()
|
||||
.map(|(key, (name, total))| (key, name, total / num_seeds))
|
||||
.collect();
|
||||
artists.sort_by(|a, b| b.2.partial_cmp(&a.2).unwrap_or(std::cmp::Ordering::Equal));
|
||||
|
||||
// Apply discovery range: truncate to pool size
|
||||
let pool_size = discovery_pool_size(config.discovery_range);
|
||||
artists.truncate(pool_size);
|
||||
|
||||
// Country filter: only keep artists from the same countries as seeds
|
||||
if config.country_filter
|
||||
&& let Some(cf) = _country_fetcher
|
||||
{
|
||||
let mut seed_countries: HashSet<String> = HashSet::new();
|
||||
for key in &seed_keys {
|
||||
if let Some(country) = cf.get_country(key).await {
|
||||
seed_countries.insert(country);
|
||||
}
|
||||
}
|
||||
|
||||
if !seed_countries.is_empty() {
|
||||
let mut filtered = Vec::new();
|
||||
for entry in artists {
|
||||
let country = cf.get_country(&entry.0).await;
|
||||
match country {
|
||||
Some(c) if seed_countries.contains(&c) => filtered.push(entry),
|
||||
None => filtered.push(entry), // unknown = pass through
|
||||
_ => {} // known but different = exclude
|
||||
}
|
||||
}
|
||||
artists = filtered;
|
||||
}
|
||||
}
|
||||
|
||||
// Build track and top-track maps for scoring
|
||||
let mut tracks_by_artist: HashMap<String, Vec<shanty_db::entities::track::Model>> =
|
||||
@@ -104,7 +152,9 @@ pub async fn similar_artists(
|
||||
&artists,
|
||||
&tracks_by_artist,
|
||||
&top_tracks_by_artist,
|
||||
popularity_bias,
|
||||
config.popularity_bias,
|
||||
config.global_popularity,
|
||||
config.max_tracks_per_artist,
|
||||
);
|
||||
|
||||
// Convert to candidates
|
||||
@@ -123,10 +173,17 @@ pub async fn similar_artists(
|
||||
.collect();
|
||||
|
||||
// Select (use resolved DB names for seed enforcement, not raw input)
|
||||
let selected = selection::generate_playlist(&candidates, count, &resolved_seed_names);
|
||||
let skip_seed_enforcement = config.seed_weight == 0;
|
||||
let selected = selection::generate_playlist(
|
||||
&candidates,
|
||||
config.count,
|
||||
&resolved_seed_names,
|
||||
config.max_artists,
|
||||
skip_seed_enforcement,
|
||||
);
|
||||
|
||||
// Order
|
||||
let ordered = apply_ordering(selected, ordering);
|
||||
let ordered = apply_ordering(selected, &config.ordering);
|
||||
|
||||
Ok(PlaylistResult {
|
||||
tracks: candidates_to_tracks(ordered),
|
||||
@@ -135,6 +192,14 @@ pub async fn similar_artists(
|
||||
})
|
||||
}
|
||||
|
||||
/// Map discovery_range (0-10) to artist pool size.
|
||||
/// 0 -> 15, 5 -> ~100, 10 -> 500 (exponential curve).
|
||||
fn discovery_pool_size(range: u8) -> usize {
|
||||
let r = range.min(10) as f64;
|
||||
let size = 15.0 * (500.0_f64 / 15.0).powf(r / 10.0);
|
||||
size.round() as usize
|
||||
}
|
||||
|
||||
/// Generate a genre-based playlist.
|
||||
pub async fn genre_based(
|
||||
conn: &DatabaseConnection,
|
||||
@@ -176,7 +241,7 @@ pub async fn genre_based(
|
||||
.collect();
|
||||
|
||||
let seed_names = HashSet::new();
|
||||
let selected = selection::generate_playlist(&candidates, count, &seed_names);
|
||||
let selected = selection::generate_playlist(&candidates, count, &seed_names, None, true);
|
||||
let ordered = apply_ordering(selected, ordering);
|
||||
|
||||
Ok(PlaylistResult {
|
||||
@@ -306,7 +371,7 @@ pub async fn smart(
|
||||
.collect();
|
||||
|
||||
let seed_names = HashSet::new();
|
||||
let selected = selection::generate_playlist(&candidates, count, &seed_names);
|
||||
let selected = selection::generate_playlist(&candidates, count, &seed_names, None, true);
|
||||
let ordered = ordering::interleave_artists(selected);
|
||||
|
||||
Ok(PlaylistResult {
|
||||
|
||||
Reference in New Issue
Block a user