112 lines
3.4 KiB
Rust
112 lines
3.4 KiB
Rust
use std::collections::{HashMap, HashSet};
|
|
|
|
use rand::distr::weighted::WeightedIndex;
|
|
use rand::prelude::*;
|
|
|
|
use crate::types::Candidate;
|
|
|
|
/// Weighted random sampling with per-artist caps and seed enforcement.
|
|
/// Ported faithfully from drift's generate_playlist().
|
|
pub fn generate_playlist(
|
|
candidates: &[Candidate],
|
|
n: usize,
|
|
seed_names: &HashSet<String>,
|
|
max_artists: Option<u8>,
|
|
skip_seed_enforcement: bool,
|
|
) -> Vec<Candidate> {
|
|
if candidates.is_empty() {
|
|
return Vec::new();
|
|
}
|
|
|
|
let mut rng = rand::rng();
|
|
let mut pool: Vec<&Candidate> = candidates.iter().collect();
|
|
let mut result: Vec<Candidate> = Vec::new();
|
|
let mut artist_counts: HashMap<String, usize> = HashMap::new();
|
|
let mut distinct_artists_set: HashSet<String> = HashSet::new();
|
|
let max_distinct = max_artists.map(|m| (m as usize).max(1));
|
|
|
|
let seed_min = if skip_seed_enforcement {
|
|
0
|
|
} else {
|
|
(n / 10).max(1)
|
|
};
|
|
|
|
let distinct_artists: usize = {
|
|
let mut seen = HashSet::new();
|
|
for c in &pool {
|
|
seen.insert(&c.artist);
|
|
}
|
|
seen.len()
|
|
};
|
|
|
|
let divisor = match distinct_artists {
|
|
1 => 1,
|
|
2 => 2,
|
|
3 => 3,
|
|
4 => 3,
|
|
5 => 4,
|
|
_ => 5,
|
|
};
|
|
let artist_cap = n.div_ceil(divisor).max(1);
|
|
|
|
while result.len() < n && !pool.is_empty() {
|
|
let seed_count: usize = seed_names
|
|
.iter()
|
|
.map(|name| *artist_counts.get(name).unwrap_or(&0))
|
|
.sum();
|
|
let remaining = n - result.len();
|
|
let seed_deficit = seed_min.saturating_sub(seed_count);
|
|
let force_seed = seed_deficit > 0 && remaining <= seed_deficit;
|
|
|
|
let eligible: Vec<usize> = pool
|
|
.iter()
|
|
.enumerate()
|
|
.filter(|(_, c)| {
|
|
// Max distinct artists: reject new artists once we hit the cap
|
|
if let Some(max) = max_distinct
|
|
&& distinct_artists_set.len() >= max
|
|
&& !distinct_artists_set.contains(&c.artist)
|
|
{
|
|
return false;
|
|
}
|
|
if force_seed {
|
|
seed_names.contains(&c.artist)
|
|
} else {
|
|
*artist_counts.get(&c.artist).unwrap_or(&0) < artist_cap
|
|
}
|
|
})
|
|
.map(|(i, _)| i)
|
|
.collect();
|
|
|
|
let fallback_indices: Vec<usize> = (0..pool.len()).collect();
|
|
let indices: &[usize] = if eligible.is_empty() {
|
|
&fallback_indices
|
|
} else {
|
|
&eligible
|
|
};
|
|
|
|
let weights: Vec<f64> = indices.iter().map(|&i| pool[i].score.max(0.001)).collect();
|
|
let dist = match WeightedIndex::new(&weights) {
|
|
Ok(d) => d,
|
|
Err(_) => break,
|
|
};
|
|
|
|
let picked = indices[dist.sample(&mut rng)];
|
|
let track = pool.remove(picked);
|
|
*artist_counts.entry(track.artist.clone()).or_insert(0) += 1;
|
|
distinct_artists_set.insert(track.artist.clone());
|
|
result.push(Candidate {
|
|
score: track.score,
|
|
artist: track.artist.clone(),
|
|
artist_mbid: track.artist_mbid.clone(),
|
|
track_id: track.track_id,
|
|
file_path: track.file_path.clone(),
|
|
title: track.title.clone(),
|
|
album: track.album.clone(),
|
|
duration: track.duration,
|
|
});
|
|
}
|
|
|
|
result
|
|
}
|