Added the playlist generator
This commit is contained in:
@@ -0,0 +1,490 @@
|
||||
use std::collections::{HashMap, HashSet};
|
||||
|
||||
use sea_orm::DatabaseConnection;
|
||||
use shanty_data::{PopularTrack, SimilarArtist, SimilarArtistFetcher};
|
||||
use shanty_db::queries;
|
||||
|
||||
use crate::ordering;
|
||||
use crate::scoring;
|
||||
use crate::selection;
|
||||
use crate::types::*;
|
||||
|
||||
/// Cache TTL: 7 days in seconds.
|
||||
const CACHE_TTL: i64 = 7 * 24 * 3600;
|
||||
|
||||
/// Generate a playlist based on similar artists (the primary strategy).
|
||||
///
|
||||
/// Flow:
|
||||
/// 1. For each seed artist: resolve MBID from DB
|
||||
/// 2. Fetch similar artists (check cache first, else call Last.fm, cache result)
|
||||
/// 3. Merge multi-seed: accumulate scores, normalize by seed count
|
||||
/// 4. Filter: only keep artists that have tracks in the local library
|
||||
/// 5. Score all tracks
|
||||
/// 6. Select via weighted sampling
|
||||
/// 7. Order (interleave or shuffle)
|
||||
pub async fn similar_artists(
|
||||
conn: &DatabaseConnection,
|
||||
fetcher: &impl SimilarArtistFetcher,
|
||||
seed_artists: Vec<String>,
|
||||
count: usize,
|
||||
popularity_bias: u8,
|
||||
ordering: &str,
|
||||
) -> Result<PlaylistResult, PlaylistError> {
|
||||
if seed_artists.is_empty() {
|
||||
return Err(PlaylistError::InvalidInput(
|
||||
"at least one seed artist is required".into(),
|
||||
));
|
||||
}
|
||||
|
||||
let num_seeds = seed_artists.len() as f64;
|
||||
|
||||
// Merge similar artists from all seeds: key -> (name, total_score)
|
||||
let mut merged: HashMap<String, (String, f64)> = HashMap::new();
|
||||
// Track resolved seed names for enforcement (use DB names, not raw input)
|
||||
let mut resolved_seed_names: HashSet<String> = HashSet::new();
|
||||
|
||||
for seed in &seed_artists {
|
||||
// Resolve the seed artist: try name lookup in DB
|
||||
let (artist_name, artist_mbid) = resolve_artist(conn, seed).await?;
|
||||
resolved_seed_names.insert(artist_name.clone());
|
||||
|
||||
// Insert the seed itself with score 1.0
|
||||
let key = artist_mbid
|
||||
.clone()
|
||||
.unwrap_or_else(|| artist_name.to_lowercase());
|
||||
let entry = merged
|
||||
.entry(key)
|
||||
.or_insert_with(|| (artist_name.clone(), 0.0));
|
||||
entry.1 += 1.0;
|
||||
|
||||
// Fetch similar artists (cached or fresh)
|
||||
let similar = fetch_cached_similar(conn, fetcher, &artist_name, artist_mbid.as_deref())
|
||||
.await
|
||||
.unwrap_or_default();
|
||||
|
||||
for sa in similar {
|
||||
let sa_key = sa.mbid.clone().unwrap_or_else(|| sa.name.to_lowercase());
|
||||
let entry = merged
|
||||
.entry(sa_key)
|
||||
.or_insert_with(|| (sa.name.clone(), 0.0));
|
||||
entry.1 += sa.match_score;
|
||||
}
|
||||
}
|
||||
|
||||
// Normalize scores by seed count
|
||||
let artists: Vec<(String, String, f64)> = merged
|
||||
.into_iter()
|
||||
.map(|(key, (name, total))| (key, name, total / num_seeds))
|
||||
.collect();
|
||||
|
||||
// Build track and top-track maps for scoring
|
||||
let mut tracks_by_artist: HashMap<String, Vec<shanty_db::entities::track::Model>> =
|
||||
HashMap::new();
|
||||
let mut top_tracks_by_artist: HashMap<String, Vec<PopularTrack>> = HashMap::new();
|
||||
|
||||
for (key, name, _) in &artists {
|
||||
// Get local tracks for this artist
|
||||
let local = get_artist_tracks(conn, key, name).await;
|
||||
if local.is_empty() {
|
||||
continue;
|
||||
}
|
||||
tracks_by_artist.insert(key.clone(), local);
|
||||
|
||||
// Get top tracks (cached or fresh)
|
||||
let top = fetch_cached_top_tracks(conn, fetcher, name, Some(key.as_str()))
|
||||
.await
|
||||
.unwrap_or_default();
|
||||
if !top.is_empty() {
|
||||
top_tracks_by_artist.insert(key.clone(), top);
|
||||
}
|
||||
}
|
||||
|
||||
// Score
|
||||
let scored = scoring::score_tracks(
|
||||
&artists,
|
||||
&tracks_by_artist,
|
||||
&top_tracks_by_artist,
|
||||
popularity_bias,
|
||||
);
|
||||
|
||||
// Convert to candidates
|
||||
let candidates: Vec<Candidate> = scored
|
||||
.into_iter()
|
||||
.map(|t| Candidate {
|
||||
score: t.score,
|
||||
artist: t.artist,
|
||||
artist_mbid: t.artist_mbid,
|
||||
track_id: t.track_id,
|
||||
file_path: t.file_path,
|
||||
title: t.title,
|
||||
album: t.album,
|
||||
duration: t.duration,
|
||||
})
|
||||
.collect();
|
||||
|
||||
// Select (use resolved DB names for seed enforcement, not raw input)
|
||||
let selected = selection::generate_playlist(&candidates, count, &resolved_seed_names);
|
||||
|
||||
// Order
|
||||
let ordered = apply_ordering(selected, ordering);
|
||||
|
||||
Ok(PlaylistResult {
|
||||
tracks: candidates_to_tracks(ordered),
|
||||
strategy: "similar".to_string(),
|
||||
resolved_seeds: resolved_seed_names.into_iter().collect(),
|
||||
})
|
||||
}
|
||||
|
||||
/// Generate a genre-based playlist.
|
||||
pub async fn genre_based(
|
||||
conn: &DatabaseConnection,
|
||||
genres: Vec<String>,
|
||||
count: usize,
|
||||
ordering: &str,
|
||||
) -> Result<PlaylistResult, PlaylistError> {
|
||||
if genres.is_empty() {
|
||||
return Err(PlaylistError::InvalidInput(
|
||||
"at least one genre is required".into(),
|
||||
));
|
||||
}
|
||||
|
||||
let mut all_tracks = Vec::new();
|
||||
for genre in &genres {
|
||||
let tracks = queries::tracks::get_by_genre(conn, genre)
|
||||
.await
|
||||
.map_err(|e| PlaylistError::Db(e.to_string()))?;
|
||||
all_tracks.extend(tracks);
|
||||
}
|
||||
|
||||
// Deduplicate by track ID
|
||||
let mut seen = HashSet::new();
|
||||
all_tracks.retain(|t| seen.insert(t.id));
|
||||
|
||||
// Convert to candidates with uniform scoring
|
||||
let candidates: Vec<Candidate> = all_tracks
|
||||
.into_iter()
|
||||
.map(|t| Candidate {
|
||||
score: 1.0,
|
||||
artist: t.artist.clone().unwrap_or_default(),
|
||||
artist_mbid: t.musicbrainz_id.clone(),
|
||||
track_id: t.id,
|
||||
file_path: t.file_path.clone(),
|
||||
title: t.title.clone(),
|
||||
album: t.album.clone(),
|
||||
duration: t.duration,
|
||||
})
|
||||
.collect();
|
||||
|
||||
let seed_names = HashSet::new();
|
||||
let selected = selection::generate_playlist(&candidates, count, &seed_names);
|
||||
let ordered = apply_ordering(selected, ordering);
|
||||
|
||||
Ok(PlaylistResult {
|
||||
tracks: candidates_to_tracks(ordered),
|
||||
strategy: "genre".to_string(),
|
||||
resolved_seeds: vec![],
|
||||
})
|
||||
}
|
||||
|
||||
/// Generate a random playlist.
|
||||
pub async fn random(
|
||||
conn: &DatabaseConnection,
|
||||
count: usize,
|
||||
no_repeat_artist: bool,
|
||||
) -> Result<PlaylistResult, PlaylistError> {
|
||||
let tracks = queries::tracks::get_random(conn, count as u64 * 2)
|
||||
.await
|
||||
.map_err(|e| PlaylistError::Db(e.to_string()))?;
|
||||
|
||||
let mut result = Vec::new();
|
||||
let mut seen_artists: HashSet<String> = HashSet::new();
|
||||
|
||||
for t in tracks {
|
||||
if result.len() >= count {
|
||||
break;
|
||||
}
|
||||
let artist = t.artist.clone().unwrap_or_default();
|
||||
if no_repeat_artist && !artist.is_empty() && !seen_artists.insert(artist.clone()) {
|
||||
continue;
|
||||
}
|
||||
result.push(PlaylistTrack {
|
||||
track_id: t.id,
|
||||
file_path: t.file_path.clone(),
|
||||
title: t.title.clone(),
|
||||
artist: t.artist.clone(),
|
||||
album: t.album.clone(),
|
||||
score: 0.0,
|
||||
duration: t.duration,
|
||||
});
|
||||
}
|
||||
|
||||
Ok(PlaylistResult {
|
||||
tracks: result,
|
||||
strategy: "random".to_string(),
|
||||
resolved_seeds: vec![],
|
||||
})
|
||||
}
|
||||
|
||||
/// Generate a smart playlist based on rules.
|
||||
pub async fn smart(
|
||||
conn: &DatabaseConnection,
|
||||
rules: SmartRules,
|
||||
count: usize,
|
||||
) -> Result<PlaylistResult, PlaylistError> {
|
||||
let mut all_tracks: Vec<shanty_db::entities::track::Model> = Vec::new();
|
||||
|
||||
// Genre filter
|
||||
if !rules.genres.is_empty() {
|
||||
for genre in &rules.genres {
|
||||
let tracks = queries::tracks::get_by_genre(conn, genre)
|
||||
.await
|
||||
.map_err(|e| PlaylistError::Db(e.to_string()))?;
|
||||
all_tracks.extend(tracks);
|
||||
}
|
||||
}
|
||||
|
||||
// Artist filter
|
||||
if !rules.artists.is_empty() {
|
||||
for artist_name in &rules.artists {
|
||||
let tracks = queries::tracks::get_by_artist_name(conn, artist_name)
|
||||
.await
|
||||
.map_err(|e| PlaylistError::Db(e.to_string()))?;
|
||||
all_tracks.extend(tracks);
|
||||
}
|
||||
}
|
||||
|
||||
// Recently added filter
|
||||
if let Some(days) = rules.added_within_days {
|
||||
let tracks = queries::tracks::get_recent(conn, days, 10000)
|
||||
.await
|
||||
.map_err(|e| PlaylistError::Db(e.to_string()))?;
|
||||
if all_tracks.is_empty() {
|
||||
all_tracks = tracks;
|
||||
} else {
|
||||
let recent_ids: HashSet<i32> = tracks.iter().map(|t| t.id).collect();
|
||||
all_tracks.retain(|t| recent_ids.contains(&t.id));
|
||||
}
|
||||
}
|
||||
|
||||
// Year range filter
|
||||
if let Some((min_year, max_year)) = rules.year_range {
|
||||
all_tracks.retain(|t| {
|
||||
t.year
|
||||
.map(|y| y >= min_year && y <= max_year)
|
||||
.unwrap_or(false)
|
||||
});
|
||||
}
|
||||
|
||||
// If no filters were specified, get random tracks
|
||||
if rules.genres.is_empty()
|
||||
&& rules.artists.is_empty()
|
||||
&& rules.added_within_days.is_none()
|
||||
&& rules.year_range.is_none()
|
||||
{
|
||||
all_tracks = queries::tracks::get_random(conn, count as u64 * 2)
|
||||
.await
|
||||
.map_err(|e| PlaylistError::Db(e.to_string()))?;
|
||||
}
|
||||
|
||||
// Deduplicate
|
||||
let mut seen = HashSet::new();
|
||||
all_tracks.retain(|t| seen.insert(t.id));
|
||||
|
||||
// Convert to candidates
|
||||
let candidates: Vec<Candidate> = all_tracks
|
||||
.into_iter()
|
||||
.map(|t| Candidate {
|
||||
score: 1.0,
|
||||
artist: t.artist.clone().unwrap_or_default(),
|
||||
artist_mbid: t.musicbrainz_id.clone(),
|
||||
track_id: t.id,
|
||||
file_path: t.file_path.clone(),
|
||||
title: t.title.clone(),
|
||||
album: t.album.clone(),
|
||||
duration: t.duration,
|
||||
})
|
||||
.collect();
|
||||
|
||||
let seed_names = HashSet::new();
|
||||
let selected = selection::generate_playlist(&candidates, count, &seed_names);
|
||||
let ordered = ordering::interleave_artists(selected);
|
||||
|
||||
Ok(PlaylistResult {
|
||||
tracks: candidates_to_tracks(ordered),
|
||||
strategy: "smart".to_string(),
|
||||
resolved_seeds: vec![],
|
||||
})
|
||||
}
|
||||
|
||||
/// Generate an M3U playlist string from tracks.
|
||||
pub fn to_m3u(tracks: &[PlaylistTrack]) -> String {
|
||||
let mut out = String::from("#EXTM3U\n");
|
||||
for t in tracks {
|
||||
let duration = t.duration.unwrap_or(0.0) as i64;
|
||||
let artist = t.artist.as_deref().unwrap_or("Unknown");
|
||||
let title = t.title.as_deref().unwrap_or("Unknown");
|
||||
out.push_str(&format!("#EXTINF:{duration},{artist} - {title}\n"));
|
||||
out.push_str(&t.file_path);
|
||||
out.push('\n');
|
||||
}
|
||||
out
|
||||
}
|
||||
|
||||
/// Apply the chosen ordering mode to selected candidates.
|
||||
fn apply_ordering(mut candidates: Vec<Candidate>, mode: &str) -> Vec<Candidate> {
|
||||
match mode {
|
||||
"random" => ordering::shuffle(candidates),
|
||||
"score" => {
|
||||
candidates.sort_by(|a, b| {
|
||||
b.score
|
||||
.partial_cmp(&a.score)
|
||||
.unwrap_or(std::cmp::Ordering::Equal)
|
||||
});
|
||||
candidates
|
||||
}
|
||||
_ => ordering::interleave_artists(candidates), // "interleave" is default
|
||||
}
|
||||
}
|
||||
|
||||
// --- Helper functions ---
|
||||
|
||||
/// Resolve an artist name or MBID to (name, optional_mbid).
|
||||
async fn resolve_artist(
|
||||
conn: &DatabaseConnection,
|
||||
query: &str,
|
||||
) -> Result<(String, Option<String>), PlaylistError> {
|
||||
// Try as MBID first (if it looks like a UUID)
|
||||
if query.len() == 36 && query.contains('-') {
|
||||
let tracks = queries::tracks::get_by_artist_mbid(conn, query)
|
||||
.await
|
||||
.map_err(|e| PlaylistError::Db(e.to_string()))?;
|
||||
if let Some(t) = tracks.first() {
|
||||
let name = t.artist.clone().unwrap_or_else(|| query.to_string());
|
||||
return Ok((name, Some(query.to_string())));
|
||||
}
|
||||
}
|
||||
|
||||
// Try by name in DB
|
||||
if let Ok(Some(artist)) = queries::artists::find_by_name(conn, query).await {
|
||||
return Ok((artist.name, artist.musicbrainz_id));
|
||||
}
|
||||
|
||||
// Fall back to using the query as the name
|
||||
Ok((query.to_string(), None))
|
||||
}
|
||||
|
||||
/// Get local tracks for an artist by key (MBID) or name.
|
||||
async fn get_artist_tracks(
|
||||
conn: &DatabaseConnection,
|
||||
key: &str,
|
||||
name: &str,
|
||||
) -> Vec<shanty_db::entities::track::Model> {
|
||||
// Try by MBID first
|
||||
if let Ok(tracks) = queries::tracks::get_by_artist_mbid(conn, key).await
|
||||
&& !tracks.is_empty()
|
||||
{
|
||||
return tracks;
|
||||
}
|
||||
|
||||
// Try by name
|
||||
queries::tracks::get_by_artist_name(conn, name)
|
||||
.await
|
||||
.unwrap_or_default()
|
||||
}
|
||||
|
||||
/// Fetch similar artists from cache or Last.fm.
|
||||
async fn fetch_cached_similar(
|
||||
conn: &DatabaseConnection,
|
||||
fetcher: &impl SimilarArtistFetcher,
|
||||
artist_name: &str,
|
||||
mbid: Option<&str>,
|
||||
) -> Result<Vec<SimilarArtist>, PlaylistError> {
|
||||
let cache_key = format!(
|
||||
"lastfm_similar:{}",
|
||||
mbid.unwrap_or(&artist_name.to_lowercase())
|
||||
);
|
||||
|
||||
// Check cache
|
||||
if let Ok(Some(json)) = queries::cache::get(conn, &cache_key).await
|
||||
&& let Ok(cached) = serde_json::from_str::<Vec<SimilarArtist>>(&json)
|
||||
{
|
||||
tracing::debug!(artist = artist_name, "using cached similar artists");
|
||||
return Ok(cached);
|
||||
}
|
||||
|
||||
// Fetch from provider
|
||||
let similar = fetcher
|
||||
.get_similar_artists(artist_name, mbid)
|
||||
.await
|
||||
.map_err(|e| PlaylistError::FetchError(e.to_string()))?;
|
||||
|
||||
// Cache the result
|
||||
if let Ok(json) = serde_json::to_string(&similar) {
|
||||
let _ = queries::cache::set(conn, &cache_key, "lastfm", &json, CACHE_TTL).await;
|
||||
}
|
||||
|
||||
Ok(similar)
|
||||
}
|
||||
|
||||
/// Fetch top tracks from cache or Last.fm.
|
||||
async fn fetch_cached_top_tracks(
|
||||
conn: &DatabaseConnection,
|
||||
fetcher: &impl SimilarArtistFetcher,
|
||||
artist_name: &str,
|
||||
mbid: Option<&str>,
|
||||
) -> Result<Vec<PopularTrack>, PlaylistError> {
|
||||
let cache_key = format!(
|
||||
"lastfm_toptracks:{}",
|
||||
mbid.unwrap_or(&artist_name.to_lowercase())
|
||||
);
|
||||
|
||||
// Check cache
|
||||
if let Ok(Some(json)) = queries::cache::get(conn, &cache_key).await
|
||||
&& let Ok(cached) = serde_json::from_str::<Vec<PopularTrack>>(&json)
|
||||
{
|
||||
tracing::debug!(artist = artist_name, "using cached top tracks");
|
||||
return Ok(cached);
|
||||
}
|
||||
|
||||
// Fetch from provider
|
||||
let tracks = fetcher
|
||||
.get_top_tracks(artist_name, mbid)
|
||||
.await
|
||||
.map_err(|e| PlaylistError::FetchError(e.to_string()))?;
|
||||
|
||||
// Cache the result
|
||||
if let Ok(json) = serde_json::to_string(&tracks) {
|
||||
let _ = queries::cache::set(conn, &cache_key, "lastfm", &json, CACHE_TTL).await;
|
||||
}
|
||||
|
||||
Ok(tracks)
|
||||
}
|
||||
|
||||
/// Convert candidates to playlist tracks.
|
||||
fn candidates_to_tracks(candidates: Vec<Candidate>) -> Vec<PlaylistTrack> {
|
||||
candidates
|
||||
.into_iter()
|
||||
.map(|c| PlaylistTrack {
|
||||
track_id: c.track_id,
|
||||
file_path: c.file_path,
|
||||
title: c.title,
|
||||
artist: Some(c.artist),
|
||||
album: c.album,
|
||||
score: c.score,
|
||||
duration: c.duration,
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Playlist generation error.
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub enum PlaylistError {
|
||||
#[error("invalid input: {0}")]
|
||||
InvalidInput(String),
|
||||
|
||||
#[error("database error: {0}")]
|
||||
Db(String),
|
||||
|
||||
#[error("fetch error: {0}")]
|
||||
FetchError(String),
|
||||
}
|
||||
Reference in New Issue
Block a user