diff --git a/src/db.rs b/src/db.rs index ab5a28a..f5b534a 100644 --- a/src/db.rs +++ b/src/db.rs @@ -36,14 +36,27 @@ pub fn artist_exists(conn: &Connection, mbid: &str) -> Result Result, rusqlite::Error> { +) -> Result, rusqlite::Error> { let mut stmt = conn.prepare( - "SELECT sa.similar_name, sa.match_score + "SELECT sa.similar_mbid, sa.similar_name, sa.match_score FROM similar_artists sa JOIN artists a ON a.mbid = sa.similar_mbid WHERE sa.artist_mbid = ?1 ORDER BY sa.match_score DESC", )?; + let rows = stmt.query_map([artist_mbid], |row| { + Ok((row.get(0)?, row.get(1)?, row.get(2)?)) + })?; + rows.collect() +} + +pub fn get_local_tracks_for_artist( + conn: &Connection, + artist_mbid: &str, +) -> Result)>, rusqlite::Error> { + let mut stmt = conn.prepare( + "SELECT path, recording_mbid FROM tracks WHERE artist_mbid = ?1", + )?; let rows = stmt.query_map([artist_mbid], |row| Ok((row.get(0)?, row.get(1)?)))?; rows.collect() } diff --git a/src/lastfm.rs b/src/lastfm.rs index b89a194..96d4002 100644 --- a/src/lastfm.rs +++ b/src/lastfm.rs @@ -12,6 +12,13 @@ pub struct SimilarArtist { pub match_score: f64, } +pub struct TopTrack { + pub name: String, + pub mbid: Option, + pub playcount: u64, + pub listeners: u64, +} + // Last.fm returns {"error": N, "message": "..."} on failure #[derive(Deserialize)] struct ApiError { @@ -40,6 +47,24 @@ struct ArtistEntry { match_score: String, } +#[derive(Deserialize)] +struct TopTracksResponse { + toptracks: TopTracksWrapper, +} + +#[derive(Deserialize)] +struct TopTracksWrapper { + track: Vec, +} + +#[derive(Deserialize)] +struct TrackEntry { + name: String, + mbid: Option, + playcount: String, + listeners: String, +} + impl LastfmClient { pub fn new(api_key: String) -> Self { Self { api_key } @@ -75,4 +100,33 @@ impl LastfmClient { }) .collect()) } + + pub fn get_top_tracks( + &self, + artist_mbid: &str, + ) -> Result, Box> { + let url = format!( + "{}?method=artist.getTopTracks&mbid={}&api_key={}&limit=1000&format=json", + BASE_URL, artist_mbid, self.api_key + ); + let body: String = ureq::get(&url).call()?.body_mut().read_to_string()?; + + if let Ok(err) = serde_json::from_str::(&body) { + eprintln!(" Last.fm: {}", err.message); + return Ok(Vec::new()); + } + + let resp: TopTracksResponse = serde_json::from_str(&body)?; + Ok(resp + .toptracks + .track + .into_iter() + .map(|t| TopTrack { + name: t.name, + mbid: t.mbid.filter(|s| !s.is_empty()), + playcount: t.playcount.parse().unwrap_or(0), + listeners: t.listeners.parse().unwrap_or(0), + }) + .collect()) + } } diff --git a/src/main.rs b/src/main.rs index fab6053..cae7efe 100644 --- a/src/main.rs +++ b/src/main.rs @@ -9,7 +9,7 @@ use std::path::Path; fn usage(program: &str) -> ! { eprintln!("Usage:"); eprintln!(" {program} index [-v] "); - eprintln!(" {program} build "); + eprintln!(" {program} build [-v] "); std::process::exit(1); } @@ -105,12 +105,15 @@ fn cmd_index(args: &[String]) { } fn cmd_build(args: &[String]) { - if args.len() != 3 { - eprintln!("Usage: {} build ", args[0]); + let verbose = args.iter().any(|a| a == "-v"); + let rest: Vec<&String> = args.iter().skip(2).filter(|a| *a != "-v").collect(); + + if rest.len() != 1 { + eprintln!("Usage: {} build [-v] ", args[0]); std::process::exit(1); } - let path = Path::new(&args[2]); + let path = Path::new(rest[0].as_str()); let artist_mbid = match metadata::read_artist_mbid(path) { Ok(Some(mbid)) => mbid, Ok(None) => { @@ -123,17 +126,103 @@ fn cmd_build(args: &[String]) { } }; - let conn = db::open("playlists.db").expect("failed to open database"); + dotenvy::dotenv().ok(); + let api_key = env::var("LASTFM_API_KEY").unwrap_or_default(); + if api_key.is_empty() { + eprintln!("Error: LASTFM_API_KEY not set"); + std::process::exit(1); + } - match db::get_available_similar_artists(&conn, &artist_mbid) { - Ok(artists) => { - for (name, score) in &artists { - println!("{name} ({score:.4})"); - } - } + let conn = db::open("playlists.db").expect("failed to open database"); + let lastfm = lastfm::LastfmClient::new(api_key); + + let seed_name = metadata::read_artist_name(path) + .ok() + .flatten() + .unwrap_or_else(|| artist_mbid.clone()); + + let similar = match db::get_available_similar_artists(&conn, &artist_mbid) { + Ok(a) => a, Err(e) => { eprintln!("DB error: {e}"); std::process::exit(1); } + }; + + // Seed artist + similar artists: (mbid, name, match_score) + let mut artists: Vec<(String, String, f64)> = vec![ + (artist_mbid.clone(), seed_name, 1.0), + ]; + artists.extend(similar); + + // Collect scored tracks: (total, popularity, match_score, artist_name, path) + let mut playlist: Vec<(f64, f64, f64, String, String)> = Vec::new(); + + for (mbid, name, match_score) in &artists { + let local_tracks = match db::get_local_tracks_for_artist(&conn, mbid) { + Ok(t) => t, + Err(e) => { + eprintln!("DB error for {name}: {e}"); + continue; + } + }; + + if local_tracks.is_empty() { + continue; + } + + // Fetch top tracks from Last.fm for popularity data + let top_tracks = match lastfm.get_top_tracks(mbid) { + Ok(t) => t, + Err(e) => { + eprintln!("Last.fm error for {name}: {e}"); + Vec::new() + } + }; + + // Build a map from recording_mbid -> playcount + let mut playcount_by_mbid: std::collections::HashMap = + std::collections::HashMap::new(); + for tt in &top_tracks { + if let Some(ref mbid) = tt.mbid { + playcount_by_mbid.insert(mbid.clone(), tt.playcount); + } + } + + // Find max ln(playcount) for this artist to normalize + let max_ln = top_tracks + .iter() + .filter(|t| t.playcount > 0) + .map(|t| (t.playcount as f64).ln()) + .fold(f64::NEG_INFINITY, f64::max); + let max_ln = if max_ln > 0.0 { max_ln } else { 1.0 }; + + for (track_path, recording_mbid) in &local_tracks { + let playcount = recording_mbid + .as_ref() + .and_then(|rec_mbid| playcount_by_mbid.get(rec_mbid).copied()); + + // Skip tracks not in the artist's top 1000 + let Some(playcount) = playcount else { continue }; + + let popularity = if playcount > 0 { + (playcount as f64).ln() / max_ln + } else { + 0.0 + }; + + let total = match_score * (1.0 + popularity); + playlist.push((total, popularity, *match_score, name.clone(), track_path.clone())); + } + } + + playlist.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal)); + + for (total, popularity, similarity, artist, track_path) in &playlist { + if verbose { + println!("{total:.4}\t{similarity:.4}\t{popularity:.4}\t{artist}\t{track_path}"); + } else { + println!("{total:.4}\t{artist}\t{track_path}"); + } } }