use std::collections::{HashMap, HashSet}; use rand::distr::weighted::WeightedIndex; use rand::prelude::*; use crate::types::Candidate; /// Weighted random sampling with per-artist caps and seed enforcement. /// Ported faithfully from drift's generate_playlist(). pub fn generate_playlist( candidates: &[Candidate], n: usize, seed_names: &HashSet, max_artists: Option, skip_seed_enforcement: bool, ) -> Vec { if candidates.is_empty() { return Vec::new(); } let mut rng = rand::rng(); let mut pool: Vec<&Candidate> = candidates.iter().collect(); let mut result: Vec = Vec::new(); let mut artist_counts: HashMap = HashMap::new(); let mut distinct_artists_set: HashSet = HashSet::new(); let max_distinct = max_artists.map(|m| (m as usize).max(1)); let seed_min = if skip_seed_enforcement { 0 } else { (n / 10).max(1) }; let distinct_artists: usize = { let mut seen = HashSet::new(); for c in &pool { seen.insert(&c.artist); } seen.len() }; let divisor = match distinct_artists { 1 => 1, 2 => 2, 3 => 3, 4 => 3, 5 => 4, _ => 5, }; let artist_cap = n.div_ceil(divisor).max(1); while result.len() < n && !pool.is_empty() { let seed_count: usize = seed_names .iter() .map(|name| *artist_counts.get(name).unwrap_or(&0)) .sum(); let remaining = n - result.len(); let seed_deficit = seed_min.saturating_sub(seed_count); let force_seed = seed_deficit > 0 && remaining <= seed_deficit; let eligible: Vec = pool .iter() .enumerate() .filter(|(_, c)| { // Max distinct artists: reject new artists once we hit the cap if let Some(max) = max_distinct && distinct_artists_set.len() >= max && !distinct_artists_set.contains(&c.artist) { return false; } if force_seed { seed_names.contains(&c.artist) } else { *artist_counts.get(&c.artist).unwrap_or(&0) < artist_cap } }) .map(|(i, _)| i) .collect(); let fallback_indices: Vec = (0..pool.len()).collect(); let indices: &[usize] = if eligible.is_empty() { &fallback_indices } else { &eligible }; let weights: Vec = indices.iter().map(|&i| pool[i].score.max(0.001)).collect(); let dist = match WeightedIndex::new(&weights) { Ok(d) => d, Err(_) => break, }; let picked = indices[dist.sample(&mut rng)]; let track = pool.remove(picked); *artist_counts.entry(track.artist.clone()).or_insert(0) += 1; distinct_artists_set.insert(track.artist.clone()); result.push(Candidate { score: track.score, artist: track.artist.clone(), artist_mbid: track.artist_mbid.clone(), track_id: track.track_id, file_path: track.file_path.clone(), title: track.title.clone(), album: track.album.clone(), duration: track.duration, }); } result }