Tuned the algorithm
This commit is contained in:
94
Cargo.lock
generated
94
Cargo.lock
generated
@@ -114,6 +114,18 @@ dependencies = [
|
|||||||
"wasi",
|
"wasi",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "getrandom"
|
||||||
|
version = "0.3.4"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "899def5c37c4fd7b2664648c28120ecec138e4d395b459e5ca34f9cce2dd77fd"
|
||||||
|
dependencies = [
|
||||||
|
"cfg-if",
|
||||||
|
"libc",
|
||||||
|
"r-efi",
|
||||||
|
"wasip2",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "hashbrown"
|
name = "hashbrown"
|
||||||
version = "0.15.5"
|
version = "0.15.5"
|
||||||
@@ -258,6 +270,7 @@ version = "0.1.0"
|
|||||||
dependencies = [
|
dependencies = [
|
||||||
"dotenvy",
|
"dotenvy",
|
||||||
"lofty",
|
"lofty",
|
||||||
|
"rand",
|
||||||
"rusqlite",
|
"rusqlite",
|
||||||
"serde",
|
"serde",
|
||||||
"serde_json",
|
"serde_json",
|
||||||
@@ -265,6 +278,15 @@ dependencies = [
|
|||||||
"walkdir",
|
"walkdir",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "ppv-lite86"
|
||||||
|
version = "0.2.21"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "85eae3c4ed2f50dcfe72643da4befc30deadb458a9b590d720cde2f2b1e97da9"
|
||||||
|
dependencies = [
|
||||||
|
"zerocopy",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "proc-macro2"
|
name = "proc-macro2"
|
||||||
version = "1.0.106"
|
version = "1.0.106"
|
||||||
@@ -283,6 +305,41 @@ dependencies = [
|
|||||||
"proc-macro2",
|
"proc-macro2",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "r-efi"
|
||||||
|
version = "5.3.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "69cdb34c158ceb288df11e18b4bd39de994f6657d83847bdffdbd7f346754b0f"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "rand"
|
||||||
|
version = "0.9.2"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "6db2770f06117d490610c7488547d543617b21bfa07796d7a12f6f1bd53850d1"
|
||||||
|
dependencies = [
|
||||||
|
"rand_chacha",
|
||||||
|
"rand_core",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "rand_chacha"
|
||||||
|
version = "0.9.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "d3022b5f1df60f26e1ffddd6c66e8aa15de382ae63b3a0c1bfc0e4d3e3f325cb"
|
||||||
|
dependencies = [
|
||||||
|
"ppv-lite86",
|
||||||
|
"rand_core",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "rand_core"
|
||||||
|
version = "0.9.5"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "76afc826de14238e6e8c374ddcc1fa19e374fd8dd986b0d2af0d02377261d83c"
|
||||||
|
dependencies = [
|
||||||
|
"getrandom 0.3.4",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "ring"
|
name = "ring"
|
||||||
version = "0.17.14"
|
version = "0.17.14"
|
||||||
@@ -291,7 +348,7 @@ checksum = "a4689e6c2294d81e88dc6261c768b63bc4fcdb852be6d1352498b114f61383b7"
|
|||||||
dependencies = [
|
dependencies = [
|
||||||
"cc",
|
"cc",
|
||||||
"cfg-if",
|
"cfg-if",
|
||||||
"getrandom",
|
"getrandom 0.2.17",
|
||||||
"libc",
|
"libc",
|
||||||
"untrusted",
|
"untrusted",
|
||||||
"windows-sys 0.52.0",
|
"windows-sys 0.52.0",
|
||||||
@@ -502,6 +559,15 @@ version = "0.11.1+wasi-snapshot-preview1"
|
|||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "ccf3ec651a847eb01de73ccad15eb7d99f80485de043efb2f370cd654f4ea44b"
|
checksum = "ccf3ec651a847eb01de73ccad15eb7d99f80485de043efb2f370cd654f4ea44b"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "wasip2"
|
||||||
|
version = "1.0.2+wasi-0.2.9"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "9517f9239f02c069db75e65f174b3da828fe5f5b945c4dd26bd25d89c03ebcf5"
|
||||||
|
dependencies = [
|
||||||
|
"wit-bindgen",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "webpki-roots"
|
name = "webpki-roots"
|
||||||
version = "1.0.6"
|
version = "1.0.6"
|
||||||
@@ -608,6 +674,32 @@ version = "0.52.6"
|
|||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "589f6da84c646204747d1270a2a5661ea66ed1cced2631d546fdfb155959f9ec"
|
checksum = "589f6da84c646204747d1270a2a5661ea66ed1cced2631d546fdfb155959f9ec"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "wit-bindgen"
|
||||||
|
version = "0.51.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "d7249219f66ced02969388cf2bb044a09756a083d0fab1e566056b04d9fbcaa5"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "zerocopy"
|
||||||
|
version = "0.8.40"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "a789c6e490b576db9f7e6b6d661bcc9799f7c0ac8352f56ea20193b2681532e5"
|
||||||
|
dependencies = [
|
||||||
|
"zerocopy-derive",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "zerocopy-derive"
|
||||||
|
version = "0.8.40"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "f65c489a7071a749c849713807783f70672b28094011623e200cb86dcb835953"
|
||||||
|
dependencies = [
|
||||||
|
"proc-macro2",
|
||||||
|
"quote",
|
||||||
|
"syn",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "zeroize"
|
name = "zeroize"
|
||||||
version = "1.8.2"
|
version = "1.8.2"
|
||||||
|
|||||||
@@ -10,4 +10,5 @@ serde = { version = "1", features = ["derive"] }
|
|||||||
serde_json = "1"
|
serde_json = "1"
|
||||||
rusqlite = { version = "0.34", features = ["bundled"] }
|
rusqlite = { version = "0.34", features = ["bundled"] }
|
||||||
ureq = "3"
|
ureq = "3"
|
||||||
|
rand = "0.9"
|
||||||
walkdir = "2.5"
|
walkdir = "2.5"
|
||||||
|
|||||||
140
src/main.rs
140
src/main.rs
@@ -3,13 +3,17 @@ mod filesystem;
|
|||||||
mod lastfm;
|
mod lastfm;
|
||||||
mod metadata;
|
mod metadata;
|
||||||
|
|
||||||
|
use std::collections::HashMap;
|
||||||
use std::env;
|
use std::env;
|
||||||
use std::path::Path;
|
use std::path::Path;
|
||||||
|
|
||||||
|
use rand::distr::weighted::WeightedIndex;
|
||||||
|
use rand::prelude::*;
|
||||||
|
|
||||||
fn usage(program: &str) -> ! {
|
fn usage(program: &str) -> ! {
|
||||||
eprintln!("Usage:");
|
eprintln!("Usage:");
|
||||||
eprintln!(" {program} index [-v] <directory>");
|
eprintln!(" {program} index [-v] <directory>");
|
||||||
eprintln!(" {program} build [-v] <file>");
|
eprintln!(" {program} build [-v] [-n COUNT] <file>");
|
||||||
std::process::exit(1);
|
std::process::exit(1);
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -106,10 +110,35 @@ fn cmd_index(args: &[String]) {
|
|||||||
|
|
||||||
fn cmd_build(args: &[String]) {
|
fn cmd_build(args: &[String]) {
|
||||||
let verbose = args.iter().any(|a| a == "-v");
|
let verbose = args.iter().any(|a| a == "-v");
|
||||||
let rest: Vec<&String> = args.iter().skip(2).filter(|a| *a != "-v").collect();
|
|
||||||
|
// Parse -n COUNT
|
||||||
|
let mut count: usize = 20;
|
||||||
|
let mut rest: Vec<&String> = Vec::new();
|
||||||
|
let mut iter = args.iter().skip(2);
|
||||||
|
while let Some(arg) = iter.next() {
|
||||||
|
if arg == "-v" {
|
||||||
|
continue;
|
||||||
|
} else if arg == "-n" {
|
||||||
|
match iter.next() {
|
||||||
|
Some(val) => match val.parse::<usize>() {
|
||||||
|
Ok(n) if n > 0 => count = n,
|
||||||
|
_ => {
|
||||||
|
eprintln!("Error: -n requires a positive integer");
|
||||||
|
std::process::exit(1);
|
||||||
|
}
|
||||||
|
},
|
||||||
|
None => {
|
||||||
|
eprintln!("Error: -n requires a value");
|
||||||
|
std::process::exit(1);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
rest.push(arg);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if rest.len() != 1 {
|
if rest.len() != 1 {
|
||||||
eprintln!("Usage: {} build [-v] <file>", args[0]);
|
eprintln!("Usage: {} build [-v] [-n COUNT] <file>", args[0]);
|
||||||
std::process::exit(1);
|
std::process::exit(1);
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -189,13 +218,13 @@ fn cmd_build(args: &[String]) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Find max ln(playcount) for this artist to normalize
|
// Find max playcount for this artist to normalize
|
||||||
let max_ln = top_tracks
|
let max_playcount = top_tracks
|
||||||
.iter()
|
.iter()
|
||||||
.filter(|t| t.playcount > 0)
|
.map(|t| t.playcount)
|
||||||
.map(|t| (t.playcount as f64).ln())
|
.max()
|
||||||
.fold(f64::NEG_INFINITY, f64::max);
|
.unwrap_or(1)
|
||||||
let max_ln = if max_ln > 0.0 { max_ln } else { 1.0 };
|
.max(1);
|
||||||
|
|
||||||
for (track_path, recording_mbid) in &local_tracks {
|
for (track_path, recording_mbid) in &local_tracks {
|
||||||
let playcount = recording_mbid
|
let playcount = recording_mbid
|
||||||
@@ -206,23 +235,98 @@ fn cmd_build(args: &[String]) {
|
|||||||
let Some(playcount) = playcount else { continue };
|
let Some(playcount) = playcount else { continue };
|
||||||
|
|
||||||
let popularity = if playcount > 0 {
|
let popularity = if playcount > 0 {
|
||||||
(playcount as f64).ln() / max_ln
|
(playcount as f64 / max_playcount as f64).powf(0.15)
|
||||||
} else {
|
} else {
|
||||||
0.0
|
0.0
|
||||||
};
|
};
|
||||||
|
|
||||||
let total = match_score * (1.0 + popularity);
|
let similarity = (match_score.exp()) / std::f64::consts::E;
|
||||||
playlist.push((total, popularity, *match_score, name.clone(), track_path.clone()));
|
let total = similarity * (1.0 + popularity);
|
||||||
|
playlist.push((total, popularity, similarity, name.clone(), track_path.clone()));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
playlist.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal));
|
|
||||||
|
|
||||||
for (total, popularity, similarity, artist, track_path) in &playlist {
|
|
||||||
if verbose {
|
if verbose {
|
||||||
println!("{total:.4}\t{similarity:.4}\t{popularity:.4}\t{artist}\t{track_path}");
|
let mut sorted = playlist.clone();
|
||||||
} else {
|
sorted.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal));
|
||||||
println!("{total:.4}\t{artist}\t{track_path}");
|
for (total, popularity, similarity, artist, track_path) in &sorted {
|
||||||
|
eprintln!("{total:.4}\t{similarity:.4}\t{popularity:.4}\t{artist}\t{track_path}");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Convert to (score, artist, path) for playlist generation
|
||||||
|
let candidates: Vec<(f64, String, String)> = playlist
|
||||||
|
.into_iter()
|
||||||
|
.map(|(total, _, _, artist, path)| (total, artist, path))
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
let selected = generate_playlist(&candidates, count);
|
||||||
|
|
||||||
|
for (_, _, track_path) in &selected {
|
||||||
|
println!("{track_path}");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn generate_playlist(
|
||||||
|
candidates: &[(f64, String, String)],
|
||||||
|
n: usize,
|
||||||
|
) -> Vec<(f64, String, String)> {
|
||||||
|
if candidates.is_empty() {
|
||||||
|
return Vec::new();
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut rng = rand::rng();
|
||||||
|
let mut pool: Vec<(f64, String, String)> = candidates.to_vec();
|
||||||
|
let mut result: Vec<(f64, String, String)> = Vec::new();
|
||||||
|
let mut artist_counts: HashMap<String, usize> = HashMap::new();
|
||||||
|
|
||||||
|
let distinct_artists: usize = {
|
||||||
|
let mut seen = std::collections::HashSet::new();
|
||||||
|
for (_, artist, _) in &pool {
|
||||||
|
seen.insert(artist.clone());
|
||||||
|
}
|
||||||
|
seen.len()
|
||||||
|
};
|
||||||
|
|
||||||
|
let divisor = match distinct_artists {
|
||||||
|
1 => 1,
|
||||||
|
2 => 2,
|
||||||
|
3 => 3,
|
||||||
|
4 => 3,
|
||||||
|
5 => 4,
|
||||||
|
_ => 5,
|
||||||
|
};
|
||||||
|
let artist_cap = ((n + divisor - 1) / divisor).max(1);
|
||||||
|
|
||||||
|
while result.len() < n && !pool.is_empty() {
|
||||||
|
// Find eligible tracks (artist hasn't hit cap)
|
||||||
|
let eligible: Vec<usize> = pool
|
||||||
|
.iter()
|
||||||
|
.enumerate()
|
||||||
|
.filter(|(_, (_, artist, _))| {
|
||||||
|
*artist_counts.get(artist).unwrap_or(&0) < artist_cap
|
||||||
|
})
|
||||||
|
.map(|(i, _)| i)
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
// If no eligible tracks, relax and use all remaining
|
||||||
|
let indices: &[usize] = if eligible.is_empty() {
|
||||||
|
&(0..pool.len()).collect::<Vec<_>>()
|
||||||
|
} else {
|
||||||
|
&eligible
|
||||||
|
};
|
||||||
|
|
||||||
|
let weights: Vec<f64> = indices.iter().map(|&i| pool[i].0.max(0.001)).collect();
|
||||||
|
let dist = match WeightedIndex::new(&weights) {
|
||||||
|
Ok(d) => d,
|
||||||
|
Err(_) => break,
|
||||||
|
};
|
||||||
|
|
||||||
|
let picked = indices[dist.sample(&mut rng)];
|
||||||
|
let track = pool.remove(picked);
|
||||||
|
*artist_counts.entry(track.1.clone()).or_insert(0) += 1;
|
||||||
|
result.push(track);
|
||||||
|
}
|
||||||
|
|
||||||
|
result
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user