From 97e5a8f5dfe3d90d014f28f763b1157603a7219b Mon Sep 17 00:00:00 2001 From: Connor Johnstone Date: Mon, 2 Mar 2026 23:01:48 -0500 Subject: [PATCH] Tuned the algorithm --- Cargo.lock | 94 +++++++++++++++++++++++++++++++++- Cargo.toml | 1 + src/main.rs | 142 +++++++++++++++++++++++++++++++++++++++++++++------- 3 files changed, 217 insertions(+), 20 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index fa2f740..1f12eb5 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -114,6 +114,18 @@ dependencies = [ "wasi", ] +[[package]] +name = "getrandom" +version = "0.3.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "899def5c37c4fd7b2664648c28120ecec138e4d395b459e5ca34f9cce2dd77fd" +dependencies = [ + "cfg-if", + "libc", + "r-efi", + "wasip2", +] + [[package]] name = "hashbrown" version = "0.15.5" @@ -258,6 +270,7 @@ version = "0.1.0" dependencies = [ "dotenvy", "lofty", + "rand", "rusqlite", "serde", "serde_json", @@ -265,6 +278,15 @@ dependencies = [ "walkdir", ] +[[package]] +name = "ppv-lite86" +version = "0.2.21" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "85eae3c4ed2f50dcfe72643da4befc30deadb458a9b590d720cde2f2b1e97da9" +dependencies = [ + "zerocopy", +] + [[package]] name = "proc-macro2" version = "1.0.106" @@ -283,6 +305,41 @@ dependencies = [ "proc-macro2", ] +[[package]] +name = "r-efi" +version = "5.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "69cdb34c158ceb288df11e18b4bd39de994f6657d83847bdffdbd7f346754b0f" + +[[package]] +name = "rand" +version = "0.9.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6db2770f06117d490610c7488547d543617b21bfa07796d7a12f6f1bd53850d1" +dependencies = [ + "rand_chacha", + "rand_core", +] + +[[package]] +name = "rand_chacha" +version = "0.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d3022b5f1df60f26e1ffddd6c66e8aa15de382ae63b3a0c1bfc0e4d3e3f325cb" +dependencies = [ + "ppv-lite86", + "rand_core", +] + +[[package]] +name = "rand_core" +version = "0.9.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "76afc826de14238e6e8c374ddcc1fa19e374fd8dd986b0d2af0d02377261d83c" +dependencies = [ + "getrandom 0.3.4", +] + [[package]] name = "ring" version = "0.17.14" @@ -291,7 +348,7 @@ checksum = "a4689e6c2294d81e88dc6261c768b63bc4fcdb852be6d1352498b114f61383b7" dependencies = [ "cc", "cfg-if", - "getrandom", + "getrandom 0.2.17", "libc", "untrusted", "windows-sys 0.52.0", @@ -502,6 +559,15 @@ version = "0.11.1+wasi-snapshot-preview1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ccf3ec651a847eb01de73ccad15eb7d99f80485de043efb2f370cd654f4ea44b" +[[package]] +name = "wasip2" +version = "1.0.2+wasi-0.2.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9517f9239f02c069db75e65f174b3da828fe5f5b945c4dd26bd25d89c03ebcf5" +dependencies = [ + "wit-bindgen", +] + [[package]] name = "webpki-roots" version = "1.0.6" @@ -608,6 +674,32 @@ version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "589f6da84c646204747d1270a2a5661ea66ed1cced2631d546fdfb155959f9ec" +[[package]] +name = "wit-bindgen" +version = "0.51.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d7249219f66ced02969388cf2bb044a09756a083d0fab1e566056b04d9fbcaa5" + +[[package]] +name = "zerocopy" +version = "0.8.40" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a789c6e490b576db9f7e6b6d661bcc9799f7c0ac8352f56ea20193b2681532e5" +dependencies = [ + "zerocopy-derive", +] + +[[package]] +name = "zerocopy-derive" +version = "0.8.40" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f65c489a7071a749c849713807783f70672b28094011623e200cb86dcb835953" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "zeroize" version = "1.8.2" diff --git a/Cargo.toml b/Cargo.toml index f018344..0ecef9f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -10,4 +10,5 @@ serde = { version = "1", features = ["derive"] } serde_json = "1" rusqlite = { version = "0.34", features = ["bundled"] } ureq = "3" +rand = "0.9" walkdir = "2.5" diff --git a/src/main.rs b/src/main.rs index cae7efe..090bdd8 100644 --- a/src/main.rs +++ b/src/main.rs @@ -3,13 +3,17 @@ mod filesystem; mod lastfm; mod metadata; +use std::collections::HashMap; use std::env; use std::path::Path; +use rand::distr::weighted::WeightedIndex; +use rand::prelude::*; + fn usage(program: &str) -> ! { eprintln!("Usage:"); eprintln!(" {program} index [-v] "); - eprintln!(" {program} build [-v] "); + eprintln!(" {program} build [-v] [-n COUNT] "); std::process::exit(1); } @@ -106,10 +110,35 @@ fn cmd_index(args: &[String]) { fn cmd_build(args: &[String]) { let verbose = args.iter().any(|a| a == "-v"); - let rest: Vec<&String> = args.iter().skip(2).filter(|a| *a != "-v").collect(); + + // Parse -n COUNT + let mut count: usize = 20; + let mut rest: Vec<&String> = Vec::new(); + let mut iter = args.iter().skip(2); + while let Some(arg) = iter.next() { + if arg == "-v" { + continue; + } else if arg == "-n" { + match iter.next() { + Some(val) => match val.parse::() { + Ok(n) if n > 0 => count = n, + _ => { + eprintln!("Error: -n requires a positive integer"); + std::process::exit(1); + } + }, + None => { + eprintln!("Error: -n requires a value"); + std::process::exit(1); + } + } + } else { + rest.push(arg); + } + } if rest.len() != 1 { - eprintln!("Usage: {} build [-v] ", args[0]); + eprintln!("Usage: {} build [-v] [-n COUNT] ", args[0]); std::process::exit(1); } @@ -189,13 +218,13 @@ fn cmd_build(args: &[String]) { } } - // Find max ln(playcount) for this artist to normalize - let max_ln = top_tracks + // Find max playcount for this artist to normalize + let max_playcount = top_tracks .iter() - .filter(|t| t.playcount > 0) - .map(|t| (t.playcount as f64).ln()) - .fold(f64::NEG_INFINITY, f64::max); - let max_ln = if max_ln > 0.0 { max_ln } else { 1.0 }; + .map(|t| t.playcount) + .max() + .unwrap_or(1) + .max(1); for (track_path, recording_mbid) in &local_tracks { let playcount = recording_mbid @@ -206,23 +235,98 @@ fn cmd_build(args: &[String]) { let Some(playcount) = playcount else { continue }; let popularity = if playcount > 0 { - (playcount as f64).ln() / max_ln + (playcount as f64 / max_playcount as f64).powf(0.15) } else { 0.0 }; - let total = match_score * (1.0 + popularity); - playlist.push((total, popularity, *match_score, name.clone(), track_path.clone())); + let similarity = (match_score.exp()) / std::f64::consts::E; + let total = similarity * (1.0 + popularity); + playlist.push((total, popularity, similarity, name.clone(), track_path.clone())); } } - playlist.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal)); - - for (total, popularity, similarity, artist, track_path) in &playlist { - if verbose { - println!("{total:.4}\t{similarity:.4}\t{popularity:.4}\t{artist}\t{track_path}"); - } else { - println!("{total:.4}\t{artist}\t{track_path}"); + if verbose { + let mut sorted = playlist.clone(); + sorted.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal)); + for (total, popularity, similarity, artist, track_path) in &sorted { + eprintln!("{total:.4}\t{similarity:.4}\t{popularity:.4}\t{artist}\t{track_path}"); } } + + // Convert to (score, artist, path) for playlist generation + let candidates: Vec<(f64, String, String)> = playlist + .into_iter() + .map(|(total, _, _, artist, path)| (total, artist, path)) + .collect(); + + let selected = generate_playlist(&candidates, count); + + for (_, _, track_path) in &selected { + println!("{track_path}"); + } +} + +fn generate_playlist( + candidates: &[(f64, String, String)], + n: usize, +) -> Vec<(f64, String, String)> { + if candidates.is_empty() { + return Vec::new(); + } + + let mut rng = rand::rng(); + let mut pool: Vec<(f64, String, String)> = candidates.to_vec(); + let mut result: Vec<(f64, String, String)> = Vec::new(); + let mut artist_counts: HashMap = HashMap::new(); + + let distinct_artists: usize = { + let mut seen = std::collections::HashSet::new(); + for (_, artist, _) in &pool { + seen.insert(artist.clone()); + } + seen.len() + }; + + let divisor = match distinct_artists { + 1 => 1, + 2 => 2, + 3 => 3, + 4 => 3, + 5 => 4, + _ => 5, + }; + let artist_cap = ((n + divisor - 1) / divisor).max(1); + + while result.len() < n && !pool.is_empty() { + // Find eligible tracks (artist hasn't hit cap) + let eligible: Vec = pool + .iter() + .enumerate() + .filter(|(_, (_, artist, _))| { + *artist_counts.get(artist).unwrap_or(&0) < artist_cap + }) + .map(|(i, _)| i) + .collect(); + + // If no eligible tracks, relax and use all remaining + let indices: &[usize] = if eligible.is_empty() { + &(0..pool.len()).collect::>() + } else { + &eligible + }; + + let weights: Vec = indices.iter().map(|&i| pool[i].0.max(0.001)).collect(); + let dist = match WeightedIndex::new(&weights) { + Ok(d) => d, + Err(_) => break, + }; + + let picked = indices[dist.sample(&mut rng)]; + let track = pool.remove(picked); + *artist_counts.entry(track.1.clone()).or_insert(0) += 1; + result.push(track); + } + + result }