Files
Main/shanty-playlist/src/selection.rs
T
Connor Johnstone 6f73bb87ce
CI / check (push) Successful in 1m12s
CI / docker (push) Successful in 2m1s
Added the playlist generator
2026-03-20 18:09:47 -04:00

96 lines
2.8 KiB
Rust

use std::collections::{HashMap, HashSet};
use rand::distr::weighted::WeightedIndex;
use rand::prelude::*;
use crate::types::Candidate;
/// Weighted random sampling with per-artist caps and seed enforcement.
/// Ported faithfully from drift's generate_playlist().
pub fn generate_playlist(
candidates: &[Candidate],
n: usize,
seed_names: &HashSet<String>,
) -> Vec<Candidate> {
if candidates.is_empty() {
return Vec::new();
}
let mut rng = rand::rng();
let mut pool: Vec<&Candidate> = candidates.iter().collect();
let mut result: Vec<Candidate> = Vec::new();
let mut artist_counts: HashMap<String, usize> = HashMap::new();
let seed_min = (n / 10).max(1);
let distinct_artists: usize = {
let mut seen = HashSet::new();
for c in &pool {
seen.insert(&c.artist);
}
seen.len()
};
let divisor = match distinct_artists {
1 => 1,
2 => 2,
3 => 3,
4 => 3,
5 => 4,
_ => 5,
};
let artist_cap = n.div_ceil(divisor).max(1);
while result.len() < n && !pool.is_empty() {
let seed_count: usize = seed_names
.iter()
.map(|name| *artist_counts.get(name).unwrap_or(&0))
.sum();
let remaining = n - result.len();
let seed_deficit = seed_min.saturating_sub(seed_count);
let force_seed = seed_deficit > 0 && remaining <= seed_deficit;
let eligible: Vec<usize> = pool
.iter()
.enumerate()
.filter(|(_, c)| {
if force_seed {
seed_names.contains(&c.artist)
} else {
*artist_counts.get(&c.artist).unwrap_or(&0) < artist_cap
}
})
.map(|(i, _)| i)
.collect();
let fallback_indices: Vec<usize> = (0..pool.len()).collect();
let indices: &[usize] = if eligible.is_empty() {
&fallback_indices
} else {
&eligible
};
let weights: Vec<f64> = indices.iter().map(|&i| pool[i].score.max(0.001)).collect();
let dist = match WeightedIndex::new(&weights) {
Ok(d) => d,
Err(_) => break,
};
let picked = indices[dist.sample(&mut rng)];
let track = pool.remove(picked);
*artist_counts.entry(track.artist.clone()).or_insert(0) += 1;
result.push(Candidate {
score: track.score,
artist: track.artist.clone(),
artist_mbid: track.artist_mbid.clone(),
track_id: track.track_id,
file_path: track.file_path.clone(),
title: track.title.clone(),
album: track.album.clone(),
duration: track.duration,
});
}
result
}