Added the playlist generator
This commit is contained in:
@@ -0,0 +1,95 @@
|
||||
use std::collections::{HashMap, HashSet};
|
||||
|
||||
use rand::distr::weighted::WeightedIndex;
|
||||
use rand::prelude::*;
|
||||
|
||||
use crate::types::Candidate;
|
||||
|
||||
/// Weighted random sampling with per-artist caps and seed enforcement.
|
||||
/// Ported faithfully from drift's generate_playlist().
|
||||
pub fn generate_playlist(
|
||||
candidates: &[Candidate],
|
||||
n: usize,
|
||||
seed_names: &HashSet<String>,
|
||||
) -> Vec<Candidate> {
|
||||
if candidates.is_empty() {
|
||||
return Vec::new();
|
||||
}
|
||||
|
||||
let mut rng = rand::rng();
|
||||
let mut pool: Vec<&Candidate> = candidates.iter().collect();
|
||||
let mut result: Vec<Candidate> = Vec::new();
|
||||
let mut artist_counts: HashMap<String, usize> = HashMap::new();
|
||||
|
||||
let seed_min = (n / 10).max(1);
|
||||
|
||||
let distinct_artists: usize = {
|
||||
let mut seen = HashSet::new();
|
||||
for c in &pool {
|
||||
seen.insert(&c.artist);
|
||||
}
|
||||
seen.len()
|
||||
};
|
||||
|
||||
let divisor = match distinct_artists {
|
||||
1 => 1,
|
||||
2 => 2,
|
||||
3 => 3,
|
||||
4 => 3,
|
||||
5 => 4,
|
||||
_ => 5,
|
||||
};
|
||||
let artist_cap = n.div_ceil(divisor).max(1);
|
||||
|
||||
while result.len() < n && !pool.is_empty() {
|
||||
let seed_count: usize = seed_names
|
||||
.iter()
|
||||
.map(|name| *artist_counts.get(name).unwrap_or(&0))
|
||||
.sum();
|
||||
let remaining = n - result.len();
|
||||
let seed_deficit = seed_min.saturating_sub(seed_count);
|
||||
let force_seed = seed_deficit > 0 && remaining <= seed_deficit;
|
||||
|
||||
let eligible: Vec<usize> = pool
|
||||
.iter()
|
||||
.enumerate()
|
||||
.filter(|(_, c)| {
|
||||
if force_seed {
|
||||
seed_names.contains(&c.artist)
|
||||
} else {
|
||||
*artist_counts.get(&c.artist).unwrap_or(&0) < artist_cap
|
||||
}
|
||||
})
|
||||
.map(|(i, _)| i)
|
||||
.collect();
|
||||
|
||||
let fallback_indices: Vec<usize> = (0..pool.len()).collect();
|
||||
let indices: &[usize] = if eligible.is_empty() {
|
||||
&fallback_indices
|
||||
} else {
|
||||
&eligible
|
||||
};
|
||||
|
||||
let weights: Vec<f64> = indices.iter().map(|&i| pool[i].score.max(0.001)).collect();
|
||||
let dist = match WeightedIndex::new(&weights) {
|
||||
Ok(d) => d,
|
||||
Err(_) => break,
|
||||
};
|
||||
|
||||
let picked = indices[dist.sample(&mut rng)];
|
||||
let track = pool.remove(picked);
|
||||
*artist_counts.entry(track.artist.clone()).or_insert(0) += 1;
|
||||
result.push(Candidate {
|
||||
score: track.score,
|
||||
artist: track.artist.clone(),
|
||||
artist_mbid: track.artist_mbid.clone(),
|
||||
track_id: track.track_id,
|
||||
file_path: track.file_path.clone(),
|
||||
title: track.title.clone(),
|
||||
album: track.album.clone(),
|
||||
duration: track.duration,
|
||||
});
|
||||
}
|
||||
|
||||
result
|
||||
}
|
||||
Reference in New Issue
Block a user