use std::collections::{HashMap, HashSet}; use sea_orm::DatabaseConnection; use shanty_data::{PopularTrack, SimilarArtist, SimilarArtistFetcher}; use shanty_db::queries; use crate::ordering; use crate::scoring; use crate::selection; use crate::types::*; /// Cache TTL: 7 days in seconds. const CACHE_TTL: i64 = 7 * 24 * 3600; /// Generate a playlist based on similar artists (the primary strategy). /// /// Flow: /// 1. For each seed artist: resolve MBID from DB /// 2. Fetch similar artists (check cache first, else call Last.fm, cache result) /// 3. Merge multi-seed: accumulate scores, normalize by seed count /// 4. Filter: only keep artists that have tracks in the local library /// 5. Score all tracks /// 6. Select via weighted sampling /// 7. Order (interleave or shuffle) pub async fn similar_artists( conn: &DatabaseConnection, fetcher: &impl SimilarArtistFetcher, seed_artists: Vec, count: usize, popularity_bias: u8, ordering: &str, ) -> Result { if seed_artists.is_empty() { return Err(PlaylistError::InvalidInput( "at least one seed artist is required".into(), )); } let num_seeds = seed_artists.len() as f64; // Merge similar artists from all seeds: key -> (name, total_score) let mut merged: HashMap = HashMap::new(); // Track resolved seed names for enforcement (use DB names, not raw input) let mut resolved_seed_names: HashSet = HashSet::new(); for seed in &seed_artists { // Resolve the seed artist: try name lookup in DB let (artist_name, artist_mbid) = resolve_artist(conn, seed).await?; resolved_seed_names.insert(artist_name.clone()); // Insert the seed itself with score 1.0 let key = artist_mbid .clone() .unwrap_or_else(|| artist_name.to_lowercase()); let entry = merged .entry(key) .or_insert_with(|| (artist_name.clone(), 0.0)); entry.1 += 1.0; // Fetch similar artists (cached or fresh) let similar = fetch_cached_similar(conn, fetcher, &artist_name, artist_mbid.as_deref()) .await .unwrap_or_default(); for sa in similar { let sa_key = sa.mbid.clone().unwrap_or_else(|| sa.name.to_lowercase()); let entry = merged .entry(sa_key) .or_insert_with(|| (sa.name.clone(), 0.0)); entry.1 += sa.match_score; } } // Normalize scores by seed count let artists: Vec<(String, String, f64)> = merged .into_iter() .map(|(key, (name, total))| (key, name, total / num_seeds)) .collect(); // Build track and top-track maps for scoring let mut tracks_by_artist: HashMap> = HashMap::new(); let mut top_tracks_by_artist: HashMap> = HashMap::new(); for (key, name, _) in &artists { // Get local tracks for this artist let local = get_artist_tracks(conn, key, name).await; if local.is_empty() { continue; } tracks_by_artist.insert(key.clone(), local); // Get top tracks (cached or fresh) let top = fetch_cached_top_tracks(conn, fetcher, name, Some(key.as_str())) .await .unwrap_or_default(); if !top.is_empty() { top_tracks_by_artist.insert(key.clone(), top); } } // Score let scored = scoring::score_tracks( &artists, &tracks_by_artist, &top_tracks_by_artist, popularity_bias, ); // Convert to candidates let candidates: Vec = scored .into_iter() .map(|t| Candidate { score: t.score, artist: t.artist, artist_mbid: t.artist_mbid, track_id: t.track_id, file_path: t.file_path, title: t.title, album: t.album, duration: t.duration, }) .collect(); // Select (use resolved DB names for seed enforcement, not raw input) let selected = selection::generate_playlist(&candidates, count, &resolved_seed_names); // Order let ordered = apply_ordering(selected, ordering); Ok(PlaylistResult { tracks: candidates_to_tracks(ordered), strategy: "similar".to_string(), resolved_seeds: resolved_seed_names.into_iter().collect(), }) } /// Generate a genre-based playlist. pub async fn genre_based( conn: &DatabaseConnection, genres: Vec, count: usize, ordering: &str, ) -> Result { if genres.is_empty() { return Err(PlaylistError::InvalidInput( "at least one genre is required".into(), )); } let mut all_tracks = Vec::new(); for genre in &genres { let tracks = queries::tracks::get_by_genre(conn, genre) .await .map_err(|e| PlaylistError::Db(e.to_string()))?; all_tracks.extend(tracks); } // Deduplicate by track ID let mut seen = HashSet::new(); all_tracks.retain(|t| seen.insert(t.id)); // Convert to candidates with uniform scoring let candidates: Vec = all_tracks .into_iter() .map(|t| Candidate { score: 1.0, artist: t.artist.clone().unwrap_or_default(), artist_mbid: t.musicbrainz_id.clone(), track_id: t.id, file_path: t.file_path.clone(), title: t.title.clone(), album: t.album.clone(), duration: t.duration, }) .collect(); let seed_names = HashSet::new(); let selected = selection::generate_playlist(&candidates, count, &seed_names); let ordered = apply_ordering(selected, ordering); Ok(PlaylistResult { tracks: candidates_to_tracks(ordered), strategy: "genre".to_string(), resolved_seeds: vec![], }) } /// Generate a random playlist. pub async fn random( conn: &DatabaseConnection, count: usize, no_repeat_artist: bool, ) -> Result { let tracks = queries::tracks::get_random(conn, count as u64 * 2) .await .map_err(|e| PlaylistError::Db(e.to_string()))?; let mut result = Vec::new(); let mut seen_artists: HashSet = HashSet::new(); for t in tracks { if result.len() >= count { break; } let artist = t.artist.clone().unwrap_or_default(); if no_repeat_artist && !artist.is_empty() && !seen_artists.insert(artist.clone()) { continue; } result.push(PlaylistTrack { track_id: t.id, file_path: t.file_path.clone(), title: t.title.clone(), artist: t.artist.clone(), album: t.album.clone(), score: 0.0, duration: t.duration, }); } Ok(PlaylistResult { tracks: result, strategy: "random".to_string(), resolved_seeds: vec![], }) } /// Generate a smart playlist based on rules. pub async fn smart( conn: &DatabaseConnection, rules: SmartRules, count: usize, ) -> Result { let mut all_tracks: Vec = Vec::new(); // Genre filter if !rules.genres.is_empty() { for genre in &rules.genres { let tracks = queries::tracks::get_by_genre(conn, genre) .await .map_err(|e| PlaylistError::Db(e.to_string()))?; all_tracks.extend(tracks); } } // Artist filter if !rules.artists.is_empty() { for artist_name in &rules.artists { let tracks = queries::tracks::get_by_artist_name(conn, artist_name) .await .map_err(|e| PlaylistError::Db(e.to_string()))?; all_tracks.extend(tracks); } } // Recently added filter if let Some(days) = rules.added_within_days { let tracks = queries::tracks::get_recent(conn, days, 10000) .await .map_err(|e| PlaylistError::Db(e.to_string()))?; if all_tracks.is_empty() { all_tracks = tracks; } else { let recent_ids: HashSet = tracks.iter().map(|t| t.id).collect(); all_tracks.retain(|t| recent_ids.contains(&t.id)); } } // Year range filter if let Some((min_year, max_year)) = rules.year_range { all_tracks.retain(|t| { t.year .map(|y| y >= min_year && y <= max_year) .unwrap_or(false) }); } // If no filters were specified, get random tracks if rules.genres.is_empty() && rules.artists.is_empty() && rules.added_within_days.is_none() && rules.year_range.is_none() { all_tracks = queries::tracks::get_random(conn, count as u64 * 2) .await .map_err(|e| PlaylistError::Db(e.to_string()))?; } // Deduplicate let mut seen = HashSet::new(); all_tracks.retain(|t| seen.insert(t.id)); // Convert to candidates let candidates: Vec = all_tracks .into_iter() .map(|t| Candidate { score: 1.0, artist: t.artist.clone().unwrap_or_default(), artist_mbid: t.musicbrainz_id.clone(), track_id: t.id, file_path: t.file_path.clone(), title: t.title.clone(), album: t.album.clone(), duration: t.duration, }) .collect(); let seed_names = HashSet::new(); let selected = selection::generate_playlist(&candidates, count, &seed_names); let ordered = ordering::interleave_artists(selected); Ok(PlaylistResult { tracks: candidates_to_tracks(ordered), strategy: "smart".to_string(), resolved_seeds: vec![], }) } /// Generate an M3U playlist string from tracks. pub fn to_m3u(tracks: &[PlaylistTrack]) -> String { let mut out = String::from("#EXTM3U\n"); for t in tracks { let duration = t.duration.unwrap_or(0.0) as i64; let artist = t.artist.as_deref().unwrap_or("Unknown"); let title = t.title.as_deref().unwrap_or("Unknown"); out.push_str(&format!("#EXTINF:{duration},{artist} - {title}\n")); out.push_str(&t.file_path); out.push('\n'); } out } /// Apply the chosen ordering mode to selected candidates. fn apply_ordering(mut candidates: Vec, mode: &str) -> Vec { match mode { "random" => ordering::shuffle(candidates), "score" => { candidates.sort_by(|a, b| { b.score .partial_cmp(&a.score) .unwrap_or(std::cmp::Ordering::Equal) }); candidates } _ => ordering::interleave_artists(candidates), // "interleave" is default } } // --- Helper functions --- /// Resolve an artist name or MBID to (name, optional_mbid). async fn resolve_artist( conn: &DatabaseConnection, query: &str, ) -> Result<(String, Option), PlaylistError> { // Try as MBID first (if it looks like a UUID) if query.len() == 36 && query.contains('-') { let tracks = queries::tracks::get_by_artist_mbid(conn, query) .await .map_err(|e| PlaylistError::Db(e.to_string()))?; if let Some(t) = tracks.first() { let name = t.artist.clone().unwrap_or_else(|| query.to_string()); return Ok((name, Some(query.to_string()))); } } // Try by name in DB if let Ok(Some(artist)) = queries::artists::find_by_name(conn, query).await { return Ok((artist.name, artist.musicbrainz_id)); } // Fall back to using the query as the name Ok((query.to_string(), None)) } /// Get local tracks for an artist by key (MBID) or name. async fn get_artist_tracks( conn: &DatabaseConnection, key: &str, name: &str, ) -> Vec { // Try by MBID first if let Ok(tracks) = queries::tracks::get_by_artist_mbid(conn, key).await && !tracks.is_empty() { return tracks; } // Try by name queries::tracks::get_by_artist_name(conn, name) .await .unwrap_or_default() } /// Fetch similar artists from cache or Last.fm. async fn fetch_cached_similar( conn: &DatabaseConnection, fetcher: &impl SimilarArtistFetcher, artist_name: &str, mbid: Option<&str>, ) -> Result, PlaylistError> { let cache_key = format!( "lastfm_similar:{}", mbid.unwrap_or(&artist_name.to_lowercase()) ); // Check cache if let Ok(Some(json)) = queries::cache::get(conn, &cache_key).await && let Ok(cached) = serde_json::from_str::>(&json) { tracing::debug!(artist = artist_name, "using cached similar artists"); return Ok(cached); } // Fetch from provider let similar = fetcher .get_similar_artists(artist_name, mbid) .await .map_err(|e| PlaylistError::FetchError(e.to_string()))?; // Cache the result if let Ok(json) = serde_json::to_string(&similar) { let _ = queries::cache::set(conn, &cache_key, "lastfm", &json, CACHE_TTL).await; } Ok(similar) } /// Fetch top tracks from cache or Last.fm. async fn fetch_cached_top_tracks( conn: &DatabaseConnection, fetcher: &impl SimilarArtistFetcher, artist_name: &str, mbid: Option<&str>, ) -> Result, PlaylistError> { let cache_key = format!( "lastfm_toptracks:{}", mbid.unwrap_or(&artist_name.to_lowercase()) ); // Check cache if let Ok(Some(json)) = queries::cache::get(conn, &cache_key).await && let Ok(cached) = serde_json::from_str::>(&json) { tracing::debug!(artist = artist_name, "using cached top tracks"); return Ok(cached); } // Fetch from provider let tracks = fetcher .get_top_tracks(artist_name, mbid) .await .map_err(|e| PlaylistError::FetchError(e.to_string()))?; // Cache the result if let Ok(json) = serde_json::to_string(&tracks) { let _ = queries::cache::set(conn, &cache_key, "lastfm", &json, CACHE_TTL).await; } Ok(tracks) } /// Convert candidates to playlist tracks. fn candidates_to_tracks(candidates: Vec) -> Vec { candidates .into_iter() .map(|c| PlaylistTrack { track_id: c.track_id, file_path: c.file_path, title: c.title, artist: Some(c.artist), album: c.album, score: c.score, duration: c.duration, }) .collect() } /// Playlist generation error. #[derive(Debug, thiserror::Error)] pub enum PlaylistError { #[error("invalid input: {0}")] InvalidInput(String), #[error("database error: {0}")] Db(String), #[error("fetch error: {0}")] FetchError(String), }