Added scoring

This commit is contained in:
Connor Johnstone
2026-03-02 22:42:30 -05:00
parent 6a16cb1395
commit c3b4c946cb
3 changed files with 169 additions and 13 deletions

View File

@@ -36,14 +36,27 @@ pub fn artist_exists(conn: &Connection, mbid: &str) -> Result<bool, rusqlite::Er
pub fn get_available_similar_artists(
conn: &Connection,
artist_mbid: &str,
) -> Result<Vec<(String, f64)>, rusqlite::Error> {
) -> Result<Vec<(String, String, f64)>, rusqlite::Error> {
let mut stmt = conn.prepare(
"SELECT sa.similar_name, sa.match_score
"SELECT sa.similar_mbid, sa.similar_name, sa.match_score
FROM similar_artists sa
JOIN artists a ON a.mbid = sa.similar_mbid
WHERE sa.artist_mbid = ?1
ORDER BY sa.match_score DESC",
)?;
let rows = stmt.query_map([artist_mbid], |row| {
Ok((row.get(0)?, row.get(1)?, row.get(2)?))
})?;
rows.collect()
}
pub fn get_local_tracks_for_artist(
conn: &Connection,
artist_mbid: &str,
) -> Result<Vec<(String, Option<String>)>, rusqlite::Error> {
let mut stmt = conn.prepare(
"SELECT path, recording_mbid FROM tracks WHERE artist_mbid = ?1",
)?;
let rows = stmt.query_map([artist_mbid], |row| Ok((row.get(0)?, row.get(1)?)))?;
rows.collect()
}

View File

@@ -12,6 +12,13 @@ pub struct SimilarArtist {
pub match_score: f64,
}
pub struct TopTrack {
pub name: String,
pub mbid: Option<String>,
pub playcount: u64,
pub listeners: u64,
}
// Last.fm returns {"error": N, "message": "..."} on failure
#[derive(Deserialize)]
struct ApiError {
@@ -40,6 +47,24 @@ struct ArtistEntry {
match_score: String,
}
#[derive(Deserialize)]
struct TopTracksResponse {
toptracks: TopTracksWrapper,
}
#[derive(Deserialize)]
struct TopTracksWrapper {
track: Vec<TrackEntry>,
}
#[derive(Deserialize)]
struct TrackEntry {
name: String,
mbid: Option<String>,
playcount: String,
listeners: String,
}
impl LastfmClient {
pub fn new(api_key: String) -> Self {
Self { api_key }
@@ -75,4 +100,33 @@ impl LastfmClient {
})
.collect())
}
pub fn get_top_tracks(
&self,
artist_mbid: &str,
) -> Result<Vec<TopTrack>, Box<dyn std::error::Error>> {
let url = format!(
"{}?method=artist.getTopTracks&mbid={}&api_key={}&limit=1000&format=json",
BASE_URL, artist_mbid, self.api_key
);
let body: String = ureq::get(&url).call()?.body_mut().read_to_string()?;
if let Ok(err) = serde_json::from_str::<ApiError>(&body) {
eprintln!(" Last.fm: {}", err.message);
return Ok(Vec::new());
}
let resp: TopTracksResponse = serde_json::from_str(&body)?;
Ok(resp
.toptracks
.track
.into_iter()
.map(|t| TopTrack {
name: t.name,
mbid: t.mbid.filter(|s| !s.is_empty()),
playcount: t.playcount.parse().unwrap_or(0),
listeners: t.listeners.parse().unwrap_or(0),
})
.collect())
}
}

View File

@@ -9,7 +9,7 @@ use std::path::Path;
fn usage(program: &str) -> ! {
eprintln!("Usage:");
eprintln!(" {program} index [-v] <directory>");
eprintln!(" {program} build <file>");
eprintln!(" {program} build [-v] <file>");
std::process::exit(1);
}
@@ -105,12 +105,15 @@ fn cmd_index(args: &[String]) {
}
fn cmd_build(args: &[String]) {
if args.len() != 3 {
eprintln!("Usage: {} build <file>", args[0]);
let verbose = args.iter().any(|a| a == "-v");
let rest: Vec<&String> = args.iter().skip(2).filter(|a| *a != "-v").collect();
if rest.len() != 1 {
eprintln!("Usage: {} build [-v] <file>", args[0]);
std::process::exit(1);
}
let path = Path::new(&args[2]);
let path = Path::new(rest[0].as_str());
let artist_mbid = match metadata::read_artist_mbid(path) {
Ok(Some(mbid)) => mbid,
Ok(None) => {
@@ -123,17 +126,103 @@ fn cmd_build(args: &[String]) {
}
};
let conn = db::open("playlists.db").expect("failed to open database");
dotenvy::dotenv().ok();
let api_key = env::var("LASTFM_API_KEY").unwrap_or_default();
if api_key.is_empty() {
eprintln!("Error: LASTFM_API_KEY not set");
std::process::exit(1);
}
match db::get_available_similar_artists(&conn, &artist_mbid) {
Ok(artists) => {
for (name, score) in &artists {
println!("{name} ({score:.4})");
}
}
let conn = db::open("playlists.db").expect("failed to open database");
let lastfm = lastfm::LastfmClient::new(api_key);
let seed_name = metadata::read_artist_name(path)
.ok()
.flatten()
.unwrap_or_else(|| artist_mbid.clone());
let similar = match db::get_available_similar_artists(&conn, &artist_mbid) {
Ok(a) => a,
Err(e) => {
eprintln!("DB error: {e}");
std::process::exit(1);
}
};
// Seed artist + similar artists: (mbid, name, match_score)
let mut artists: Vec<(String, String, f64)> = vec![
(artist_mbid.clone(), seed_name, 1.0),
];
artists.extend(similar);
// Collect scored tracks: (total, popularity, match_score, artist_name, path)
let mut playlist: Vec<(f64, f64, f64, String, String)> = Vec::new();
for (mbid, name, match_score) in &artists {
let local_tracks = match db::get_local_tracks_for_artist(&conn, mbid) {
Ok(t) => t,
Err(e) => {
eprintln!("DB error for {name}: {e}");
continue;
}
};
if local_tracks.is_empty() {
continue;
}
// Fetch top tracks from Last.fm for popularity data
let top_tracks = match lastfm.get_top_tracks(mbid) {
Ok(t) => t,
Err(e) => {
eprintln!("Last.fm error for {name}: {e}");
Vec::new()
}
};
// Build a map from recording_mbid -> playcount
let mut playcount_by_mbid: std::collections::HashMap<String, u64> =
std::collections::HashMap::new();
for tt in &top_tracks {
if let Some(ref mbid) = tt.mbid {
playcount_by_mbid.insert(mbid.clone(), tt.playcount);
}
}
// Find max ln(playcount) for this artist to normalize
let max_ln = top_tracks
.iter()
.filter(|t| t.playcount > 0)
.map(|t| (t.playcount as f64).ln())
.fold(f64::NEG_INFINITY, f64::max);
let max_ln = if max_ln > 0.0 { max_ln } else { 1.0 };
for (track_path, recording_mbid) in &local_tracks {
let playcount = recording_mbid
.as_ref()
.and_then(|rec_mbid| playcount_by_mbid.get(rec_mbid).copied());
// Skip tracks not in the artist's top 1000
let Some(playcount) = playcount else { continue };
let popularity = if playcount > 0 {
(playcount as f64).ln() / max_ln
} else {
0.0
};
let total = match_score * (1.0 + popularity);
playlist.push((total, popularity, *match_score, name.clone(), track_path.clone()));
}
}
playlist.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal));
for (total, popularity, similarity, artist, track_path) in &playlist {
if verbose {
println!("{total:.4}\t{similarity:.4}\t{popularity:.4}\t{artist}\t{track_path}");
} else {
println!("{total:.4}\t{artist}\t{track_path}");
}
}
}