Tuned the algorithm
This commit is contained in:
142
src/main.rs
142
src/main.rs
@@ -3,13 +3,17 @@ mod filesystem;
|
||||
mod lastfm;
|
||||
mod metadata;
|
||||
|
||||
use std::collections::HashMap;
|
||||
use std::env;
|
||||
use std::path::Path;
|
||||
|
||||
use rand::distr::weighted::WeightedIndex;
|
||||
use rand::prelude::*;
|
||||
|
||||
fn usage(program: &str) -> ! {
|
||||
eprintln!("Usage:");
|
||||
eprintln!(" {program} index [-v] <directory>");
|
||||
eprintln!(" {program} build [-v] <file>");
|
||||
eprintln!(" {program} build [-v] [-n COUNT] <file>");
|
||||
std::process::exit(1);
|
||||
}
|
||||
|
||||
@@ -106,10 +110,35 @@ fn cmd_index(args: &[String]) {
|
||||
|
||||
fn cmd_build(args: &[String]) {
|
||||
let verbose = args.iter().any(|a| a == "-v");
|
||||
let rest: Vec<&String> = args.iter().skip(2).filter(|a| *a != "-v").collect();
|
||||
|
||||
// Parse -n COUNT
|
||||
let mut count: usize = 20;
|
||||
let mut rest: Vec<&String> = Vec::new();
|
||||
let mut iter = args.iter().skip(2);
|
||||
while let Some(arg) = iter.next() {
|
||||
if arg == "-v" {
|
||||
continue;
|
||||
} else if arg == "-n" {
|
||||
match iter.next() {
|
||||
Some(val) => match val.parse::<usize>() {
|
||||
Ok(n) if n > 0 => count = n,
|
||||
_ => {
|
||||
eprintln!("Error: -n requires a positive integer");
|
||||
std::process::exit(1);
|
||||
}
|
||||
},
|
||||
None => {
|
||||
eprintln!("Error: -n requires a value");
|
||||
std::process::exit(1);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
rest.push(arg);
|
||||
}
|
||||
}
|
||||
|
||||
if rest.len() != 1 {
|
||||
eprintln!("Usage: {} build [-v] <file>", args[0]);
|
||||
eprintln!("Usage: {} build [-v] [-n COUNT] <file>", args[0]);
|
||||
std::process::exit(1);
|
||||
}
|
||||
|
||||
@@ -189,13 +218,13 @@ fn cmd_build(args: &[String]) {
|
||||
}
|
||||
}
|
||||
|
||||
// Find max ln(playcount) for this artist to normalize
|
||||
let max_ln = top_tracks
|
||||
// Find max playcount for this artist to normalize
|
||||
let max_playcount = top_tracks
|
||||
.iter()
|
||||
.filter(|t| t.playcount > 0)
|
||||
.map(|t| (t.playcount as f64).ln())
|
||||
.fold(f64::NEG_INFINITY, f64::max);
|
||||
let max_ln = if max_ln > 0.0 { max_ln } else { 1.0 };
|
||||
.map(|t| t.playcount)
|
||||
.max()
|
||||
.unwrap_or(1)
|
||||
.max(1);
|
||||
|
||||
for (track_path, recording_mbid) in &local_tracks {
|
||||
let playcount = recording_mbid
|
||||
@@ -206,23 +235,98 @@ fn cmd_build(args: &[String]) {
|
||||
let Some(playcount) = playcount else { continue };
|
||||
|
||||
let popularity = if playcount > 0 {
|
||||
(playcount as f64).ln() / max_ln
|
||||
(playcount as f64 / max_playcount as f64).powf(0.15)
|
||||
} else {
|
||||
0.0
|
||||
};
|
||||
|
||||
let total = match_score * (1.0 + popularity);
|
||||
playlist.push((total, popularity, *match_score, name.clone(), track_path.clone()));
|
||||
let similarity = (match_score.exp()) / std::f64::consts::E;
|
||||
let total = similarity * (1.0 + popularity);
|
||||
playlist.push((total, popularity, similarity, name.clone(), track_path.clone()));
|
||||
}
|
||||
}
|
||||
|
||||
playlist.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal));
|
||||
|
||||
for (total, popularity, similarity, artist, track_path) in &playlist {
|
||||
if verbose {
|
||||
println!("{total:.4}\t{similarity:.4}\t{popularity:.4}\t{artist}\t{track_path}");
|
||||
} else {
|
||||
println!("{total:.4}\t{artist}\t{track_path}");
|
||||
if verbose {
|
||||
let mut sorted = playlist.clone();
|
||||
sorted.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal));
|
||||
for (total, popularity, similarity, artist, track_path) in &sorted {
|
||||
eprintln!("{total:.4}\t{similarity:.4}\t{popularity:.4}\t{artist}\t{track_path}");
|
||||
}
|
||||
}
|
||||
|
||||
// Convert to (score, artist, path) for playlist generation
|
||||
let candidates: Vec<(f64, String, String)> = playlist
|
||||
.into_iter()
|
||||
.map(|(total, _, _, artist, path)| (total, artist, path))
|
||||
.collect();
|
||||
|
||||
let selected = generate_playlist(&candidates, count);
|
||||
|
||||
for (_, _, track_path) in &selected {
|
||||
println!("{track_path}");
|
||||
}
|
||||
}
|
||||
|
||||
fn generate_playlist(
|
||||
candidates: &[(f64, String, String)],
|
||||
n: usize,
|
||||
) -> Vec<(f64, String, String)> {
|
||||
if candidates.is_empty() {
|
||||
return Vec::new();
|
||||
}
|
||||
|
||||
let mut rng = rand::rng();
|
||||
let mut pool: Vec<(f64, String, String)> = candidates.to_vec();
|
||||
let mut result: Vec<(f64, String, String)> = Vec::new();
|
||||
let mut artist_counts: HashMap<String, usize> = HashMap::new();
|
||||
|
||||
let distinct_artists: usize = {
|
||||
let mut seen = std::collections::HashSet::new();
|
||||
for (_, artist, _) in &pool {
|
||||
seen.insert(artist.clone());
|
||||
}
|
||||
seen.len()
|
||||
};
|
||||
|
||||
let divisor = match distinct_artists {
|
||||
1 => 1,
|
||||
2 => 2,
|
||||
3 => 3,
|
||||
4 => 3,
|
||||
5 => 4,
|
||||
_ => 5,
|
||||
};
|
||||
let artist_cap = ((n + divisor - 1) / divisor).max(1);
|
||||
|
||||
while result.len() < n && !pool.is_empty() {
|
||||
// Find eligible tracks (artist hasn't hit cap)
|
||||
let eligible: Vec<usize> = pool
|
||||
.iter()
|
||||
.enumerate()
|
||||
.filter(|(_, (_, artist, _))| {
|
||||
*artist_counts.get(artist).unwrap_or(&0) < artist_cap
|
||||
})
|
||||
.map(|(i, _)| i)
|
||||
.collect();
|
||||
|
||||
// If no eligible tracks, relax and use all remaining
|
||||
let indices: &[usize] = if eligible.is_empty() {
|
||||
&(0..pool.len()).collect::<Vec<_>>()
|
||||
} else {
|
||||
&eligible
|
||||
};
|
||||
|
||||
let weights: Vec<f64> = indices.iter().map(|&i| pool[i].0.max(0.001)).collect();
|
||||
let dist = match WeightedIndex::new(&weights) {
|
||||
Ok(d) => d,
|
||||
Err(_) => break,
|
||||
};
|
||||
|
||||
let picked = indices[dist.sample(&mut rng)];
|
||||
let track = pool.remove(picked);
|
||||
*artist_counts.entry(track.1.clone()).or_insert(0) += 1;
|
||||
result.push(track);
|
||||
}
|
||||
|
||||
result
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user